about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2024-06-10 17:26:36 +0200
committerRalf Jung <post@ralfj.de>2024-06-11 08:54:49 +0200
commit3757136d8e0d8ddca294453e5a5ce70cfa3417e9 (patch)
treecf5ca21676495d215f0c791c7a67546fd8f46e70
parentd041b7cf30bb5b5236cd9148cde7a3017ed28679 (diff)
downloadrust-3757136d8e0d8ddca294453e5a5ce70cfa3417e9.tar.gz
rust-3757136d8e0d8ddca294453e5a5ce70cfa3417e9.zip
interpret: dyn trait metadata check: equate traits in a proper way
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs4
-rw-r--r--compiler/rustc_const_eval/src/interpret/traits.rs36
-rw-r--r--src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs40
3 files changed, 76 insertions, 4 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index 84c6dad1cd3..e2e39399f3a 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -884,9 +884,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
             throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
         };
         if let Some(expected_trait) = expected_trait {
-            if vtable_trait != expected_trait.principal() {
-                throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
-            }
+            self.check_vtable_for_type(vtable_trait, expected_trait)?;
         }
         Ok(ty)
     }
diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs
index 44e7244a513..bd2c6519421 100644
--- a/compiler/rustc_const_eval/src/interpret/traits.rs
+++ b/compiler/rustc_const_eval/src/interpret/traits.rs
@@ -1,11 +1,14 @@
+use rustc_infer::infer::TyCtxtInferExt;
+use rustc_infer::traits::ObligationCause;
 use rustc_middle::mir::interpret::{InterpResult, Pointer};
 use rustc_middle::ty::layout::LayoutOf;
 use rustc_middle::ty::{self, Ty};
 use rustc_target::abi::{Align, Size};
+use rustc_trait_selection::traits::ObligationCtxt;
 use tracing::trace;
 
 use super::util::ensure_monomorphic_enough;
-use super::{InterpCx, MPlaceTy, Machine, MemPlaceMeta, OffsetMode, Projectable};
+use super::{throw_ub, InterpCx, MPlaceTy, Machine, MemPlaceMeta, OffsetMode, Projectable};
 
 impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
     /// Creates a dynamic vtable for the given type and vtable origin. This is used only for
@@ -44,6 +47,37 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         Ok((layout.size, layout.align.abi))
     }
 
+    /// Check that the given vtable trait is valid for a pointer/reference/place with the given
+    /// expected trait type.
+    pub(super) fn check_vtable_for_type(
+        &self,
+        vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
+        expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
+    ) -> InterpResult<'tcx> {
+        // Fast path: if they are equal, it's all fine.
+        if expected_trait.principal() == vtable_trait {
+            return Ok(());
+        }
+        if let (Some(expected_trait), Some(vtable_trait)) =
+            (expected_trait.principal(), vtable_trait)
+        {
+            // Slow path: spin up an inference context to check if these traits are sufficiently equal.
+            let infcx = self.tcx.infer_ctxt().build();
+            let ocx = ObligationCtxt::new(&infcx);
+            let cause = ObligationCause::dummy_with_span(self.cur_span());
+            // equate the two trait refs after normalization
+            let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
+            let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
+            if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
+                if ocx.select_all_or_error().is_empty() {
+                    // All good.
+                    return Ok(());
+                }
+            }
+        }
+        throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
+    }
+
     /// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
     pub(super) fn unpack_dyn_trait(
         &self,
diff --git a/src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs b/src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs
new file mode 100644
index 00000000000..c46031de2d8
--- /dev/null
+++ b/src/tools/miri/tests/pass/issues/issue-miri-3541-dyn-vtable-trait-normalization.rs
@@ -0,0 +1,40 @@
+#![feature(ptr_metadata)]
+// This test is the result of minimizing the `emplacable` crate to reproduce
+// <https://github.com/rust-lang/miri/issues/3541>.
+
+use std::{ops::FnMut, ptr::Pointee};
+
+pub type EmplacerFn<'a, T> = dyn for<'b> FnMut(<T as Pointee>::Metadata) + 'a;
+
+#[repr(transparent)]
+pub struct Emplacer<'a, T>(EmplacerFn<'a, T>)
+where
+    T: ?Sized;
+
+impl<'a, T> Emplacer<'a, T>
+where
+    T: ?Sized,
+{
+    pub unsafe fn from_fn<'b>(emplacer_fn: &'b mut EmplacerFn<'a, T>) -> &'b mut Self {
+        // This used to trigger:
+        // constructing invalid value: wrong trait in wide pointer vtable: expected
+        // `std::ops::FnMut(<[std::boxed::Box<i32>] as std::ptr::Pointee>::Metadata)`, but encountered
+        // `std::ops::FnMut<(usize,)>`.
+        unsafe { &mut *((emplacer_fn as *mut EmplacerFn<'a, T>) as *mut Self) }
+    }
+}
+
+pub fn box_new_with<T>()
+where
+    T: ?Sized,
+{
+    let emplacer_closure = &mut |_meta| {
+        unreachable!();
+    };
+
+    unsafe { Emplacer::<T>::from_fn(emplacer_closure) };
+}
+
+fn main() {
+    box_new_with::<[Box<i32>]>();
+}