about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src')
-rw-r--r--compiler/rustc_const_eval/src/check_consts/mod.rs2
-rw-r--r--compiler/rustc_const_eval/src/const_eval/machine.rs10
-rw-r--r--compiler/rustc_const_eval/src/const_eval/valtrees.rs48
-rw-r--r--compiler/rustc_const_eval/src/interpret/call.rs3
-rw-r--r--compiler/rustc_const_eval/src/interpret/cast.rs10
-rw-r--r--compiler/rustc_const_eval/src/interpret/eval_context.rs40
-rw-r--r--compiler/rustc_const_eval/src/interpret/intern.rs4
-rw-r--r--compiler/rustc_const_eval/src/interpret/operand.rs2
-rw-r--r--compiler/rustc_const_eval/src/interpret/traits.rs20
-rw-r--r--compiler/rustc_const_eval/src/interpret/validity.rs19
-rw-r--r--compiler/rustc_const_eval/src/util/check_validity_requirement.rs16
11 files changed, 63 insertions, 111 deletions
diff --git a/compiler/rustc_const_eval/src/check_consts/mod.rs b/compiler/rustc_const_eval/src/check_consts/mod.rs
index bfa0a0319c3..52e000858b4 100644
--- a/compiler/rustc_const_eval/src/check_consts/mod.rs
+++ b/compiler/rustc_const_eval/src/check_consts/mod.rs
@@ -31,7 +31,7 @@ pub struct ConstCx<'mir, 'tcx> {
 impl<'mir, 'tcx> ConstCx<'mir, 'tcx> {
     pub fn new(tcx: TyCtxt<'tcx>, body: &'mir mir::Body<'tcx>) -> Self {
         let typing_env = body.typing_env(tcx);
-        let const_kind = tcx.hir().body_const_context(body.source.def_id().expect_local());
+        let const_kind = tcx.hir_body_const_context(body.source.def_id().expect_local());
         ConstCx { body, tcx, typing_env, const_kind }
     }
 
diff --git a/compiler/rustc_const_eval/src/const_eval/machine.rs b/compiler/rustc_const_eval/src/const_eval/machine.rs
index 82e0a6e6666..4db862afd9f 100644
--- a/compiler/rustc_const_eval/src/const_eval/machine.rs
+++ b/compiler/rustc_const_eval/src/const_eval/machine.rs
@@ -502,12 +502,10 @@ impl<'tcx> interpret::Machine<'tcx> for CompileTimeMachine<'tcx> {
             RemainderByZero(op) => RemainderByZero(eval_to_int(op)?),
             ResumedAfterReturn(coroutine_kind) => ResumedAfterReturn(*coroutine_kind),
             ResumedAfterPanic(coroutine_kind) => ResumedAfterPanic(*coroutine_kind),
-            MisalignedPointerDereference { ref required, ref found } => {
-                MisalignedPointerDereference {
-                    required: eval_to_int(required)?,
-                    found: eval_to_int(found)?,
-                }
-            }
+            MisalignedPointerDereference { required, found } => MisalignedPointerDereference {
+                required: eval_to_int(required)?,
+                found: eval_to_int(found)?,
+            },
             NullPointerDereference => NullPointerDereference,
         };
         Err(ConstEvalErrKind::AssertFailure(err)).into()
diff --git a/compiler/rustc_const_eval/src/const_eval/valtrees.rs b/compiler/rustc_const_eval/src/const_eval/valtrees.rs
index c35910a706b..3776fb55c2e 100644
--- a/compiler/rustc_const_eval/src/const_eval/valtrees.rs
+++ b/compiler/rustc_const_eval/src/const_eval/valtrees.rs
@@ -2,7 +2,7 @@ use rustc_abi::{BackendRepr, VariantIdx};
 use rustc_data_structures::stack::ensure_sufficient_stack;
 use rustc_middle::mir::interpret::{EvalToValTreeResult, GlobalId, ReportedErrorInfo};
 use rustc_middle::ty::layout::{LayoutCx, LayoutOf, TyAndLayout};
-use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
+use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_middle::{bug, mir};
 use rustc_span::DUMMY_SP;
 use tracing::{debug, instrument, trace};
@@ -21,7 +21,7 @@ use crate::interpret::{
 fn branches<'tcx>(
     ecx: &CompileTimeInterpCx<'tcx>,
     place: &MPlaceTy<'tcx>,
-    n: usize,
+    field_count: usize,
     variant: Option<VariantIdx>,
     num_nodes: &mut usize,
 ) -> ValTreeCreationResult<'tcx> {
@@ -29,30 +29,28 @@ fn branches<'tcx>(
         Some(variant) => ecx.project_downcast(place, variant).unwrap(),
         None => place.clone(),
     };
