about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthew Maurer <mmaurer@google.com>2024-03-25 19:27:43 +0000
committerMatthew Maurer <mmaurer@google.com>2024-04-02 19:11:16 +0000
commit6aa89f684e4427a9d08e35b572f9071705105140 (patch)
tree205064196db2851129cda806a8179f67db1c934d
parent93c2bace58b36ba297f06505b55aef5b8eba954f (diff)
downloadrust-6aa89f684e4427a9d08e35b572f9071705105140.tar.gz
rust-6aa89f684e4427a9d08e35b572f9071705105140.zip
Track reason for creating a `ReifyShim`
KCFI needs to be able to tell which kind of `ReifyShim` it is examining
in order to decide whether to use a concrete type (`FnPtr` case) or an
abstract case (`Vtable` case). You can *almost* tell this from context,
but there is one case where you can't - if a trait has a method which is
*not* `#[track_caller]`, with an impl that *is* `#[track_caller]`, both
the vtable and a function pointer created from that method will be
`ReifyShim(def_id)`.

Currently, the reason is optional to ensure no additional unique
`ReifyShim`s are added without KCFI on. However, the case in which an
extra `ReifyShim` is created is sufficiently rare that this may be worth
revisiting to reduce complexity.
-rw-r--r--compiler/rustc_middle/src/mir/visit.rs2
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs48
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs2
-rw-r--r--compiler/rustc_middle/src/ty/structural_impls.rs1
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs2
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs2
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs2
-rw-r--r--compiler/rustc_symbol_mangling/src/legacy.rs12
-rw-r--r--compiler/rustc_symbol_mangling/src/v0.rs8
9 files changed, 60 insertions, 19 deletions
diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs
index 3835bd371d9..4f7b2f7cbe4 100644
--- a/compiler/rustc_middle/src/mir/visit.rs
+++ b/compiler/rustc_middle/src/mir/visit.rs
@@ -341,7 +341,7 @@ macro_rules! make_mir_visitor {
 
                         ty::InstanceDef::Intrinsic(_def_id) |
                         ty::InstanceDef::VTableShim(_def_id) |
-                        ty::InstanceDef::ReifyShim(_def_id) |
+                        ty::InstanceDef::ReifyShim(_def_id, _) |
                         ty::InstanceDef::Virtual(_def_id, _) |
                         ty::InstanceDef::ThreadLocalShim(_def_id) |
                         ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index 4fec5653a79..e5625c8a5c6 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -31,6 +31,28 @@ pub struct Instance<'tcx> {
     pub args: GenericArgsRef<'tcx>,
 }
 
+/// Describes why a `ReifyShim` was created. This is needed to distingish a ReifyShim created to
+/// adjust for things like `#[track_caller]` in a vtable from a `ReifyShim` created to produce a
+/// function pointer from a vtable entry.
+/// Currently, this is only used when KCFI is enabled, as only KCFI needs to treat those two
+/// `ReifyShim`s differently.
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
+#[derive(TyEncodable, TyDecodable, HashStable)]
+pub enum ReifyReason {
+    /// The `ReifyShim` was created to produce a function pointer. This happens when:
+    /// * A vtable entry is directly converted to a function call (e.g. creating a fn ptr from a
+    ///   method on a `dyn` object).
+    /// * A function with `#[track_caller]` is converted to a function pointer
+    /// * If KCFI is enabled, creating a function pointer from a method on an object-safe trait.
+    /// This includes the case of converting `::call`-like methods on closure-likes to function
+    /// pointers.
+    FnPtr,
+    /// This `ReifyShim` was created to populate a vtable. Currently, this happens when a
+    /// `#[track_caller]` mismatch occurs between the implementation of a method and the method.
+    /// This includes the case of `::call`-like methods in closure-likes' vtables.
+    Vtable,
+}
+
 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
 #[derive(TyEncodable, TyDecodable, HashStable, TypeFoldable, TypeVisitable, Lift)]
 pub enum InstanceDef<'tcx> {
@@ -67,7 +89,13 @@ pub enum InstanceDef<'tcx> {
     /// Because this is a required part of the function's ABI but can't be tracked
     /// as a property of the function pointer, we use a single "caller location"
     /// (the definition of the function itself).
-    ReifyShim(DefId),
+    ///
+    /// The second field encodes *why* this shim was created. This allows distinguishing between
+    /// a `ReifyShim` that appears in a vtable vs one that appears as a function pointer.
+    ///
+    /// This field will only be populated if we are compiling in a mode that needs these shims
+    /// to be separable, currently only when KCFI is enabled.
+    ReifyShim(DefId, Option<ReifyReason>),
 
     /// `<fn() as FnTrait>::call_*` (generated `FnTrait` implementation for `fn()` pointers).
     ///
