about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2025-01-10 04:47:45 +0000
committerMichael Goulet <michael@errs.io>2025-01-30 15:34:00 +0000
commit739ef83f318defbb9692029fa98f56639896c6fd (patch)
treecfe995feb5405d026522872a8388e419ec2ea157
parentfdc4bd22b7b8117f4a3864c342773df600f5b956 (diff)
downloadrust-739ef83f318defbb9692029fa98f56639896c6fd.tar.gz
rust-739ef83f318defbb9692029fa98f56639896c6fd.zip
Normalize vtable entries before walking and deduplicating them
-rw-r--r--compiler/rustc_const_eval/src/interpret/cast.rs9
-rw-r--r--compiler/rustc_infer/src/infer/at.rs24
-rw-r--r--compiler/rustc_trait_selection/src/traits/vtable.rs33
-rw-r--r--tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr2
-rw-r--r--tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs32
-rw-r--r--tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.run.stdout1
6 files changed, 77 insertions, 24 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs
index b05efd10e66..52bc2af928d 100644
--- a/compiler/rustc_const_eval/src/interpret/cast.rs
+++ b/compiler/rustc_const_eval/src/interpret/cast.rs
@@ -419,7 +419,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
                         self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
                     let vtable_entries = self.vtable_entries(data_a.principal(), ty);
                     if let Some(entry_idx) = vptr_entry_idx {
-                        let Some(&ty::VtblEntry::TraitVPtr(_)) = vtable_entries.get(entry_idx)
+                        let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
+                            vtable_entries.get(entry_idx)
                         else {
                             span_bug!(
                                 self.cur_span(),
@@ -428,6 +429,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
                                 dest_pointee_ty
                             );
                         };
+                        let erased_trait_ref =
+                            ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref);
+                        assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env(
+                            erased_trait_ref,
+                            self.tcx.instantiate_bound_regions_with_erased(b)
+                        )));
                     } else {
                         // In this case codegen would keep using the old vtable. We don't want to do
                         // that as it has the wrong trait. The reason codegen can do this is that
diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs
index 12e2bbc968f..ad15b764bcc 100644
--- a/compiler/rustc_infer/src/infer/at.rs
+++ b/compiler/rustc_infer/src/infer/at.rs
@@ -402,6 +402,18 @@ impl<'tcx> ToTrace<'tcx> for ty::PolyExistentialTraitRef<'tcx> {
     }
 }
 
+impl<'tcx> ToTrace<'tcx> for ty::ExistentialTraitRef<'tcx> {
+    fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
+        TypeTrace {
+            cause: cause.clone(),
+            values: ValuePairs::ExistentialTraitRef(ExpectedFound::new(
+                ty::Binder::dummy(a),
+                ty::Binder::dummy(b),
+            )),
+        }
+    }
+}
+
 impl<'tcx> ToTrace<'tcx> for ty::PolyExistentialProjection<'tcx> {
     fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
         TypeTrace {
@@ -410,3 +422,15 @@ impl<'tcx> ToTrace<'tcx> for ty::PolyExistentialProjection<'tcx> {
         }
     }
 }
+
+impl<'tcx> ToTrace<'tcx> for ty::ExistentialProjection<'tcx> {
+    fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
+        TypeTrace {
+            cause: cause.clone(),
+            values: ValuePairs::ExistentialProjection(ExpectedFound::new(
+                ty::Binder::dummy(a),
+                ty::Binder::dummy(b),
+            )),
+        }
+    }
+}
diff --git a/compiler/rustc_trait_selection/src/traits/vtable.rs b/compiler/rustc_trait_selection/src/traits/vtable.rs
index b23d1da1608..abdf5df6f72 100644
--- a/compiler/rustc_trait_selection/src/traits/vtable.rs
+++ b/compiler/rustc_trait_selection/src/traits/vtable.rs
@@ -3,7 +3,6 @@ use std::ops::ControlFlow;
 
 use rustc_hir::def_id::DefId;
 use rustc_infer::infer::TyCtxtInferExt;
-use rustc_infer::infer::at::ToTrace;
 use rustc_infer::traits::ObligationCause;
 use rustc_infer::traits::util::PredicateSet;
 use rustc_middle::bug;
