about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2022-11-17 13:47:03 +0000
committerbors <bors@rust-lang.org>2022-11-17 13:47:03 +0000
commitb6097f2e1b2ca62e188ba53cf43bd66b06b36915 (patch)
tree2e5d5c944967d1a4629252a8016b34df8da1efd9
parent36db030a7c3c51cb4484cbd8c8daebcf5057d61c (diff)
parent79c06fc595261a118cea2e5440ed98fbf5659a99 (diff)
downloadrust-b6097f2e1b2ca62e188ba53cf43bd66b06b36915.tar.gz
rust-b6097f2e1b2ca62e188ba53cf43bd66b06b36915.zip
Auto merge of #104219 - bryangarza:async-track-caller-dup, r=eholk
Support `#[track_caller]` on async fns

Adds `#[track_caller]` to the generator that is created when we desugar the async fn.

Fixes #78840

Open questions:
- What is the performance impact of adding `#[track_caller]` to every `GenFuture`'s `poll(...)` function, even if it's unused (i.e., the parent span does not set `#[track_caller]`)? We might need to set it only conditionally, if the indirection causes overhead we don't want.
-rw-r--r--compiler/rustc_ast_lowering/src/expr.rs37
-rw-r--r--compiler/rustc_ast_lowering/src/item.rs2
-rw-r--r--library/core/src/future/mod.rs1
-rw-r--r--src/test/ui/async-await/track-caller/panic-track-caller.rs76
4 files changed, 109 insertions, 7 deletions
diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs
index eb64027f369..01716859a9d 100644
--- a/compiler/rustc_ast_lowering/src/expr.rs
+++ b/compiler/rustc_ast_lowering/src/expr.rs
@@ -655,15 +655,40 @@ impl<'hir> LoweringContext<'_, 'hir> {
 
             hir::ExprKind::Closure(c)
         };
-        let generator = hir::Expr {
-            hir_id: self.lower_node_id(closure_node_id),
-            kind: generator_kind,
-            span: self.lower_span(span),
+        let parent_has_track_caller = self
+            .attrs
+            .values()
+            .find(|attrs| attrs.into_iter().find(|attr| attr.has_name(sym::track_caller)).is_some())
+            .is_some();
+        let unstable_span =
+            self.mark_span_with_reason(DesugaringKind::Async, span, self.allow_gen_future.clone());
+
+        let hir_id = if parent_has_track_caller {
+            let generator_hir_id = self.lower_node_id(closure_node_id);
+            self.lower_attrs(
+                generator_hir_id,
+                &[Attribute {
+                    kind: AttrKind::Normal(ptr::P(NormalAttr {
+                        item: AttrItem {
+                            path: Path::from_ident(Ident::new(sym::track_caller, span)),
+                            args: MacArgs::Empty,
+                            tokens: None,
+                        },
+                        tokens: None,
+                    })),
+                    id: self.tcx.sess.parse_sess.attr_id_generator.mk_attr_id(),
+                    style: AttrStyle::Outer,
+                    span: unstable_span,
+                }],
+            );
+            generator_hir_id
+        } else {
+            self.lower_node_id(closure_node_id)
         };
 
+        let generator = hir::Expr { hir_id, kind: generator_kind, span: self.lower_span(span) };
+
         // `future::from_generator`:
-        let unstable_span =
-            self.mark_span_with_reason(DesugaringKind::Async, span, self.allow_gen_future.clone());
         let gen_future = self.expr_lang_item_path(
             unstable_span,
             hir::LangItem::FromGenerator,
diff --git a/compiler/rustc_ast_lowering/src/item.rs b/compiler/rustc_ast_lowering/src/item.rs
index 756f35e901d..795ad113ef2 100644
--- a/compiler/rustc_ast_lowering/src/item.rs
+++ b/compiler/rustc_ast_lowering/src/item.rs
@@ -86,7 +86,7 @@ impl<'a, 'hir> ItemLowerer<'a, 'hir> {
             impl_trait_defs: Vec::new(),
             impl_trait_bounds: Vec::new(),
             allow_try_trait: Some([sym::try_trait_v2, sym::yeet_desugar_details][..].into()),
-            allow_gen_future: Some([sym::gen_future][..].into()),
+            allow_gen_future: Some([sym::gen_future, sym::closure_track_caller][..].into()),
             allow_into_future: Some([sym::into_future][..].into()),
             generics_def_id_map: Default::default(),
         };
diff --git a/library/core/src/future/mod.rs b/library/core/src/future/mod.rs
index 6487aa08859..107cf92c1c0 100644
--- a/library/core/src/future/mod.rs
+++ b/library/core/src/future/mod.rs
@@ -82,6 +82,7 @@ where
 
     impl<T: Generator<ResumeTy, Yield = ()>> Future for GenFuture<T> {
         type Output = T::Return;
+        #[track_caller]
         fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
             // SAFETY: Safe because we're !Unpin + !Drop, and this is just a field projection.
             let gen = unsafe { Pin::map_unchecked_mut(self, |s| &mut s.0) };
diff --git a/src/test/ui/async-await/track-caller/panic-track-caller.rs b/src/test/ui/async-await/track-caller/panic-track-caller.rs
new file mode 100644
index 00000000000..b113c56412f
--- /dev/null
+++ b/src/test/ui/async-await/track-caller/panic-track-caller.rs
@@ -0,0 +1,76 @@
+// run-pass
+// edition:2021
+// needs-unwind
+#![feature(closure_track_caller)]
+
+use std::future::Future;
+use std::panic;
+use std::sync::{Arc, Mutex};
+use std::task::{Context, Poll, Wake};
+use std::thread::{self, Thread};
+
+/// A waker that wakes up the current thread when called.
+struct ThreadWaker(Thread);
+
+impl Wake for ThreadWaker {
+    fn wake(self: Arc<Self>) {
+        self.0.unpark();
+    }
+}
+
+/// Run a future to completion on the current thread.
+fn block_on<T>(fut: impl Future<Output = T>) -> T {
+    // Pin the future so it can be polled.
+    let mut fut = Box::pin(fut);
+
+    // Create a new context to be passed to the future.
+    let t = thread::current();
+    let waker = Arc::new(ThreadWaker(t)).into();
+    let mut cx = Context::from_waker(&waker);
+
+    // Run the future to completion.
+    loop {
+        match fut.as_mut().poll(&mut cx) {
+            Poll::Ready(res) => return res,
+            Poll::Pending => thread::park(),
+        }
+    }
+}
+
+async fn bar() {
+    panic!()
+}
+
+async fn foo() {
+    bar().await
+}
+
+#[track_caller]
+async fn bar_track_caller() {
+    panic!()
+}
+
+async fn foo_track_caller() {
+    bar_track_caller().await
+}
+
+fn panicked_at(f: impl FnOnce() + panic::UnwindSafe) -> u32 {
+    let loc = Arc::new(Mutex::new(None));
+
+    let hook = panic::take_hook();
+    {
+        let loc = loc.clone();
+        panic::set_hook(Box::new(move |info| {
+            *loc.lock().unwrap() = info.location().map(|loc| loc.line())
+        }));
+    }
+    panic::catch_unwind(f).unwrap_err();
+    panic::set_hook(hook);
+    let x = loc.lock().unwrap().unwrap();
+    x
+}
+
+fn main() {
+    assert_eq!(panicked_at(|| block_on(foo())), 41);
+    assert_eq!(panicked_at(|| block_on(foo_track_caller())), 54);
+}