about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src')
-rw-r--r--compiler/rustc_const_eval/src/errors.rs18
-rw-r--r--compiler/rustc_const_eval/src/interpret/cast.rs6
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs14
-rw-r--r--compiler/rustc_const_eval/src/interpret/traits.rs67
-rw-r--r--compiler/rustc_const_eval/src/interpret/validity.rs8
5 files changed, 66 insertions, 47 deletions
diff --git a/compiler/rustc_const_eval/src/errors.rs b/compiler/rustc_const_eval/src/errors.rs
index c38d7f3d03c..c60bacb8506 100644
--- a/compiler/rustc_const_eval/src/errors.rs
+++ b/compiler/rustc_const_eval/src/errors.rs
@@ -522,12 +522,9 @@ impl<'a> ReportErrorExt for UndefinedBehaviorInfo<'a> {
             UnterminatedCString(ptr) | InvalidFunctionPointer(ptr) | InvalidVTablePointer(ptr) => {
                 diag.arg("pointer", ptr);
             }
-            InvalidVTableTrait { expected_trait, vtable_trait } => {
-                diag.arg("expected_trait", expected_trait.to_string());
-                diag.arg(
-                    "vtable_trait",
-                    vtable_trait.map(|t| t.to_string()).unwrap_or_else(|| format!("<trivial>")),
-                );
+            InvalidVTableTrait { expected_dyn_type, vtable_dyn_type } => {
+                diag.arg("expected_dyn_type", expected_dyn_type.to_string());
+                diag.arg("vtable_dyn_type", vtable_dyn_type.to_string());
             }
             PointerUseAfterFree(alloc_id, msg) => {
                 diag.arg("alloc_id", alloc_id)
@@ -777,12 +774,9 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
             DanglingPtrNoProvenance { pointer, .. } => {
                 err.arg("pointer", pointer);
             }
-            InvalidMetaWrongTrait { expected_trait: ref_trait, vtable_trait } => {
-                err.arg("ref_trait", ref_trait.to_string());
-                err.arg(
-                    "vtable_trait",
-                    vtable_trait.map(|t| t.to_string()).unwrap_or_else(|| format!("<trivial>")),
-                );
+            InvalidMetaWrongTrait { vtable_dyn_type, expected_dyn_type } => {
+                err.arg("vtable_dyn_type", vtable_dyn_type.to_string());
+                err.arg("expected_dyn_type", expected_dyn_type.to_string());
             }
             NullPtr { .. }
             | ConstRefToMutable
diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs
index 198aa1bbd5b..70d074cfdc5 100644
--- a/compiler/rustc_const_eval/src/interpret/cast.rs
+++ b/compiler/rustc_const_eval/src/interpret/cast.rs
@@ -128,7 +128,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
             CastKind::DynStar => {
                 if let ty::Dynamic(data, _, ty::DynStar) = cast_ty.kind() {
                     // Initial cast from sized to dyn trait
-                    let vtable = self.get_vtable_ptr(src.layout.ty, data.principal())?;
+                    let vtable = self.get_vtable_ptr(src.layout.ty, data)?;
                     let vtable = Scalar::from_maybe_pointer(vtable, self);
                     let data = self.read_immediate(src)?.to_scalar();
                     let _assert_pointer_like = data.to_pointer(self)?;
@@ -446,12 +446,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
                 }
 
                 // Get the destination trait vtable and return that.
-                let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
+                let new_vptr = self.get_vtable_ptr(ty, data_b)?;
                 self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
             }
             (_, &ty::Dynamic(data, _, ty::Dyn)) => {
                 // Initial cast from sized to dyn trait
-                let vtable = self.get_vtable_ptr(src_pointee_ty, data.principal())?;
+                let vtable = self.get_vtable_ptr(src_pointee_ty, data)?;
                 let ptr = self.read_pointer(src)?;
                 let val = Immediate::new_dyn_trait(ptr, vtable, &*self.tcx);
                 self.write_immediate(val, dest)
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index e5fdf592ec9..c3b506d848c 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -943,12 +943,13 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         if offset.bytes() != 0 {
             throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
         }
-        let Some(GlobalAlloc::VTable(ty, vtable_trait)) = self.tcx.try_get_global_alloc(alloc_id)
+        let Some(GlobalAlloc::VTable(ty, vtable_dyn_type)) =
+            self.tcx.try_get_global_alloc(alloc_id)
         else {
             throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
         };
-        if let Some(expected_trait) = expected_trait {
-            self.check_vtable_for_type(vtable_trait, expected_trait)?;
+        if let Some(expected_dyn_type) = expected_trait {
+            self.check_vtable_for_type(vtable_dyn_type, expected_dyn_type)?;
         }
         Ok(ty)
     }
@@ -1113,11 +1114,8 @@ impl<'a, 'tcx, M: Machine<'tcx>> std::fmt::Debug for DumpAllocs<'a, 'tcx, M> {
                         Some(GlobalAlloc::Function { instance, .. }) => {
                             write!(fmt, " (fn: {instance})")?;
                         }
-                        Some(GlobalAlloc::VTable(ty, Some(trait_ref))) => {
-                            write!(fmt, " (vtable: impl {trait_ref} for {ty})")?;
-                        }
-                        Some(GlobalAlloc::VTable(ty, None)) => {
-                            write!(fmt, " (vtable: impl <auto trait> for {ty})")?;
+                        Some(GlobalAlloc::VTable(ty, dyn_ty)) => {
+                            write!(fmt, " (vtable: impl {dyn_ty} for {ty})")?;
                         }
                         Some(GlobalAlloc::Static(did)) => {
                             write!(fmt, " (static: {})", self.ecx.tcx.def_path_str(did))?;
diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs
index b5eef0fd8c9..8eead6018ac 100644
--- a/compiler/rustc_const_eval/src/interpret/traits.rs
+++ b/compiler/rustc_const_eval/src/interpret/traits.rs
@@ -1,6 +1,6 @@
 use rustc_middle::mir::interpret::{InterpResult, Pointer};
 use rustc_middle::ty::layout::LayoutOf;
-use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
+use rustc_middle::ty::{self, ExistentialPredicateStableCmpExt, Ty, TyCtxt, VtblEntry};
 use rustc_target::abi::{Align, Size};
 use tracing::trace;
 
@@ -11,26 +11,25 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
     /// Creates a dynamic vtable for the given type and vtable origin. This is used only for
     /// objects.
     ///
-    /// The `trait_ref` encodes the erased self type. Hence, if we are making an object `Foo<Trait>`
-    /// from a value of type `Foo<T>`, then `trait_ref` would map `T: Trait`. `None` here means that
-    /// this is an auto trait without any methods, so we only need the basic vtable (drop, size,
-    /// align).
+    /// The `dyn_ty` encodes the erased self type. Hence, if we are making an object
+    /// `Foo<dyn Trait<Assoc = A> + Send>` from a value of type `Foo<T>`, then `dyn_ty`
+    /// would be `Trait<Assoc = A> + Send`. If this list doesn't have a principal trait ref,
+    /// we only need the basic vtable prefix (drop, size, align).
     pub fn get_vtable_ptr(
         &self,
         ty: Ty<'tcx>,
-        poly_trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
+        dyn_ty: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
     ) -> InterpResult<'tcx, Pointer<Option<M::Provenance>>> {
-        trace!("get_vtable(trait_ref={:?})", poly_trait_ref);
+        trace!("get_vtable(ty={ty:?}, dyn_ty={dyn_ty:?})");
 
-        let (ty, poly_trait_ref) = self.tcx.erase_regions((ty, poly_trait_ref));
+        let (ty, dyn_ty) = self.tcx.erase_regions((ty, dyn_ty));
 
         // All vtables must be monomorphic, bail out otherwise.
         ensure_monomorphic_enough(*self.tcx, ty)?;
-        ensure_monomorphic_enough(*self.tcx, poly_trait_ref)?;
+        ensure_monomorphic_enough(*self.tcx, dyn_ty)?;
 
         let salt = M::get_global_alloc_salt(self, None);
-        let vtable_symbolic_allocation =
-            self.tcx.reserve_and_set_vtable_alloc(ty, poly_trait_ref, salt);
+        let vtable_symbolic_allocation = self.tcx.reserve_and_set_vtable_alloc(ty, dyn_ty, salt);
         let vtable_ptr = self.global_root_pointer(Pointer::from(vtable_symbolic_allocation))?;
         Ok(vtable_ptr.into())
     }
@@ -64,17 +63,45 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
     /// expected trait type.
     pub(super) fn check_vtable_for_type(
         &self,
-        vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
-        expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
+        vtable_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
+        expected_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
     ) -> InterpResult<'tcx> {
-        let eq = match (expected_trait.principal(), vtable_trait) {
-            (Some(a), Some(b)) => self.eq_in_param_env(a, b),
-            (None, None) => true,
-            _ => false,
-        };
-        if !eq {
-            throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
+        // We check validity by comparing the lists of predicates for equality. We *could* instead
+        // check that the dynamic type to which the vtable belongs satisfies all the expected
+        // predicates, but that would likely be a lot slower and seems unnecessarily permissive.
+
+        // FIXME: we are skipping auto traits for now, but might revisit this in the future.
+        let mut sorted_vtable: Vec<_> = vtable_dyn_type.without_auto_traits().collect();
+        let mut sorted_expected: Vec<_> = expected_dyn_type.without_auto_traits().collect();
+        // `skip_binder` here is okay because `stable_cmp` doesn't look at binders
+        sorted_vtable.sort_by(|a, b| a.skip_binder().stable_cmp(*self.tcx, &b.skip_binder()));
+        sorted_vtable.dedup();
+        sorted_expected.sort_by(|a, b| a.skip_binder().stable_cmp(*self.tcx, &b.skip_binder()));
+        sorted_expected.dedup();
+
+        if sorted_vtable.len() != sorted_expected.len() {
+            throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
+        }
+
+        for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) {
+            let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) {
+                (
+                    ty::ExistentialPredicate::Trait(a_data),
+                    ty::ExistentialPredicate::Trait(b_data),
+                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
+
+                (
+                    ty::ExistentialPredicate::Projection(a_data),
+                    ty::ExistentialPredicate::Projection(b_data),
+                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
+
+                _ => false,
+            };
+            if !is_eq {
+                throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
+            }
         }
+
         Ok(())
     }
 
diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs
index ff3c6120f0c..203cceccd9d 100644
--- a/compiler/rustc_const_eval/src/interpret/validity.rs
+++ b/compiler/rustc_const_eval/src/interpret/validity.rs
@@ -452,8 +452,8 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
                     self.path,
                     Ub(DanglingIntPointer{ .. } | InvalidVTablePointer(..)) =>
                         InvalidVTablePtr { value: format!("{vtable}") },
-                    Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
-                        InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
+                    Ub(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type }) => {
+                        InvalidMetaWrongTrait { vtable_dyn_type, expected_dyn_type }
                     },
                 );
             }
@@ -1281,8 +1281,8 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
                     self.path,
                     // It's not great to catch errors here, since we can't give a very good path,
                     // but it's better than ICEing.
-                    Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
-                        InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
+                    Ub(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type }) => {
+                        InvalidMetaWrongTrait { vtable_dyn_type, expected_dyn_type: *expected_dyn_type }
                     },
                 );
             }