@@ -127,16 +126,15 @@ fn prepare_vtable_segments_inner<'tcx, T>(
                 .explicit_super_predicates_of(inner_most_trait_ref.def_id)
                 .iter_identity_copied()
                 .filter_map(move |(pred, _)| {
-                    Some(
-                        tcx.instantiate_bound_regions_with_erased(
-                            pred.instantiate_supertrait(
-                                tcx,
-                                ty::Binder::dummy(inner_most_trait_ref),
-                            )
-                            .as_trait_clause()?,
-                        )
-                        .trait_ref,
+                    pred.instantiate_supertrait(tcx, ty::Binder::dummy(inner_most_trait_ref))
+                        .as_trait_clause()
+                })
+                .map(move |pred| {
+                    tcx.normalize_erasing_late_bound_regions(
+                        ty::TypingEnv::fully_monomorphized(),
+                        pred,
                     )
+                    .trait_ref
                 });
 
             // Find an unvisited supertrait
@@ -229,6 +227,8 @@ fn vtable_entries<'tcx>(
     tcx: TyCtxt<'tcx>,
     trait_ref: ty::TraitRef<'tcx>,
 ) -> &'tcx [VtblEntry<'tcx>] {
+    debug_assert!(!trait_ref.has_non_region_infer() && !trait_ref.has_non_region_param());
+
     debug!("vtable_entries({:?})", trait_ref);
 
     let mut entries = vec![];
@@ -422,17 +422,8 @@ fn trait_refs_are_compatible<'tcx>(
     let ocx = ObligationCtxt::new(&infcx);
     let source_principal = ocx.normalize(&ObligationCause::dummy(), param_env, vtable_principal);
     let target_principal = ocx.normalize(&ObligationCause::dummy(), param_env, target_principal);
-    let Ok(()) = ocx.eq_trace(
-        &ObligationCause::dummy(),
-        param_env,
-        ToTrace::to_trace(
-            &ObligationCause::dummy(),
-            ty::Binder::dummy(target_principal),
-            ty::Binder::dummy(source_principal),
-        ),
-        target_principal,
-        source_principal,
-    ) else {
+    let Ok(()) = ocx.eq(&ObligationCause::dummy(), param_env, target_principal, source_principal)
+    else {
         return false;
     };
     ocx.select_all_or_error().is_empty()
diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr
index 9e93e1c48c9..04b1afae7be 100644
--- a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr
+++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr
@@ -3,8 +3,6 @@ error: vtable entries: [
            MetadataSize,
            MetadataAlign,
            Method(<() as Supertrait<()>>::_print_numbers),
-           Method(<() as Supertrait<()>>::_print_numbers),
-           TraitVPtr(<() as Supertrait<<() as Identity>::Selff>>),
            Method(<() as Middle<()>>::say_hello),
        ]
   --> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:29:1
diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs
new file mode 100644
index 00000000000..fd0f62b4255
--- /dev/null
+++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs
@@ -0,0 +1,32 @@
+//@ run-pass
+//@ check-run-results
+
+#![feature(trait_upcasting)]
+
+trait Supertrait<T> {
+    fn _print_numbers(&self, mem: &[usize; 100]) {
+        println!("{mem:?}");
+    }
+}
+impl<T> Supertrait<T> for () {}
+
+trait Identity {
+    type Selff;
+}
+impl<Selff> Identity for Selff {
+    type Selff = Selff;
+}
+
+trait Middle<T>: Supertrait<()> + Supertrait<T> {
+    fn say_hello(&self, _: &usize) {
+        println!("Hello!");
+    }
+}
+impl<T> Middle<T> for () {}
+
+trait Trait: Middle<<() as Identity>::Selff> {}
+impl Trait for () {}
+
+fn main() {
+    (&() as &dyn Trait as &dyn Middle<()>).say_hello(&0);
+}
diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.run.stdout b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.run.stdout
new file mode 100644
index 00000000000..10ddd6d257e
--- /dev/null
+++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.run.stdout
@@ -0,0 +1 @@
+Hello!