about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Markeffsky <@>2025-01-09 17:34:58 +0000
committerLukas Markeffsky <@>2025-01-31 17:43:28 +0100
commita90cb05da642607545560fb64dd057d3fedf2e97 (patch)
treed657c8f9bb4fc8060c911b4e219daa8cbf550e66
parent7f36543a48e52912ac6664a70c0a5b9d86509eaf (diff)
downloadrust-a90cb05da642607545560fb64dd057d3fedf2e97.tar.gz
rust-a90cb05da642607545560fb64dd057d3fedf2e97.zip
interpret: adjust vtable validity check for higher-ranked types
-rw-r--r--compiler/rustc_const_eval/src/interpret/cast.rs10
-rw-r--r--compiler/rustc_const_eval/src/interpret/eval_context.rs40
-rw-r--r--compiler/rustc_const_eval/src/interpret/traits.rs20
-rw-r--r--src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.rs30
-rw-r--r--src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.stderr15
-rw-r--r--src/tools/miri/tests/pass/dyn-upcast.rs30
6 files changed, 89 insertions, 56 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs
index e110c155da0..86fdfae1ffb 100644
--- a/compiler/rustc_const_eval/src/interpret/cast.rs
+++ b/compiler/rustc_const_eval/src/interpret/cast.rs
@@ -430,10 +430,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
                     };
                     let erased_trait_ref =
                         ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref);
-                    assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env(
-                        erased_trait_ref,
-                        self.tcx.instantiate_bound_regions_with_erased(b)
-                    )));
+                    assert_eq!(
+                        data_b.principal().map(|b| {
+                            self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b)
+                        }),
+                        Some(erased_trait_ref),
+                    );
                 } else {
                     // In this case codegen would keep using the old vtable. We don't want to do
                     // that as it has the wrong trait. The reason codegen can do this is that
diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs
index 95a72d3cbc1..242cf6484dd 100644
--- a/compiler/rustc_const_eval/src/interpret/eval_context.rs
+++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs
@@ -4,9 +4,6 @@ use either::{Left, Right};
 use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout};
 use rustc_errors::DiagCtxtHandle;
 use rustc_hir::def_id::DefId;
-use rustc_infer::infer::TyCtxtInferExt;
-use rustc_infer::infer::at::ToTrace;
-use rustc_infer::traits::ObligationCause;
 use rustc_middle::mir::interpret::{ErrorHandled, InvalidMetaKind, ReportedErrorInfo};
 use rustc_middle::query::TyCtxtAt;
 use rustc_middle::ty::layout::{
@@ -17,8 +14,7 @@ use rustc_middle::{mir, span_bug};
 use rustc_session::Limit;
 use rustc_span::Span;
 use rustc_target::callconv::FnAbi;
-use rustc_trait_selection::traits::ObligationCtxt;
-use tracing::{debug, instrument, trace};
+use tracing::{debug, trace};
 
 use super::{
     Frame, FrameInfo, GlobalId, InterpErrorInfo, InterpErrorKind, InterpResult, MPlaceTy, Machine,
@@ -323,40 +319,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         }
     }
 
-    /// Check if the two things are equal in the current param_env, using an infcx to get proper
-    /// equality checks.
-    #[instrument(level = "trace", skip(self), ret)]
-    pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
-    where
-        T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
-    {
-        // Fast path: compare directly.
-        if a == b {
-            return true;
-        }
-        // Slow path: spin up an inference context to check if these traits are sufficiently equal.
-        let (infcx, param_env) = self.tcx.infer_ctxt().build_with_typing_env(self.typing_env);
-        let ocx = ObligationCtxt::new(&infcx);
-        let cause = ObligationCause::dummy_with_span(self.cur_span());
-        // equate the two trait refs after normalization
-        let a = ocx.normalize(&cause, param_env, a);
-        let b = ocx.normalize(&cause, param_env, b);
-
-        if let Err(terr) = ocx.eq(&cause, param_env, a, b) {
-            trace!(?terr);
-            return false;
-        }
-
-        let errors = ocx.select_all_or_error();
-        if !errors.is_empty() {
-            trace!(?errors);
-            return false;
-        }
-
-        // All good.
-        true
-    }
-
     /// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
     /// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
     /// and is primarily intended for the panic machinery.
diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs
index 4cfaacebfcd..a5029eea5a7 100644
--- a/compiler/rustc_const_eval/src/interpret/traits.rs
+++ b/compiler/rustc_const_eval/src/interpret/traits.rs
@@ -86,21 +86,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
             throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
         }
 
+        // This checks whether there is a subtyping relation between the predicates in either direction.
+        // For example:
+        // - casting between `dyn for<'a> Trait<fn(&'a u8)>` and `dyn Trait<fn(&'static u8)>` is OK
+        // - casting between `dyn Trait<for<'a> fn(&'a u8)>` and either of the above is UB
         for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) {
-            let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) {
-                (
-                    ty::ExistentialPredicate::Trait(a_data),
-                    ty::ExistentialPredicate::Trait(b_data),
-                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
+            let a_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, a_pred);
+            let b_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b_pred);
 
