about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src')
-rw-r--r--compiler/rustc_const_eval/src/interpret/cast.rs41
-rw-r--r--compiler/rustc_const_eval/src/interpret/eval_context.rs30
-rw-r--r--compiler/rustc_const_eval/src/interpret/terminator.rs17
-rw-r--r--compiler/rustc_const_eval/src/interpret/traits.rs48
4 files changed, 95 insertions, 41 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs
index 83b61ab1749..bd2a5812cfa 100644
--- a/compiler/rustc_const_eval/src/interpret/cast.rs
+++ b/compiler/rustc_const_eval/src/interpret/cast.rs
@@ -401,15 +401,46 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
             }
             (ty::Dynamic(data_a, _, ty::Dyn), ty::Dynamic(data_b, _, ty::Dyn)) => {
                 let val = self.read_immediate(src)?;
-                if data_a.principal() == data_b.principal() {
-                    // A NOP cast that doesn't actually change anything, should be allowed even with mismatching vtables.
-                    // (But currently mismatching vtables violate the validity invariant so UB is triggered anyway.)
-                    return self.write_immediate(*val, dest);
-                }
+                // Take apart the old pointer, and find the dynamic type.
                 let (old_data, old_vptr) = val.to_scalar_pair();
                 let old_data = old_data.to_pointer(self)?;
                 let old_vptr = old_vptr.to_pointer(self)?;
                 let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;
+
+                // Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces
+                // our destination trait.
+                if cfg!(debug_assertions) {
+                    let vptr_entry_idx =
+                        self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
+                    let vtable_entries = self.vtable_entries(data_a.principal(), ty);
+                    if let Some(entry_idx) = vptr_entry_idx {
+                        let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
+                            vtable_entries.get(entry_idx)
+                        else {
+                            span_bug!(
+                                self.cur_span(),
+                                "invalid vtable entry index in {} -> {} upcast",
+                                src_pointee_ty,
+                                dest_pointee_ty
+                            );
+                        };
+                        let erased_trait_ref = upcast_trait_ref
+                            .map_bound(|r| ty::ExistentialTraitRef::erase_self_ty(*self.tcx, r));
+                        assert!(
+                            data_b
+                                .principal()
+                                .is_some_and(|b| self.eq_in_param_env(erased_trait_ref, b))
+                        );
+                    } else {
+                        // In this case codegen would keep using the old vtable. We don't want to do
+                        // that as it has the wrong trait. The reason codegen can do this is that
+                        // one vtable is a prefix of the other, so we double-check that.
+                        let vtable_entries_b = self.vtable_entries(data_b.principal(), ty);
+                        assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b);
+                    };
+                }
+
+                // Get the destination trait vtable and return that.
                 let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
                 self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
             }
diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs
index 6d3e5ea1031..9fddeec2973 100644
--- a/compiler/rustc_const_eval/src/interpret/eval_context.rs
+++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs
@@ -2,11 +2,15 @@ use std::cell::Cell;
 use std::{fmt, mem};
 
 use either::{Either, Left, Right};
+use rustc_infer::infer::at::ToTrace;
+use rustc_infer::traits::ObligationCause;
+use rustc_trait_selection::traits::ObligationCtxt;
 use tracing::{debug, info, info_span, instrument, trace};
 
 use rustc_errors::DiagCtxtHandle;
 use rustc_hir::{self as hir, def_id::DefId, definitions::DefPathData};
 use rustc_index::IndexVec;
+use rustc_infer::infer::TyCtxtInferExt;
 use rustc_middle::mir;
 use rustc_middle::mir::interpret::{
     CtfeProvenance, ErrorHandled, InvalidMetaKind, ReportedErrorInfo,
@@ -640,6 +644,32 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         }
     }
 
+    /// Check if the two things are equal in the current param_env, using an infctx to get proper
+    /// equality checks.
+    pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
+    where
+        T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
+    {
+        // Fast path: compare directly.
+        if a == b {
+            return true;
+        }
+        // Slow path: spin up an inference context to check if these traits are sufficiently equal.
+        let infcx = self.tcx.infer_ctxt().build();
+        let ocx = ObligationCtxt::new(&infcx);
+        let cause = ObligationCause::dummy_with_span(self.cur_span());
+        // equate the two trait refs after normalization
+        let a = ocx.normalize(&cause, self.param_env, a);
+        let b = ocx.normalize(&cause, self.param_env, b);
+        if ocx.eq(&cause, self.param_env, a, b).is_ok() {
+            if ocx.select_all_or_error().is_empty() {
+                // All good.
+                return true;
+            }
+        }
+        return false;
+    }
+
     /// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
     /// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
     /// and is primarily intended for the panic machinery.
diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs
index 25f6bd64055..56d3dc94104 100644
--- a/compiler/rustc_const_eval/src/interpret/terminator.rs
+++ b/compiler/rustc_const_eval/src/interpret/terminator.rs
@@ -1,7 +1,6 @@
 use std::borrow::Cow;
 
 use either::Either;
-use rustc_middle::ty::TyCtxt;
 use tracing::trace;
 
 use rustc_middle::{
@@ -867,7 +866,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
                 };
 
                 // Obtain the underlying trait we are working on, and the adjusted receiver argument.
-                let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
+                let (trait_, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
                     receiver_place.layout.ty.kind()
                 {
                     let recv = self.unpack_dyn_star(&receiver_place, data)?;
@@ -898,20 +897,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
                     (receiver_trait.principal(), dyn_ty, receiver_place.ptr())
                 };
 
-                // Now determine the actual method to call. We can do that in two different ways and
-                // compare them to ensure everything fits.
-                let vtable_entries = if let Some(dyn_trait) = dyn_trait {
-                    let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
-                    let trait_ref = self.tcx.erase_regions(trait_ref);
-                    self.tcx.vtable_entries(trait_ref)
-                } else {
-                    TyCtxt::COMMON_VTABLE_ENTRIES
-                };
+                // Now determine the actual method to call. Usually we use the easy way of just
+                // looking up the method at index `idx`.
+                let vtable_entries = self.vtable_entries(trait_, dyn_ty);
                 let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
                     // FIXME(fee1-dead) these could be variants of the UB info enum instead of this
                     throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
                 };
                 trace!("Virtual call dispatches to {fn_inst:#?}");
+                // We can also do the lookup based on `def_id` and `dyn_ty`, and check that that
+                // produces the same result.
                 if cfg!(debug_assertions) {
                     let tcx = *self.tcx;
 
diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs
index bd2c6519421..fb50661b826 100644
--- a/compiler/rustc_const_eval/src/interpret/traits.rs
+++ b/compiler/rustc_const_eval/src/interpret/traits.rs
@@ -1,10 +1,7 @@
-use rustc_infer::infer::TyCtxtInferExt;
-use rustc_infer::traits::ObligationCause;
 use rustc_middle::mir::interpret::{InterpResult, Pointer};
 use rustc_middle::ty::layout::LayoutOf;
-use rustc_middle::ty::{self, Ty};
+use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
 use rustc_target::abi::{Align, Size};
-use rustc_trait_selection::traits::ObligationCtxt;
 use tracing::trace;
 
 use super::util::ensure_monomorphic_enough;
@@ -47,6 +44,20 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         Ok((layout.size, layout.align.abi))
     }
 
+    pub(super) fn vtable_entries(
+        &self,
+        trait_: Option<ty::PolyExistentialTraitRef<'tcx>>,
+        dyn_ty: Ty<'tcx>,
+    ) -> &'tcx [VtblEntry<'tcx>] {
+        if let Some(trait_) = trait_ {
+            let trait_ref = trait_.with_self_ty(*self.tcx, dyn_ty);
+            let trait_ref = self.tcx.erase_regions(trait_ref);
+            self.tcx.vtable_entries(trait_ref)
+        } else {
+            TyCtxt::COMMON_VTABLE_ENTRIES
+        }
+    }
+
     /// Check that the given vtable trait is valid for a pointer/reference/place with the given
     /// expected trait type.
     pub(super) fn check_vtable_for_type(
@@ -54,28 +65,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
         expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
     ) -> InterpResult<'tcx> {
-        // Fast path: if they are equal, it's all fine.
-        if expected_trait.principal() == vtable_trait {
-            return Ok(());
-        }
-        if let (Some(expected_trait), Some(vtable_trait)) =
-            (expected_trait.principal(), vtable_trait)
-        {
-            // Slow path: spin up an inference context to check if these traits are sufficiently equal.
-            let infcx = self.tcx.infer_ctxt().build();
-            let ocx = ObligationCtxt::new(&infcx);
-            let cause = ObligationCause::dummy_with_span(self.cur_span());
-            // equate the two trait refs after normalization
-            let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
-            let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
-            if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
-                if ocx.select_all_or_error().is_empty() {
-                    // All good.
-                    return Ok(());
-                }
-            }
+        let eq = match (expected_trait.principal(), vtable_trait) {
+            (Some(a), Some(b)) => self.eq_in_param_env(a, b),
+            (None, None) => true,
+            _ => false,
+        };
+        if !eq {
+            throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
         }
-        throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
+        Ok(())
     }
 
     /// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.