about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src/interpret/traits.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src/interpret/traits.rs')
-rw-r--r--compiler/rustc_const_eval/src/interpret/traits.rs67
1 files changed, 47 insertions, 20 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs
index b5eef0fd8c9..8eead6018ac 100644
--- a/compiler/rustc_const_eval/src/interpret/traits.rs
+++ b/compiler/rustc_const_eval/src/interpret/traits.rs
@@ -1,6 +1,6 @@
 use rustc_middle::mir::interpret::{InterpResult, Pointer};
 use rustc_middle::ty::layout::LayoutOf;
-use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
+use rustc_middle::ty::{self, ExistentialPredicateStableCmpExt, Ty, TyCtxt, VtblEntry};
 use rustc_target::abi::{Align, Size};
 use tracing::trace;
 
@@ -11,26 +11,25 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
     /// Creates a dynamic vtable for the given type and vtable origin. This is used only for
     /// objects.
     ///
-    /// The `trait_ref` encodes the erased self type. Hence, if we are making an object `Foo<Trait>`
-    /// from a value of type `Foo<T>`, then `trait_ref` would map `T: Trait`. `None` here means that
-    /// this is an auto trait without any methods, so we only need the basic vtable (drop, size,
-    /// align).
+    /// The `dyn_ty` encodes the erased self type. Hence, if we are making an object
+    /// `Foo<dyn Trait<Assoc = A> + Send>` from a value of type `Foo<T>`, then `dyn_ty`
+    /// would be `Trait<Assoc = A> + Send`. If this list doesn't have a principal trait ref,
+    /// we only need the basic vtable prefix (drop, size, align).
     pub fn get_vtable_ptr(
         &self,
         ty: Ty<'tcx>,
-        poly_trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
+        dyn_ty: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
     ) -> InterpResult<'tcx, Pointer<Option<M::Provenance>>> {
-        trace!("get_vtable(trait_ref={:?})", poly_trait_ref);
+        trace!("get_vtable(ty={ty:?}, dyn_ty={dyn_ty:?})");
 
-        let (ty, poly_trait_ref) = self.tcx.erase_regions((ty, poly_trait_ref));
+        let (ty, dyn_ty) = self.tcx.erase_regions((ty, dyn_ty));
 
         // All vtables must be monomorphic, bail out otherwise.
         ensure_monomorphic_enough(*self.tcx, ty)?;
-        ensure_monomorphic_enough(*self.tcx, poly_trait_ref)?;
+        ensure_monomorphic_enough(*self.tcx, dyn_ty)?;
 
         let salt = M::get_global_alloc_salt(self, None);
-        let vtable_symbolic_allocation =
-            self.tcx.reserve_and_set_vtable_alloc(ty, poly_trait_ref, salt);
+        let vtable_symbolic_allocation = self.tcx.reserve_and_set_vtable_alloc(ty, dyn_ty, salt);
         let vtable_ptr = self.global_root_pointer(Pointer::from(vtable_symbolic_allocation))?;
         Ok(vtable_ptr.into())
     }
@@ -64,17 +63,45 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
     /// expected trait type.
     pub(super) fn check_vtable_for_type(
         &self,
-        vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
-        expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
+        vtable_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
+        expected_dyn_type: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
     ) -> InterpResult<'tcx> {
-        let eq = match (expected_trait.principal(), vtable_trait) {
-            (Some(a), Some(b)) => self.eq_in_param_env(a, b),
-            (None, None) => true,
-            _ => false,
-        };
-        if !eq {
-            throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
+        // We check validity by comparing the lists of predicates for equality. We *could* instead
+        // check that the dynamic type to which the vtable belongs satisfies all the expected
+        // predicates, but that would likely be a lot slower and seems unnecessarily permissive.
+
+        // FIXME: we are skipping auto traits for now, but might revisit this in the future.
+        let mut sorted_vtable: Vec<_> = vtable_dyn_type.without_auto_traits().collect();
+        let mut sorted_expected: Vec<_> = expected_dyn_type.without_auto_traits().collect();
+        // `skip_binder` here is okay because `stable_cmp` doesn't look at binders
+        sorted_vtable.sort_by(|a, b| a.skip_binder().stable_cmp(*self.tcx, &b.skip_binder()));
+        sorted_vtable.dedup();
+        sorted_expected.sort_by(|a, b| a.skip_binder().stable_cmp(*self.tcx, &b.skip_binder()));
+        sorted_expected.dedup();
+
+        if sorted_vtable.len() != sorted_expected.len() {
+            throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
+        }
+
+        for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) {
+            let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) {
+                (
+                    ty::ExistentialPredicate::Trait(a_data),
+                    ty::ExistentialPredicate::Trait(b_data),
+                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
+
+                (
+                    ty::ExistentialPredicate::Projection(a_data),
+                    ty::ExistentialPredicate::Projection(b_data),
+                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
+
+                _ => false,
+            };
+            if !is_eq {
+                throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
+            }
         }
+
         Ok(())
     }