-    let variant = variant.map(|variant| Some(ty::ValTree::Leaf(ScalarInt::from(variant.as_u32()))));
-    debug!(?place, ?variant);
+    debug!(?place);
 
-    let mut fields = Vec::with_capacity(n);
-    for i in 0..n {
-        let field = ecx.project_field(&place, i).unwrap();
-        let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
-        fields.push(Some(valtree));
-    }
+    let mut branches = Vec::with_capacity(field_count + variant.is_some() as usize);
 
     // For enums, we prepend their variant index before the variant's fields so we can figure out
     // the variant again when just seeing a valtree.
-    let branches = variant
-        .into_iter()
-        .chain(fields.into_iter())
-        .collect::<Option<Vec<_>>>()
-        .expect("should have already checked for errors in ValTree creation");
+    if let Some(variant) = variant {
+        branches.push(ty::ValTree::from_scalar_int(*ecx.tcx, variant.as_u32().into()));
+    }
+
+    for i in 0..field_count {
+        let field = ecx.project_field(&place, i).unwrap();
+        let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
+        branches.push(valtree);
+    }
 
     // Have to account for ZSTs here
     if branches.len() == 0 {
         *num_nodes += 1;
     }
 
-    Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(branches)))
+    Ok(ty::ValTree::from_branches(*ecx.tcx, branches))
 }
 
 #[instrument(skip(ecx), level = "debug")]
@@ -70,7 +68,7 @@ fn slice_branches<'tcx>(
         elems.push(valtree);
     }
 
-    Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(elems)))
+    Ok(ty::ValTree::from_branches(*ecx.tcx, elems))
 }
 
 #[instrument(skip(ecx), level = "debug")]
@@ -79,6 +77,7 @@ fn const_to_valtree_inner<'tcx>(
     place: &MPlaceTy<'tcx>,
     num_nodes: &mut usize,
 ) -> ValTreeCreationResult<'tcx> {
+    let tcx = *ecx.tcx;
     let ty = place.layout.ty;
     debug!("ty kind: {:?}", ty.kind());
 
@@ -89,14 +88,14 @@ fn const_to_valtree_inner<'tcx>(
     match ty.kind() {
         ty::FnDef(..) => {
             *num_nodes += 1;
-            Ok(ty::ValTree::zst())
+            Ok(ty::ValTree::zst(tcx))
         }
         ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
             let val = ecx.read_immediate(place).unwrap();
             let val = val.to_scalar_int().unwrap();
             *num_nodes += 1;
 
-            Ok(ty::ValTree::Leaf(val))
+            Ok(ty::ValTree::from_scalar_int(tcx, val))
         }
 
         ty::Pat(base, ..) => {
@@ -127,7 +126,7 @@ fn const_to_valtree_inner<'tcx>(
                 return Err(ValTreeCreationError::NonSupportedType(ty));
             };
             // It's just a ScalarInt!
-            Ok(ty::ValTree::Leaf(val))
+            Ok(ty::ValTree::from_scalar_int(tcx, val))
         }
 
         // Technically we could allow function pointers (represented as `ty::Instance`), but this is not guaranteed to
@@ -287,16 +286,11 @@ pub fn valtree_to_const_value<'tcx>(
     // FIXME: Does this need an example?
     match *cv.ty.kind() {
         ty::FnDef(..) => {
-            assert!(cv.valtree.unwrap_branch().is_empty());
+            assert!(cv.valtree.is_zst());
             mir::ConstValue::ZeroSized
         }
         ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(_, _) => {
-            match cv.valtree {
-                ty::ValTree::Leaf(scalar_int) => mir::ConstValue::Scalar(Scalar::Int(scalar_int)),
-                ty::ValTree::Branch(_) => bug!(
-                    "ValTrees for Bool, Int, Uint, Float, Char or RawPtr should have the form ValTree::Leaf"
-                ),
-            }
+            mir::ConstValue::Scalar(Scalar::Int(cv.valtree.unwrap_leaf()))
         }
         ty::Pat(ty, _) => {
             let cv = ty::Value { valtree: cv.valtree, ty };
diff --git a/compiler/rustc_const_eval/src/interpret/call.rs b/compiler/rustc_const_eval/src/interpret/call.rs
index cdf706f3752..29f819cca1f 100644
--- a/compiler/rustc_const_eval/src/interpret/call.rs
+++ b/compiler/rustc_const_eval/src/interpret/call.rs
@@ -241,7 +241,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         interp_ok(caller == callee)
     }
 
-    fn check_argument_compat(
+    /// Returns a `bool` saying whether the two arguments are ABI-compatible.
+    pub fn check_argument_compat(
         &self,
         caller_abi: &ArgAbi<'tcx, Ty<'tcx>>,
         callee_abi: &ArgAbi<'tcx, Ty<'tcx>>,
diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs
index b53c1505f6b..7b706334ed8 100644
--- a/compiler/rustc_const_eval/src/interpret/cast.rs
+++ b/compiler/rustc_const_eval/src/interpret/cast.rs
@@ -430,10 +430,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
                     };
                     let erased_trait_ref =
                         ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref);
-                    assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env(
-                        erased_trait_ref,
-                        self.tcx.instantiate_bound_regions_with_erased(b)
-                    )));
+                    assert_eq!(
+                        data_b.principal().map(|b| {
+                            self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b)
+                        }),
+                        Some(erased_trait_ref),
+                    );
                 } else {
                     // In this case codegen would keep using the old vtable. We don't want to do
                     // that as it has the wrong trait. The reason codegen can do this is that
diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs
index 66a75113652..c1948e9f31f 100644
--- a/compiler/rustc_const_eval/src/interpret/eval_context.rs
+++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs
@@ -4,9 +4,6 @@ use either::{Left, Right};
 use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout};
 use rustc_errors::DiagCtxtHandle;
 use rustc_hir::def_id::DefId;