-                (
-                    ty::ExistentialPredicate::Projection(a_data),
-                    ty::ExistentialPredicate::Projection(b_data),
-                ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
-
-                _ => false,
-            };
-            if !is_eq {
+            if a_pred != b_pred {
                 throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
             }
         }
diff --git a/src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.rs b/src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.rs
new file mode 100644
index 00000000000..7de4aef422a
--- /dev/null
+++ b/src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.rs
@@ -0,0 +1,30 @@
+// Test that transmuting from `&dyn Trait<fn(&'static ())>` to `&dyn Trait<for<'a> fn(&'a ())>` is UB.
+//
+// The vtable of `() as Trait<fn(&'static ())>` and `() as Trait<for<'a> fn(&'a ())>` can have
+// different entries and, because in the former the entry for `foo` is vacant, this test will
+// segfault at runtime.
+
+trait Trait<U> {
+    fn foo(&self)
+    where
+        U: HigherRanked,
+    {
+    }
+}
+impl<T, U> Trait<U> for T {}
+
+trait HigherRanked {}
+impl HigherRanked for for<'a> fn(&'a ()) {}
+
+// 2nd candidate is required so that selecting `(): Trait<fn(&'static ())>` will
+// evaluate the candidates and fail the leak check instead of returning the
+// only applicable candidate.
+trait Unsatisfied {}
+impl<T: Unsatisfied> HigherRanked for T {}
+
+fn main() {
+    let x: &dyn Trait<fn(&'static ())> = &();
+    let y: &dyn Trait<for<'a> fn(&'a ())> = unsafe { std::mem::transmute(x) };
+    //~^ ERROR: wrong trait in wide pointer vtable
+    y.foo();
+}
diff --git a/src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.stderr b/src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.stderr
new file mode 100644
index 00000000000..cfdf279a605
--- /dev/null
+++ b/src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.stderr
@@ -0,0 +1,15 @@
+error: Undefined Behavior: constructing invalid value: wrong trait in wide pointer vtable: expected `Trait<for<'a> fn(&'a ())>`, but encountered `Trait<fn(&())>`
+  --> tests/fail/validity/dyn-transmute-inner-binder.rs:LL:CC
+   |
+LL |     let y: &dyn Trait<for<'a> fn(&'a ())> = unsafe { std::mem::transmute(x) };
+   |                                                      ^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value: wrong trait in wide pointer vtable: expected `Trait<for<'a> fn(&'a ())>`, but encountered `Trait<fn(&())>`
+   |
+   = help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior
+   = help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information
+   = note: BACKTRACE:
+   = note: inside `main` at tests/fail/validity/dyn-transmute-inner-binder.rs:LL:CC
+
+note: some details are omitted, run with `MIRIFLAGS=-Zmiri-backtrace=full` for a verbose backtrace
+
+error: aborting due to 1 previous error
+
diff --git a/src/tools/miri/tests/pass/dyn-upcast.rs b/src/tools/miri/tests/pass/dyn-upcast.rs
index f100c4d6a86..394e80aa257 100644
--- a/src/tools/miri/tests/pass/dyn-upcast.rs
+++ b/src/tools/miri/tests/pass/dyn-upcast.rs
@@ -12,6 +12,7 @@ fn main() {
     drop_principal();
     modulo_binder();
     modulo_assoc();
+    bidirectional_subtyping();
 }
 
 fn vtable_nop_cast() {
@@ -534,3 +535,32 @@ fn modulo_assoc() {
 
     (&() as &dyn Trait as &dyn Middle<()>).say_hello(&0);
 }
+
+fn bidirectional_subtyping() {
+    // Test that transmuting between subtypes of dyn traits is fine, even in the
+    // "wrong direction", i.e. going from a lower-ranked to a higher-ranked dyn trait.
+    // Note that compared to the `dyn-transmute-inner-binder` test, the `for` is on the
+    // *outside* here!
+
+    trait Trait<U: ?Sized> {}
+    impl<T, U: ?Sized> Trait<U> for T {}
+
+    struct Wrapper<T: ?Sized>(T);
+
+    let x: &dyn Trait<fn(&'static ())> = &();
+    let _y: &dyn for<'a> Trait<fn(&'a ())> = unsafe { std::mem::transmute(x) };
+
+    let x: &dyn for<'a> Trait<fn(&'a ())> = &();
+    let _y: &dyn Trait<fn(&'static ())> = unsafe { std::mem::transmute(x) };
+
+    let x: &dyn Trait<dyn Trait<fn(&'static ())>> = &();
+    let _y: &dyn for<'a> Trait<dyn Trait<fn(&'a ())>> = unsafe { std::mem::transmute(x) };
+
+    let x: &dyn for<'a> Trait<dyn Trait<fn(&'a ())>> = &();
+    let _y: &dyn Trait<dyn Trait<fn(&'static ())>> = unsafe { std::mem::transmute(x) };
+
+    // This lowers to a ptr-to-ptr cast (which behaves like a transmute)
+    // and not an unsizing coercion:
+    let x: *const dyn for<'a> Trait<&'a ()> = &();
+    let _y: *const Wrapper<dyn Trait<&'static ()>> = x as _;
+}