about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_middle/src/dep_graph/dep_node.rs2
-rw-r--r--compiler/rustc_middle/src/query/mod.rs4
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs52
-rw-r--r--compiler/rustc_typeck/src/collect.rs3
-rw-r--r--src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs56
5 files changed, 102 insertions, 15 deletions
diff --git a/compiler/rustc_middle/src/dep_graph/dep_node.rs b/compiler/rustc_middle/src/dep_graph/dep_node.rs
index aa54d1ae7b9..8476929eaec 100644
--- a/compiler/rustc_middle/src/dep_graph/dep_node.rs
+++ b/compiler/rustc_middle/src/dep_graph/dep_node.rs
@@ -285,7 +285,7 @@ pub type DepNode = rustc_query_system::dep_graph::DepNode<DepKind>;
 // required that their size stay the same, but we don't want to change
 // it inadvertently. This assert just ensures we're aware of any change.
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
-static_assert_size!(DepNode, 17);
+static_assert_size!(DepNode, 18);
 
 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
 static_assert_size!(DepNode, 24);
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 9a2f1149316..d218172282d 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -916,6 +916,10 @@ rustc_queries! {
         desc { |tcx| "looking up const stability of `{}`", tcx.def_path_str(def_id) }
     }
 
+    query should_inherit_track_caller(def_id: DefId) -> bool {
+        desc { |tcx| "computing should_inherit_track_caller of `{}`", tcx.def_path_str(def_id) }
+    }
+
     query lookup_deprecation_entry(def_id: DefId) -> Option<DeprecationEntry> {
         desc { |tcx| "checking whether `{}` is deprecated", tcx.def_path_str(def_id) }
     }
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index 41d953216e0..261a19f862e 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -227,8 +227,9 @@ impl<'tcx> InstanceDef<'tcx> {
 
     pub fn requires_caller_location(&self, tcx: TyCtxt<'_>) -> bool {
         match *self {
-            InstanceDef::Item(def) => {
-                tcx.codegen_fn_attrs(def.did).flags.contains(CodegenFnAttrFlags::TRACK_CALLER)
+            InstanceDef::Item(ty::WithOptConstParam { did: def_id, .. })
+            | InstanceDef::Virtual(def_id, _) => {
+                tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::TRACK_CALLER)
             }
             _ => false,
         }
@@ -403,7 +404,7 @@ impl<'tcx> Instance<'tcx> {
         def_id: DefId,
         substs: SubstsRef<'tcx>,
     ) -> Option<Instance<'tcx>> {
-        debug!("resolve(def_id={:?}, substs={:?})", def_id, substs);
+        debug!("resolve_for_vtable(def_id={:?}, substs={:?})", def_id, substs);
         let fn_sig = tcx.fn_sig(def_id);
         let is_vtable_shim = !fn_sig.inputs().skip_binder().is_empty()
             && fn_sig.input(0).skip_binder().is_param(0)
@@ -412,7 +413,50 @@ impl<'tcx> Instance<'tcx> {
             debug!(" => associated item with unsizeable self: Self");
             Some(Instance { def: InstanceDef::VtableShim(def_id), substs })
         } else {
-            Instance::resolve_for_fn_ptr(tcx, param_env, def_id, substs)
+            Instance::resolve(tcx, param_env, def_id, substs).ok().flatten().map(|mut resolved| {
+                match resolved.def {
+                    InstanceDef::Item(def) => {
+                        // We need to generate a shim when we cannot guarantee that
+                        // the caller of a trait object method will be aware of
+                        // `#[track_caller]` - this ensures that the caller
+                        // and callee ABI will always match.
+                        //
+                        // The shim is generated when all of these conditions are met:
+                        //
+                        // 1) The underlying method expects a caller location parameter
+                        // in the ABI
+                        if resolved.def.requires_caller_location(tcx)
+                            // 2) The caller location parameter comes from having `#[track_caller]`
+                            // on the implementation, and *not* on the trait method.
+                            && !tcx.should_inherit_track_caller(def.did)
+                            // If the method implementation comes from the trait definition itself
+                            // (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
+                            // then we don't need to generate a shim. This check is needed because
+                            // `should_inherit_track_caller` returns `false` if our method
+                            // implementation comes from the trait block, and not an impl block
+                            && !matches!(
+                                tcx.opt_associated_item(def.did),
+                                Some(ty::AssocItem {
+                                    container: ty::AssocItemContainer::TraitContainer(_),
+                                    ..
+                                })
+                            )
+                        {
+                            debug!(
+                                " => vtable fn pointer created for function with #[track_caller]"
+                            );
+                            resolved.def = InstanceDef::ReifyShim(def.did);
+                        }
+                    }
+                    InstanceDef::Virtual(def_id, _) => {
+                        debug!(" => vtable fn pointer created for virtual call");
+                        resolved.def = InstanceDef::ReifyShim(def_id);
+                    }
+                    _ => {}
+                }
+
+                resolved
+            })
         }
     }
 
diff --git a/compiler/rustc_typeck/src/collect.rs b/compiler/rustc_typeck/src/collect.rs
index 583ba9392f0..506ca98b960 100644
--- a/compiler/rustc_typeck/src/collect.rs
+++ b/compiler/rustc_typeck/src/collect.rs
@@ -92,6 +92,7 @@ pub fn provide(providers: &mut Providers) {
         generator_kind,
         codegen_fn_attrs,
         collect_mod_item_types,
+        should_inherit_track_caller,
         ..*providers
     };
 }