-use rustc_infer::infer::TyCtxtInferExt;
-use rustc_infer::infer::at::ToTrace;
-use rustc_infer::traits::ObligationCause;
 use rustc_middle::mir::interpret::{ErrorHandled, InvalidMetaKind, ReportedErrorInfo};
 use rustc_middle::query::TyCtxtAt;
 use rustc_middle::ty::layout::{
@@ -17,8 +14,7 @@ use rustc_middle::{mir, span_bug};
 use rustc_session::Limit;
 use rustc_span::Span;
 use rustc_target::callconv::FnAbi;
-use rustc_trait_selection::traits::ObligationCtxt;
-use tracing::{debug, instrument, trace};
+use tracing::{debug, trace};
 
 use super::{
     Frame, FrameInfo, GlobalId, InterpErrorInfo, InterpErrorKind, InterpResult, MPlaceTy, Machine,
@@ -320,40 +316,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         }
     }
 
-    /// Check if the two things are equal in the current param_env, using an infcx to get proper
-    /// equality checks.
-    #[instrument(level = "trace", skip(self), ret)]
-    pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
-    where
-        T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
-    {
-        // Fast path: compare directly.
-        if a == b {
-            return true;
-        }
-        // Slow path: spin up an inference context to check if these traits are sufficiently equal.
-        let (infcx, param_env) = self.tcx.infer_ctxt().build_with_typing_env(self.typing_env);
-        let ocx = ObligationCtxt::new(&infcx);
-        let cause = ObligationCause::dummy_with_span(self.cur_span());
-        // equate the two trait refs after normalization
-        let a = ocx.normalize(&cause, param_env, a);
-        let b = ocx.normalize(&cause, param_env, b);
-
-        if let Err(terr) = ocx.eq(&cause, param_env, a, b) {
-            trace!(?terr);
-            return false;
-        }
-
-        let errors = ocx.select_all_or_error();
-        if !errors.is_empty() {
-            trace!(?errors);
-            return false;
-        }
-
-        // All good.
-        true
-    }
-
     /// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
     /// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
     /// and is primarily intended for the panic machinery.
diff --git a/compiler/rustc_const_eval/src/interpret/intern.rs b/compiler/rustc_const_eval/src/interpret/intern.rs
index e14cc1dad36..43631330f89 100644
--- a/compiler/rustc_const_eval/src/interpret/intern.rs
+++ b/compiler/rustc_const_eval/src/interpret/intern.rs
@@ -61,8 +61,8 @@ impl HasStaticRootDefId for const_eval::CompileTimeMachine<'_> {
 /// already mutable (as a sanity check).
 ///
 /// Returns an iterator over all relocations referred to by this allocation.
-fn intern_shallow<'rt, 'tcx, T, M: CompileTimeMachine<'tcx, T>>(
-    ecx: &'rt mut InterpCx<'tcx, M>,
+fn intern_shallow<'tcx, T, M: CompileTimeMachine<'tcx, T>>(
+    ecx: &mut InterpCx<'tcx, M>,
     alloc_id: AllocId,
     mutability: Mutability,
 ) -> Result<impl Iterator<Item = CtfeProvenance> + 'tcx, ()> {
diff --git a/compiler/rustc_const_eval/src/interpret/operand.rs b/compiler/rustc_const_eval/src/interpret/operand.rs
index 5d905cff1f2..36da9037e43 100644
--- a/compiler/rustc_const_eval/src/interpret/operand.rs
+++ b/compiler/rustc_const_eval/src/interpret/operand.rs
@@ -385,7 +385,7 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
             (Immediate::Uninit, _) => Immediate::Uninit,
             // If the field is uninhabited, we can forget the data (can happen in ConstProp).
             // `enum S { A(!), B, C }` is an example of an enum with Scalar layout that