@@ -194,7 +222,7 @@ impl<'tcx> InstanceDef<'tcx> {
         match self {
             InstanceDef::Item(def_id)
             | InstanceDef::VTableShim(def_id)
-            | InstanceDef::ReifyShim(def_id)
+            | InstanceDef::ReifyShim(def_id, _)
             | InstanceDef::FnPtrShim(def_id, _)
             | InstanceDef::Virtual(def_id, _)
             | InstanceDef::Intrinsic(def_id)
@@ -354,7 +382,9 @@ fn fmt_instance(
     match instance.def {
         InstanceDef::Item(_) => Ok(()),
         InstanceDef::VTableShim(_) => write!(f, " - shim(vtable)"),
-        InstanceDef::ReifyShim(_) => write!(f, " - shim(reify)"),
+        InstanceDef::ReifyShim(_, None) => write!(f, " - shim(reify)"),
+        InstanceDef::ReifyShim(_, Some(ReifyReason::FnPtr)) => write!(f, " - shim(reify-fnptr)"),
+        InstanceDef::ReifyShim(_, Some(ReifyReason::Vtable)) => write!(f, " - shim(reify-vtable)"),
         InstanceDef::ThreadLocalShim(_) => write!(f, " - shim(tls)"),
         InstanceDef::Intrinsic(_) => write!(f, " - intrinsic"),
         InstanceDef::Virtual(_, num) => write!(f, " - virtual#{num}"),
@@ -476,15 +506,16 @@ impl<'tcx> Instance<'tcx> {
         debug!("resolve(def_id={:?}, args={:?})", def_id, args);
         // Use either `resolve_closure` or `resolve_for_vtable`
         assert!(!tcx.is_closure_like(def_id), "Called `resolve_for_fn_ptr` on closure: {def_id:?}");
+        let reason = tcx.sess.is_sanitizer_kcfi_enabled().then_some(ReifyReason::FnPtr);
         Instance::resolve(tcx, param_env, def_id, args).ok().flatten().map(|mut resolved| {
             match resolved.def {
                 InstanceDef::Item(def) if resolved.def.requires_caller_location(tcx) => {
                     debug!(" => fn pointer created for function with #[track_caller]");
-                    resolved.def = InstanceDef::ReifyShim(def);
+                    resolved.def = InstanceDef::ReifyShim(def, reason);
                 }
                 InstanceDef::Virtual(def_id, _) => {
                     debug!(" => fn pointer created for virtual call");
-                    resolved.def = InstanceDef::ReifyShim(def_id);
+                    resolved.def = InstanceDef::ReifyShim(def_id, reason);
                 }
                 _ => {}
             }
@@ -508,6 +539,7 @@ impl<'tcx> Instance<'tcx> {
             debug!(" => associated item with unsizeable self: Self");
             Some(Instance { def: InstanceDef::VTableShim(def_id), args })
         } else {
+            let reason = tcx.sess.is_sanitizer_kcfi_enabled().then_some(ReifyReason::Vtable);
             Instance::resolve(tcx, param_env, def_id, args).ok().flatten().map(|mut resolved| {
                 match resolved.def {
                     InstanceDef::Item(def) => {
@@ -544,18 +576,18 @@ impl<'tcx> Instance<'tcx> {
                                 // Create a shim for the `FnOnce/FnMut/Fn` method we are calling
                                 // - unlike functions, invoking a closure always goes through a
                                 // trait.
-                                resolved = Instance { def: InstanceDef::ReifyShim(def_id), args };
+                                resolved = Instance { def: InstanceDef::ReifyShim(def_id, reason), args };
                             } else {
                                 debug!(
                                     " => vtable fn pointer created for function with #[track_caller]: {:?}", def
                                 );
-                                resolved.def = InstanceDef::ReifyShim(def);
+                                resolved.def = InstanceDef::ReifyShim(def, reason);
                             }
                         }
                     }
                     InstanceDef::Virtual(def_id, _) => {
                         debug!(" => vtable fn pointer created for virtual call");
-                        resolved.def = InstanceDef::ReifyShim(def_id);
+                        resolved.def = InstanceDef::ReifyShim(def_id, reason)
                     }
                     _ => {}
                 }
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 4e1baaec39e..eef623d77b0 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -88,7 +88,7 @@ pub use self::context::{
     tls, CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift,
     TyCtxt, TyCtxtFeed,
 };
-pub use self::instance::{Instance, InstanceDef, ShortInstance, UnusedGenericParams};
+pub use self::instance::{Instance, InstanceDef, ReifyReason, ShortInstance, UnusedGenericParams};
 pub use self::list::List;
 pub use self::parameterized::ParameterizedOverTcx;
 pub use self::predicate::{
diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs
index a62379def53..0e7010e67d7 100644
--- a/compiler/rustc_middle/src/ty/structural_impls.rs
+++ b/compiler/rustc_middle/src/ty/structural_impls.rs
@@ -449,6 +449,7 @@ TrivialTypeTraversalAndLiftImpls! {
     crate::ty::ClosureKind,
     crate::ty::ParamConst,
     crate::ty::ParamTy,
+    crate::ty::instance::ReifyReason,
     interpret::AllocId,
     interpret::CtfeProvenance,
     interpret::Scalar,
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 5f74841151c..ab9fa165a20 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -324,7 +324,7 @@ impl<'tcx> Inliner<'tcx> {
             // do not need to catch this here, we can wait until the inliner decides to continue
             // inlining a second time.
             InstanceDef::VTableShim(_)
-            | InstanceDef::ReifyShim(_)
+            | InstanceDef::ReifyShim(..)
             | InstanceDef::FnPtrShim(..)
             | InstanceDef::ClosureOnceShim { .. }
             | InstanceDef::ConstructCoroutineInClosureShim { .. }
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index f2b6dcac586..7a1340f3a55 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -84,7 +84,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
                 // again, a function item can end up getting inlined. Thus we'll be able to cause
                 // a cycle that way
                 InstanceDef::VTableShim(_)
-                | InstanceDef::ReifyShim(_)
+                | InstanceDef::ReifyShim(..)
                 | InstanceDef::FnPtrShim(..)
                 | InstanceDef::ClosureOnceShim { .. }
                 | InstanceDef::ConstructCoroutineInClosureShim { .. }
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index b60ee7649b2..eaef2b80c86 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -55,7 +55,7 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
         // a virtual call, or a direct call to a function for which
         // indirect calls must be codegen'd differently than direct ones
         // (such as `#[track_caller]`).
-        ty::InstanceDef::ReifyShim(def_id) => {
+        ty::InstanceDef::ReifyShim(def_id, _) => {
             build_call_shim(tcx, instance, None, CallKind::Direct(def_id))
         }
         ty::InstanceDef::ClosureOnceShim { call_once: _, track_caller: _ } => {
diff --git a/compiler/rustc_symbol_mangling/src/legacy.rs b/compiler/rustc_symbol_mangling/src/legacy.rs
index 1c62ce2d214..f68668a91e6 100644
--- a/compiler/rustc_symbol_mangling/src/legacy.rs
+++ b/compiler/rustc_symbol_mangling/src/legacy.rs
@@ -2,7 +2,7 @@ use rustc_data_structures::stable_hasher::{Hash64, HashStable, StableHasher};
 use rustc_hir::def_id::CrateNum;
 use rustc_hir::definitions::{DefPathData, DisambiguatedDefPathData};
 use rustc_middle::ty::print::{PrettyPrinter, Print, PrintError, Printer};
-use rustc_middle::ty::{self, Instance, Ty, TyCtxt, TypeVisitableExt};
+use rustc_middle::ty::{self, Instance, ReifyReason, Ty, TyCtxt, TypeVisitableExt};
 use rustc_middle::ty::{GenericArg, GenericArgKind};
 
 use std::fmt::{self, Write};
@@ -71,8 +71,14 @@ pub(super) fn mangle<'tcx>(
         ty::InstanceDef::VTableShim(..) => {
             printer.write_str("{{vtable-shim}}").unwrap();
         }
-        ty::InstanceDef::ReifyShim(..) => {
-            printer.write_str("{{reify-shim}}").unwrap();
+        ty::InstanceDef::ReifyShim(_, reason) => {
+            printer.write_str("{{reify-shim").unwrap();
+            match reason {
+                Some(ReifyReason::FnPtr) => printer.write_str("-fnptr").unwrap(),
+                Some(ReifyReason::Vtable) => printer.write_str("-vtable").unwrap(),
+                None => (),
+            }
+            printer.write_str("}}").unwrap();
         }
         // FIXME(async_closures): This shouldn't be needed when we fix
         // `Instance::ty`/`Instance::def_id`.
diff --git a/compiler/rustc_symbol_mangling/src/v0.rs b/compiler/rustc_symbol_mangling/src/v0.rs
index 4369f020d27..8cb5370bb4a 100644
--- a/compiler/rustc_symbol_mangling/src/v0.rs
+++ b/compiler/rustc_symbol_mangling/src/v0.rs
@@ -8,8 +8,8 @@ use rustc_hir::definitions::{DefPathData, DisambiguatedDefPathData};
 use rustc_middle::ty::layout::IntegerExt;
 use rustc_middle::ty::print::{Print, PrintError, Printer};
 use rustc_middle::ty::{
-    self, EarlyBinder, FloatTy, Instance, IntTy, Ty, TyCtxt, TypeVisitable, TypeVisitableExt,
-    UintTy,
+    self, EarlyBinder, FloatTy, Instance, IntTy, ReifyReason, Ty, TyCtxt, TypeVisitable,
+    TypeVisitableExt, UintTy,
 };
 use rustc_middle::ty::{GenericArg, GenericArgKind};
 use rustc_span::symbol::kw;
@@ -44,7 +44,9 @@ pub(super) fn mangle<'tcx>(
     let shim_kind = match instance.def {
         ty::InstanceDef::ThreadLocalShim(_) => Some("tls"),
         ty::InstanceDef::VTableShim(_) => Some("vtable"),
-        ty::InstanceDef::ReifyShim(_) => Some("reify"),
+        ty::InstanceDef::ReifyShim(_, None) => Some("reify"),
+        ty::InstanceDef::ReifyShim(_, Some(ReifyReason::FnPtr)) => Some("reify-fnptr"),
+        ty::InstanceDef::ReifyShim(_, Some(ReifyReason::Vtable)) => Some("reify-vtable"),
 
         ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
         | ty::InstanceDef::CoroutineKindShim { .. } => Some("fn_once"),