about summary refs log tree commit diff
path: root/compiler/rustc_middle/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_middle/src')
-rw-r--r--compiler/rustc_middle/src/mir/visit.rs2
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs66
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs2
-rw-r--r--compiler/rustc_middle/src/ty/structural_impls.rs1
4 files changed, 61 insertions, 10 deletions
diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs
index 3835bd371d9..4f7b2f7cbe4 100644
--- a/compiler/rustc_middle/src/mir/visit.rs
+++ b/compiler/rustc_middle/src/mir/visit.rs
@@ -341,7 +341,7 @@ macro_rules! make_mir_visitor {
 
                         ty::InstanceDef::Intrinsic(_def_id) |
                         ty::InstanceDef::VTableShim(_def_id) |
-                        ty::InstanceDef::ReifyShim(_def_id) |
+                        ty::InstanceDef::ReifyShim(_def_id, _) |
                         ty::InstanceDef::Virtual(_def_id, _) |
                         ty::InstanceDef::ThreadLocalShim(_def_id) |
                         ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index 4fec5653a79..4a7720b38f8 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -31,6 +31,28 @@ pub struct Instance<'tcx> {
     pub args: GenericArgsRef<'tcx>,
 }
 
+/// Describes why a `ReifyShim` was created. This is needed to distingish a ReifyShim created to
+/// adjust for things like `#[track_caller]` in a vtable from a `ReifyShim` created to produce a
+/// function pointer from a vtable entry.
+/// Currently, this is only used when KCFI is enabled, as only KCFI needs to treat those two
+/// `ReifyShim`s differently.
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
+#[derive(TyEncodable, TyDecodable, HashStable)]
+pub enum ReifyReason {
+    /// The `ReifyShim` was created to produce a function pointer. This happens when:
+    /// * A vtable entry is directly converted to a function call (e.g. creating a fn ptr from a
+    ///   method on a `dyn` object).
+    /// * A function with `#[track_caller]` is converted to a function pointer
+    /// * If KCFI is enabled, creating a function pointer from a method on an object-safe trait.
+    /// This includes the case of converting `::call`-like methods on closure-likes to function
+    /// pointers.
+    FnPtr,
+    /// This `ReifyShim` was created to populate a vtable. Currently, this happens when a
+    /// `#[track_caller]` mismatch occurs between the implementation of a method and the method.
+    /// This includes the case of `::call`-like methods in closure-likes' vtables.
+    Vtable,
+}
+
 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
 #[derive(TyEncodable, TyDecodable, HashStable, TypeFoldable, TypeVisitable, Lift)]
 pub enum InstanceDef<'tcx> {
@@ -67,7 +89,13 @@ pub enum InstanceDef<'tcx> {
     /// Because this is a required part of the function's ABI but can't be tracked
     /// as a property of the function pointer, we use a single "caller location"
     /// (the definition of the function itself).
-    ReifyShim(DefId),
+    ///
+    /// The second field encodes *why* this shim was created. This allows distinguishing between
+    /// a `ReifyShim` that appears in a vtable vs one that appears as a function pointer.
+    ///
+    /// This field will only be populated if we are compiling in a mode that needs these shims
+    /// to be separable, currently only when KCFI is enabled.
+    ReifyShim(DefId, Option<ReifyReason>),
 
     /// `<fn() as FnTrait>::call_*` (generated `FnTrait` implementation for `fn()` pointers).
     ///
