diff options
Diffstat (limited to 'compiler/rustc_const_eval/src')
| -rw-r--r-- | compiler/rustc_const_eval/src/interpret/cast.rs | 41 | ||||
| -rw-r--r-- | compiler/rustc_const_eval/src/interpret/eval_context.rs | 30 | ||||
| -rw-r--r-- | compiler/rustc_const_eval/src/interpret/terminator.rs | 17 | ||||
| -rw-r--r-- | compiler/rustc_const_eval/src/interpret/traits.rs | 48 |
4 files changed, 95 insertions, 41 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs index 83b61ab1749..bd2a5812cfa 100644 --- a/compiler/rustc_const_eval/src/interpret/cast.rs +++ b/compiler/rustc_const_eval/src/interpret/cast.rs @@ -401,15 +401,46 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { } (ty::Dynamic(data_a, _, ty::Dyn), ty::Dynamic(data_b, _, ty::Dyn)) => { let val = self.read_immediate(src)?; - if data_a.principal() == data_b.principal() { - // A NOP cast that doesn't actually change anything, should be allowed even with mismatching vtables. - // (But currently mismatching vtables violate the validity invariant so UB is triggered anyway.) - return self.write_immediate(*val, dest); - } + // Take apart the old pointer, and find the dynamic type. let (old_data, old_vptr) = val.to_scalar_pair(); let old_data = old_data.to_pointer(self)?; let old_vptr = old_vptr.to_pointer(self)?; let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?; + + // Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces + // our destination trait. + if cfg!(debug_assertions) { + let vptr_entry_idx = + self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty)); + let vtable_entries = self.vtable_entries(data_a.principal(), ty); + if let Some(entry_idx) = vptr_entry_idx { + let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) = + vtable_entries.get(entry_idx) + else { + span_bug!( + self.cur_span(), + "invalid vtable entry index in {} -> {} upcast", + src_pointee_ty, + dest_pointee_ty + ); + }; + let erased_trait_ref = upcast_trait_ref + .map_bound(|r| ty::ExistentialTraitRef::erase_self_ty(*self.tcx, r)); + assert!( + data_b + .principal() + .is_some_and(|b| self.eq_in_param_env(erased_trait_ref, b)) + ); + } else { + // In this case codegen would keep using the old vtable. We don't want to do + // that as it has the wrong trait. The reason codegen can do this is that + // one vtable is a prefix of the other, so we double-check that. + let vtable_entries_b = self.vtable_entries(data_b.principal(), ty); + assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b); + }; + } + + // Get the destination trait vtable and return that. let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?; self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest) } diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs index 6d3e5ea1031..9fddeec2973 100644 --- a/compiler/rustc_const_eval/src/interpret/eval_context.rs +++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs @@ -2,11 +2,15 @@ use std::cell::Cell; use std::{fmt, mem}; use either::{Either, Left, Right}; +use rustc_infer::infer::at::ToTrace; +use rustc_infer::traits::ObligationCause; +use rustc_trait_selection::traits::ObligationCtxt; use tracing::{debug, info, info_span, instrument, trace}; use rustc_errors::DiagCtxtHandle; use rustc_hir::{self as hir, def_id::DefId, definitions::DefPathData}; use rustc_index::IndexVec; +use rustc_infer::infer::TyCtxtInferExt; use rustc_middle::mir; use rustc_middle::mir::interpret::{ CtfeProvenance, ErrorHandled, InvalidMetaKind, ReportedErrorInfo, @@ -640,6 +644,32 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { } } + /// Check if the two things are equal in the current param_env, using an infctx to get proper + /// equality checks. + pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool + where + T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>, + { + // Fast path: compare directly. + if a == b { + return true; + } + // Slow path: spin up an inference context to check if these traits are sufficiently equal. + let infcx = self.tcx.infer_ctxt().build(); + let ocx = ObligationCtxt::new(&infcx); + let cause = ObligationCause::dummy_with_span(self.cur_span()); + // equate the two trait refs after normalization + let a = ocx.normalize(&cause, self.param_env, a); + let b = ocx.normalize(&cause, self.param_env, b); + if ocx.eq(&cause, self.param_env, a, b).is_ok() { + if ocx.select_all_or_error().is_empty() { + // All good. + return true; + } + } + return false; + } + /// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a /// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic, /// and is primarily intended for the panic machinery. diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs index 25f6bd64055..56d3dc94104 100644 --- a/compiler/rustc_const_eval/src/interpret/terminator.rs +++ b/compiler/rustc_const_eval/src/interpret/terminator.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; use either::Either; -use rustc_middle::ty::TyCtxt; use tracing::trace; use rustc_middle::{ @@ -867,7 +866,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { }; // Obtain the underlying trait we are working on, and the adjusted receiver argument. - let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) = + let (trait_, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) = receiver_place.layout.ty.kind() { let recv = self.unpack_dyn_star(&receiver_place, data)?; @@ -898,20 +897,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { (receiver_trait.principal(), dyn_ty, receiver_place.ptr()) }; - // Now determine the actual method to call. We can do that in two different ways and - // compare them to ensure everything fits. - let vtable_entries = if let Some(dyn_trait) = dyn_trait { - let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty); - let trait_ref = self.tcx.erase_regions(trait_ref); - self.tcx.vtable_entries(trait_ref) - } else { - TyCtxt::COMMON_VTABLE_ENTRIES - }; + // Now determine the actual method to call. Usually we use the easy way of just + // looking up the method at index `idx`. + let vtable_entries = self.vtable_entries(trait_, dyn_ty); let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else { // FIXME(fee1-dead) these could be variants of the UB info enum instead of this throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method); }; trace!("Virtual call dispatches to {fn_inst:#?}"); + // We can also do the lookup based on `def_id` and `dyn_ty`, and check that that + // produces the same result. if cfg!(debug_assertions) { let tcx = *self.tcx; diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs index bd2c6519421..fb50661b826 100644 --- a/compiler/rustc_const_eval/src/interpret/traits.rs +++ b/compiler/rustc_const_eval/src/interpret/traits.rs @@ -1,10 +1,7 @@ -use rustc_infer::infer::TyCtxtInferExt; -use rustc_infer::traits::ObligationCause; use rustc_middle::mir::interpret::{InterpResult, Pointer}; use rustc_middle::ty::layout::LayoutOf; -use rustc_middle::ty::{self, Ty}; +use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry}; use rustc_target::abi::{Align, Size}; -use rustc_trait_selection::traits::ObligationCtxt; use tracing::trace; use super::util::ensure_monomorphic_enough; @@ -47,6 +44,20 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { Ok((layout.size, layout.align.abi)) } + pub(super) fn vtable_entries( + &self, + trait_: Option<ty::PolyExistentialTraitRef<'tcx>>, + dyn_ty: Ty<'tcx>, + ) -> &'tcx [VtblEntry<'tcx>] { + if let Some(trait_) = trait_ { + let trait_ref = trait_.with_self_ty(*self.tcx, dyn_ty); + let trait_ref = self.tcx.erase_regions(trait_ref); + self.tcx.vtable_entries(trait_ref) + } else { + TyCtxt::COMMON_VTABLE_ENTRIES + } + } + /// Check that the given vtable trait is valid for a pointer/reference/place with the given /// expected trait type. pub(super) fn check_vtable_for_type( @@ -54,28 +65,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>, expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>, ) -> InterpResult<'tcx> { - // Fast path: if they are equal, it's all fine. - if expected_trait.principal() == vtable_trait { - return Ok(()); - } - if let (Some(expected_trait), Some(vtable_trait)) = - (expected_trait.principal(), vtable_trait) - { - // Slow path: spin up an inference context to check if these traits are sufficiently equal. - let infcx = self.tcx.infer_ctxt().build(); - let ocx = ObligationCtxt::new(&infcx); - let cause = ObligationCause::dummy_with_span(self.cur_span()); - // equate the two trait refs after normalization - let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait); - let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait); - if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() { - if ocx.select_all_or_error().is_empty() { - // All good. - return Ok(()); - } - } + let eq = match (expected_trait.principal(), vtable_trait) { + (Some(a), Some(b)) => self.eq_in_param_env(a, b), + (None, None) => true, + _ => false, + }; + if !eq { + throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait }); } - throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait }); + Ok(()) } /// Turn a place with a `dyn Trait` type into a place with the actual dynamic type. |