@@ -2686,7 +2687,7 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, id: DefId) -> CodegenFnAttrs {
     let attrs = tcx.get_attrs(id);
 
     let mut codegen_fn_attrs = CodegenFnAttrs::new();
-    if should_inherit_track_caller(tcx, id) {
+    if tcx.should_inherit_track_caller(id) {
         codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
     }
 
diff --git a/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs b/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs
index 3b2a2238fa8..06883a85790 100644
--- a/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs
+++ b/src/test/ui/rfc-2091-track-caller/tracked-trait-obj.rs
@@ -2,22 +2,60 @@
 
 trait Tracked {
     #[track_caller]
-    fn handle(&self) {
+    fn track_caller_trait_method(&self, line: u32, col: u32) {
         let location = std::panic::Location::caller();
         assert_eq!(location.file(), file!());
-        // we only call this via trait object, so the def site should *always* be returned
-        assert_eq!(location.line(), line!() - 4);
-        assert_eq!(location.column(), 5);
+        // The trait method definition is annotated with `#[track_caller]`,
+        // so caller location information will work through a method
+        // call on a trait object
+        assert_eq!(location.line(), line, "Bad line");
+        assert_eq!(location.column(), col, "Bad col");
     }
+
+    fn track_caller_not_on_trait_method(&self);
+
+    #[track_caller]
+    fn track_caller_through_self(self: Box<Self>, line: u32, col: u32);
 }
 
-impl Tracked for () {}
-impl Tracked for u8 {}
+impl Tracked for () {
+    // We have `#[track_caller]` on the implementation of the method,
+    // but not on the definition of the method in the trait. Therefore,
+    // caller location information will *not* work through a method call
+    // on a trait object. Instead, we will get the location of this method
+    #[track_caller]
+    fn track_caller_not_on_trait_method(&self) {
+        let location = std::panic::Location::caller();
+        assert_eq!(location.file(), file!());
+        assert_eq!(location.line(), line!() - 3);
+        assert_eq!(location.column(), 5);
+    }
+
+    // We don't have a `#[track_caller]` attribute, but
+    // `#[track_caller]` is present on the trait definition,
+    // so we'll still get location information
+    fn track_caller_through_self(self: Box<Self>, line: u32, col: u32) {
+        let location = std::panic::Location::caller();
+        assert_eq!(location.file(), file!());
+        // The trait method definition is annotated with `#[track_caller]`,
+        // so caller location information will work through a method
+        // call on a trait object
+        assert_eq!(location.line(), line, "Bad line");
+        assert_eq!(location.column(), col, "Bad col");
+    }
+}
 
 fn main() {
-    let tracked: &dyn Tracked = &5u8;
-    tracked.handle();
+    let tracked: &dyn Tracked = &();
+    // The column is the start of 'track_caller_trait_method'
+    tracked.track_caller_trait_method(line!(), 13);
 
     const TRACKED: &dyn Tracked = &();
-    TRACKED.handle();
+    // The column is the start of 'track_caller_trait_method'
+    TRACKED.track_caller_trait_method(line!(), 13);
+    TRACKED.track_caller_not_on_trait_method();
+
+    // The column is the start of `track_caller_through_self`
+    let boxed: Box<dyn Tracked> = Box::new(());
+    boxed.track_caller_through_self(line!(), 11);
 }