-            // has an `Uninhabited` variant, which means this case is possible.
+            // has an uninhabited variant, which means this case is possible.
             _ if layout.is_uninhabited() => Immediate::Uninit,
             // the field contains no information, can be left uninit
             // (Scalar/ScalarPair can contain even aligned ZST, not just 1-ZST)
diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs
index 4cfaacebfcd..a5029eea5a7 100644
--- a/compiler/rustc_const_eval/src/interpret/traits.rs
+++ b/compiler/rustc_const_eval/src/interpret/traits.rs
@@ -86,21 +86,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
             throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
         }
 
+        // This checks whether there is a subtyping relation between the predicates in either direction.
+        // For example:
+        // - casting between `dyn for<'a> Trait<fn(&'a u8)>` and `dyn Trait<fn(&'static u8)>` is OK
+        // - casting between `dyn Trait<for<'a> fn(&'a u8)>` and either of the above is UB
         for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) {
-            let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) {
-                (
-                    ty::ExistentialPredicate::Trait(a_data),
-                    ty::ExistentialPredicate::Trait(b_data),
-                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
+            let a_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, a_pred);
+            let b_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b_pred);
 
-                (
-                    ty::ExistentialPredicate::Projection(a_data),
-                    ty::ExistentialPredicate::Projection(b_data),
-                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
-
-                _ => false,
-            };
-            if !is_eq {
+            if a_pred != b_pred {
                 throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
             }
         }
diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs
index 0ac34f4633b..40dec0cb39e 100644
--- a/compiler/rustc_const_eval/src/interpret/validity.rs
+++ b/compiler/rustc_const_eval/src/interpret/validity.rs
@@ -1264,21 +1264,20 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
             }
         }
 
-        // *After* all of this, check the ABI. We need to check the ABI to handle
-        // types like `NonNull` where the `Scalar` info is more restrictive than what
-        // the fields say (`rustc_layout_scalar_valid_range_start`).
-        // But in most cases, this will just propagate what the fields say,
-        // and then we want the error to point at the field -- so, first recurse,
-        // then check ABI.
+        // *After* all of this, check further information stored in the layout. We need to check
+        // this to handle types like `NonNull` where the `Scalar` info is more restrictive than what
+        // the fields say (`rustc_layout_scalar_valid_range_start`). But in most cases, this will
+        // just propagate what the fields say, and then we want the error to point at the field --
+        // so, we first recurse, then we do this check.
         //
         // FIXME: We could avoid some redundant checks here. For newtypes wrapping
         // scalars, we do the same check on every "level" (e.g., first we check
         // MyNewtype and then the scalar in there).
+        if val.layout.is_uninhabited() {
+            let ty = val.layout.ty;
+            throw_validation_failure!(self.path, UninhabitedVal { ty });
+        }
         match val.layout.backend_repr {
-            BackendRepr::Uninhabited => {
-                let ty = val.layout.ty;
-                throw_validation_failure!(self.path, UninhabitedVal { ty });
-            }
             BackendRepr::Scalar(scalar_layout) => {
                 if !scalar_layout.is_uninit_valid() {
                     // There is something to check here.
diff --git a/compiler/rustc_const_eval/src/util/check_validity_requirement.rs b/compiler/rustc_const_eval/src/util/check_validity_requirement.rs
index 79baf91c3ce..6426bca2332 100644
--- a/compiler/rustc_const_eval/src/util/check_validity_requirement.rs
+++ b/compiler/rustc_const_eval/src/util/check_validity_requirement.rs
@@ -111,13 +111,15 @@ fn check_validity_requirement_lax<'tcx>(
     };
 
     // Check the ABI.
-    let valid = match this.backend_repr {
-        BackendRepr::Uninhabited => false, // definitely UB
-        BackendRepr::Scalar(s) => scalar_allows_raw_init(s),
-        BackendRepr::ScalarPair(s1, s2) => scalar_allows_raw_init(s1) && scalar_allows_raw_init(s2),
-        BackendRepr::Vector { element: s, count } => count == 0 || scalar_allows_raw_init(s),
-        BackendRepr::Memory { .. } => true, // Fields are checked below.
-    };
+    let valid = !this.is_uninhabited() // definitely UB if uninhabited
+        && match this.backend_repr {
+            BackendRepr::Scalar(s) => scalar_allows_raw_init(s),
+            BackendRepr::ScalarPair(s1, s2) => {
+                scalar_allows_raw_init(s1) && scalar_allows_raw_init(s2)
+            }
+            BackendRepr::Vector { element: s, count } => count == 0 || scalar_allows_raw_init(s),
+            BackendRepr::Memory { .. } => true, // Fields are checked below.
+        };
     if !valid {
         // This is definitely not okay.
         return Ok(false);