@@ -194,7 +222,7 @@ impl<'tcx> InstanceDef<'tcx> {
         match self {
             InstanceDef::Item(def_id)
             | InstanceDef::VTableShim(def_id)
-            | InstanceDef::ReifyShim(def_id)
+            | InstanceDef::ReifyShim(def_id, _)
             | InstanceDef::FnPtrShim(def_id, _)
             | InstanceDef::Virtual(def_id, _)
             | InstanceDef::Intrinsic(def_id)
@@ -354,7 +382,9 @@ fn fmt_instance(
     match instance.def {
         InstanceDef::Item(_) => Ok(()),
         InstanceDef::VTableShim(_) => write!(f, " - shim(vtable)"),
-        InstanceDef::ReifyShim(_) => write!(f, " - shim(reify)"),
+        InstanceDef::ReifyShim(_, None) => write!(f, " - shim(reify)"),
+        InstanceDef::ReifyShim(_, Some(ReifyReason::FnPtr)) => write!(f, " - shim(reify-fnptr)"),
+        InstanceDef::ReifyShim(_, Some(ReifyReason::Vtable)) => write!(f, " - shim(reify-vtable)"),
         InstanceDef::ThreadLocalShim(_) => write!(f, " - shim(tls)"),
         InstanceDef::Intrinsic(_) => write!(f, " - intrinsic"),
         InstanceDef::Virtual(_, num) => write!(f, " - virtual#{num}"),
@@ -476,15 +506,34 @@ impl<'tcx> Instance<'tcx> {
         debug!("resolve(def_id={:?}, args={:?})", def_id, args);
         // Use either `resolve_closure` or `resolve_for_vtable`
         assert!(!tcx.is_closure_like(def_id), "Called `resolve_for_fn_ptr` on closure: {def_id:?}");
+        let reason = tcx.sess.is_sanitizer_kcfi_enabled().then_some(ReifyReason::FnPtr);
         Instance::resolve(tcx, param_env, def_id, args).ok().flatten().map(|mut resolved| {
             match resolved.def {
                 InstanceDef::Item(def) if resolved.def.requires_caller_location(tcx) => {
                     debug!(" => fn pointer created for function with #[track_caller]");
-                    resolved.def = InstanceDef::ReifyShim(def);
+                    resolved.def = InstanceDef::ReifyShim(def, reason);
                 }
                 InstanceDef::Virtual(def_id, _) => {
                     debug!(" => fn pointer created for virtual call");
-                    resolved.def = InstanceDef::ReifyShim(def_id);
+                    resolved.def = InstanceDef::ReifyShim(def_id, reason);
+                }
+                // Reify `Trait::method` implementations if KCFI is enabled
+                // FIXME(maurer) only reify it if it is a vtable-safe function
+                _ if tcx.sess.is_sanitizer_kcfi_enabled()
+                    && tcx.associated_item(def_id).trait_item_def_id.is_some() =>
+                {
+                    // If this function could also go in a vtable, we need to `ReifyShim` it with
+                    // KCFI because it can only attach one type per function.
+                    resolved.def = InstanceDef::ReifyShim(resolved.def_id(), reason)
+                }
+                // Reify `::call`-like method implementations if KCFI is enabled
+                _ if tcx.sess.is_sanitizer_kcfi_enabled()
+                    && tcx.is_closure_like(resolved.def_id()) =>
+                {
+                    // Reroute through a reify via the *unresolved* instance. The resolved one can't
+                    // be directly reified because it's closure-like. The reify can handle the
+                    // unresolved instance.
+                    resolved = Instance { def: InstanceDef::ReifyShim(def_id, reason), args }
                 }
                 _ => {}
             }
@@ -508,6 +557,7 @@ impl<'tcx> Instance<'tcx> {
             debug!(" => associated item with unsizeable self: Self");
             Some(Instance { def: InstanceDef::VTableShim(def_id), args })
         } else {
+            let reason = tcx.sess.is_sanitizer_kcfi_enabled().then_some(ReifyReason::Vtable);
             Instance::resolve(tcx, param_env, def_id, args).ok().flatten().map(|mut resolved| {
                 match resolved.def {
                     InstanceDef::Item(def) => {
@@ -544,18 +594,18 @@ impl<'tcx> Instance<'tcx> {
                                 // Create a shim for the `FnOnce/FnMut/Fn` method we are calling
                                 // - unlike functions, invoking a closure always goes through a
                                 // trait.
-                                resolved = Instance { def: InstanceDef::ReifyShim(def_id), args };
+                                resolved = Instance { def: InstanceDef::ReifyShim(def_id, reason), args };
                             } else {
                                 debug!(
                                     " => vtable fn pointer created for function with #[track_caller]: {:?}", def
                                 );
-                                resolved.def = InstanceDef::ReifyShim(def);
+                                resolved.def = InstanceDef::ReifyShim(def, reason);
                             }
                         }
                     }
                     InstanceDef::Virtual(def_id, _) => {
                         debug!(" => vtable fn pointer created for virtual call");
-                        resolved.def = InstanceDef::ReifyShim(def_id);
+                        resolved.def = InstanceDef::ReifyShim(def_id, reason)
                     }
                     _ => {}
                 }
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 62bfa0ff928..fb56c71c517 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -88,7 +88,7 @@ pub use self::context::{
     tls, CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift,
     TyCtxt, TyCtxtFeed,
 };
-pub use self::instance::{Instance, InstanceDef, ShortInstance, UnusedGenericParams};
+pub use self::instance::{Instance, InstanceDef, ReifyReason, ShortInstance, UnusedGenericParams};
 pub use self::list::List;
 pub use self::parameterized::ParameterizedOverTcx;
 pub use self::predicate::{
diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs
index a62379def53..0e7010e67d7 100644
--- a/compiler/rustc_middle/src/ty/structural_impls.rs
+++ b/compiler/rustc_middle/src/ty/structural_impls.rs
@@ -449,6 +449,7 @@ TrivialTypeTraversalAndLiftImpls! {
     crate::ty::ClosureKind,
     crate::ty::ParamConst,
     crate::ty::ParamTy,
+    crate::ty::instance::ReifyReason,
     interpret::AllocId,
     interpret::CtfeProvenance,
     interpret::Scalar,