about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Holk <ericholk@microsoft.com>2023-12-08 17:00:11 -0800
committerEric Holk <ericholk@microsoft.com>2023-12-19 12:26:27 -0800
commit97df0d3657fafac6e6ca481c1d4123bf6e2633f1 (patch)
tree22c7d7314ac22444869f27eb80487a56e499ae3a
parent27d6539a46123dcdb6fae6e043b8c1c12b3e0d6f (diff)
downloadrust-97df0d3657fafac6e6ca481c1d4123bf6e2633f1.tar.gz
rust-97df0d3657fafac6e6ca481c1d4123bf6e2633f1.zip
Desugar for await loops
-rw-r--r--compiler/rustc_ast_lowering/src/expr.rs120
-rw-r--r--compiler/rustc_ast_lowering/src/lib.rs2
-rw-r--r--compiler/rustc_builtin_macros/src/assert/context.rs2
-rw-r--r--compiler/rustc_feature/src/unstable.rs2
-rw-r--r--compiler/rustc_hir/src/lang_items.rs2
-rw-r--r--compiler/rustc_span/src/symbol.rs2
-rw-r--r--library/core/src/async_iter/async_iter.rs1
-rw-r--r--tests/ui/async-await/for-await.rs24
8 files changed, 125 insertions, 30 deletions
diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs
index c2d91b09453..660cc9b70f1 100644
--- a/compiler/rustc_ast_lowering/src/expr.rs
+++ b/compiler/rustc_ast_lowering/src/expr.rs
@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
                 ),
                 ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr),
 
-                ExprKind::Paren(_) | ExprKind::ForLoop{..} => {
+                ExprKind::Paren(_) | ExprKind::ForLoop { .. } => {
                     unreachable!("already handled")
                 }
 
@@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
     /// }
     /// ```
     fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
+        let expr = self.arena.alloc(self.lower_expr_mut(expr));
+        self.make_lowered_await(await_kw_span, expr, FutureKind::Future)
+    }
+
+    /// Takes an expr that has already been lowered and generates a desugared await loop around it
+    fn make_lowered_await(
+        &mut self,
+        await_kw_span: Span,
+        expr: &'hir hir::Expr<'hir>,
+        await_kind: FutureKind,
+    ) -> hir::ExprKind<'hir> {
         let full_span = expr.span.to(await_kw_span);
 
         let is_async_gen = match self.coroutine_kind {
@@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
             }
         };
 
-        let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None);
+        let features = match await_kind {
+            FutureKind::Future => None,
+            FutureKind::AsyncIterator => Some(self.allow_for_await.clone()),
+        };
+        let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, features);
         let gen_future_span = self.mark_span_with_reason(
             DesugaringKind::Await,
             full_span,
             Some(self.allow_gen_future.clone()),
         );
-        let expr = self.lower_expr_mut(expr);
         let expr_hir_id = expr.hir_id;
 
         // Note that the name of this binding must not be changed to something else because
@@ -933,11 +947,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
                 hir::LangItem::GetContext,
                 arena_vec![self; task_context],
             );
-            let call = self.expr_call_lang_item_fn(
-                span,
-                hir::LangItem::FuturePoll,
-                arena_vec![self; new_unchecked, get_context],
-            );
+            let call = match await_kind {
+                FutureKind::Future => self.expr_call_lang_item_fn(
+                    span,
+                    hir::LangItem::FuturePoll,
+                    arena_vec![self; new_unchecked, get_context],
+                ),
+                FutureKind::AsyncIterator => self.expr_call_lang_item_fn(
+                    span,
+                    hir::LangItem::AsyncIteratorPollNext,
+                    arena_vec![self; new_unchecked, get_context],
+                ),
+            };
             self.arena.alloc(self.expr_unsafe(call))
         };
 
@@ -1021,11 +1042,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
         let awaitee_arm = self.arm(awaitee_pat, loop_expr);
 
         // `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
-        let into_future_expr = self.expr_call_lang_item_fn(
-            span,
-            hir::LangItem::IntoFutureIntoFuture,
-            arena_vec![self; expr],
-        );
+        let into_future_expr = match await_kind {
+            FutureKind::Future => self.expr_call_lang_item_fn(
+                span,
+                hir::LangItem::IntoFutureIntoFuture,
+                arena_vec![self; *expr],
+            ),
+            // Not needed for `for await` because we expect to have already called
+            // `IntoAsyncIterator::into_async_iter` on it.
+            FutureKind::AsyncIterator => expr,
+        };
 
         // match <into_future_expr> {
         //     mut __awaitee => loop { .. }
@@ -1673,7 +1699,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
         head: &Expr,
         body: &Block,
         opt_label: Option<Label>,
-        _loop_kind: ForLoopKind,
+        loop_kind: ForLoopKind,
     ) -> hir::Expr<'hir> {
         let head = self.lower_expr_mut(head);
         let pat = self.lower_pat(pat);
@@ -1702,17 +1728,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
         let (iter_pat, iter_pat_nid) =
             self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT);
 
-        // `match Iterator::next(&mut iter) { ... }`
         let match_expr = {
             let iter = self.expr_ident(head_span, iter, iter_pat_nid);
-            let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
-            let next_expr = self.expr_call_lang_item_fn(
-                head_span,
-                hir::LangItem::IteratorNext,
-                arena_vec![self; ref_mut_iter],
-            );
+            let next_expr = match loop_kind {
+                ForLoopKind::For => {
+                    // `Iterator::next(&mut iter)`
+                    let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
+                    self.expr_call_lang_item_fn(
+                        head_span,
+                        hir::LangItem::IteratorNext,
+                        arena_vec![self; ref_mut_iter],
+                    )
+                }
+                ForLoopKind::ForAwait => {
+                    // we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
+                    // to make_lowered_await with `FutureKind::AsyncIterator` which will generator
+                    // calls to `poll_next`. In user code, this would probably be a call to
+                    // `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
+
+                    // `&mut iter`
+                    let iter = self.expr_mut_addr_of(head_span, iter);
+                    // `Pin::new_unchecked(...)`
+                    let iter = self.arena.alloc(self.expr_call_lang_item_fn_mut(
+                        head_span,
+                        hir::LangItem::PinNewUnchecked,
+                        arena_vec![self; iter],
+                    ));
+                    // `unsafe { ... }`
+                    let iter = self.arena.alloc(self.expr_unsafe(iter));
+                    let kind = self.make_lowered_await(head_span, iter, FutureKind::AsyncIterator);
+                    self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span })
+                }
+            };
             let arms = arena_vec![self; none_arm, some_arm];
 
+            // `match $next_expr { ... }`
             self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar)
         };
         let match_stmt = self.stmt_expr(for_span, match_expr);
@@ -1732,13 +1782,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
         // `mut iter => { ... }`
         let iter_arm = self.arm(iter_pat, loop_expr);
 
-        // `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
-        let into_iter_expr = {
-            self.expr_call_lang_item_fn(
-                head_span,
-                hir::LangItem::IntoIterIntoIter,
-                arena_vec![self; head],
-            )
+        let into_iter_expr = match loop_kind {
+            ForLoopKind::For => {
+                // `::std::iter::IntoIterator::into_iter(<head>)`
+                self.expr_call_lang_item_fn(
+                    head_span,
+                    hir::LangItem::IntoIterIntoIter,
+                    arena_vec![self; head],
+                )
+            }
+            ForLoopKind::ForAwait => self.arena.alloc(head),
         };
 
         let match_expr = self.arena.alloc(self.expr_match(
@@ -2141,3 +2194,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
         }
     }
 }
+
+/// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
+/// of future we are awaiting.
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+enum FutureKind {
+    /// We are awaiting a normal future
+    Future,
+    /// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
+    /// a `for await` loop)
+    AsyncIterator,
+}
diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs
index 96ed3eee02e..1de5db4deb8 100644
--- a/compiler/rustc_ast_lowering/src/lib.rs
+++ b/compiler/rustc_ast_lowering/src/lib.rs
@@ -130,6 +130,7 @@ struct LoweringContext<'a, 'hir> {
     allow_try_trait: Lrc<[Symbol]>,
     allow_gen_future: Lrc<[Symbol]>,
     allow_async_iterator: Lrc<[Symbol]>,
+    allow_for_await: Lrc<[Symbol]>,
 
     /// Mapping from generics `def_id`s to TAIT generics `def_id`s.
     /// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
@@ -174,6 +175,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             } else {
                 [sym::gen_future].into()
             },
+            allow_for_await: [sym::async_iterator].into(),
             // FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
             // interact with `gen`/`async gen` blocks
             allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),
diff --git a/compiler/rustc_builtin_macros/src/assert/context.rs b/compiler/rustc_builtin_macros/src/assert/context.rs
index 4e07122a5c4..d244897f8a5 100644
--- a/compiler/rustc_builtin_macros/src/assert/context.rs
+++ b/compiler/rustc_builtin_macros/src/assert/context.rs
@@ -303,7 +303,7 @@ impl<'cx, 'a> Context<'cx, 'a> {
             | ExprKind::Continue(_)
             | ExprKind::Err
             | ExprKind::Field(_, _)
-            | ExprKind::ForLoop {..}
+            | ExprKind::ForLoop { .. }
             | ExprKind::FormatArgs(_)
             | ExprKind::IncludedBytes(..)
             | ExprKind::InlineAsm(_)
diff --git a/compiler/rustc_feature/src/unstable.rs b/compiler/rustc_feature/src/unstable.rs
index fdd4bb755e4..bb1c95c25d7 100644
--- a/compiler/rustc_feature/src/unstable.rs
+++ b/compiler/rustc_feature/src/unstable.rs
@@ -358,7 +358,7 @@ declare_features! (
     /// Allows `#[track_caller]` on async functions.
     (unstable, async_fn_track_caller, "1.73.0", Some(110011)),
     /// Allows `for await` loops.
-    (unstable, async_for_loop, "CURRENT_RUSTC_VERSION", None),
+    (unstable, async_for_loop, "CURRENT_RUSTC_VERSION", Some(118898)),
     /// Allows builtin # foo() syntax
     (unstable, builtin_syntax, "1.71.0", Some(110680)),
     /// Treat `extern "C"` function as nounwind.
diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs
index b0b53bb7478..7691cd11c4f 100644
--- a/compiler/rustc_hir/src/lang_items.rs
+++ b/compiler/rustc_hir/src/lang_items.rs
@@ -307,6 +307,8 @@ language_item_table! {
     Context,                 sym::Context,             context,                    Target::Struct,         GenericRequirement::None;
     FuturePoll,              sym::poll,                future_poll_fn,             Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
 
+    AsyncIteratorPollNext,   sym::async_iterator_poll_next, async_iterator_poll_next, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(0);
+
     Option,                  sym::Option,              option_type,                Target::Enum,           GenericRequirement::None;
     OptionSome,              sym::Some,                option_some_variant,        Target::Variant,        GenericRequirement::None;
     OptionNone,              sym::None,                option_none_variant,        Target::Variant,        GenericRequirement::None;
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 2e7037532b5..95106cc64c1 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -428,6 +428,7 @@ symbols! {
         async_fn_track_caller,
         async_for_loop,
         async_iterator,
+        async_iterator_poll_next,
         atomic,
         atomic_mod,
         atomics,
@@ -894,6 +895,7 @@ symbols! {
         instruction_set,
         integer_: "integer", // underscore to avoid clashing with the function `sym::integer` below
         integral,
+        into_async_iter_into_iter,
         into_future,
         into_iter,
         intra_doc_pointers,
diff --git a/library/core/src/async_iter/async_iter.rs b/library/core/src/async_iter/async_iter.rs
index 8a45bd36f7a..ed9cb5bfea5 100644
--- a/library/core/src/async_iter/async_iter.rs
+++ b/library/core/src/async_iter/async_iter.rs
@@ -47,6 +47,7 @@ pub trait AsyncIterator {
     /// Rust's usual rules apply: calls must never cause undefined behavior
     /// (memory corruption, incorrect use of `unsafe` functions, or the like),
     /// regardless of the async iterator's state.
+    #[cfg_attr(not(bootstrap), lang = "async_iterator_poll_next")]
     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
 
     /// Returns the bounds on the remaining length of the async iterator.
diff --git a/tests/ui/async-await/for-await.rs b/tests/ui/async-await/for-await.rs
new file mode 100644
index 00000000000..6345ceb0c27
--- /dev/null
+++ b/tests/ui/async-await/for-await.rs
@@ -0,0 +1,24 @@
+// run-pass
+// edition: 2021
+#![feature(async_iterator, async_iter_from_iter, const_waker, async_for_loop, noop_waker)]
+
+use std::future::Future;
+
+// make sure a simple for await loop works
+async fn real_main() {
+    let iter = core::async_iter::from_iter(0..3);
+    let mut count = 0;
+    for await i in iter {
+        assert_eq!(i, count);
+        count += 1;
+    }
+    assert_eq!(count, 3);
+}
+
+fn main() {
+    let future = real_main();
+    let waker = std::task::Waker::noop();
+    let mut cx = &mut core::task::Context::from_waker(&waker);
+    let mut future = core::pin::pin!(future);
+    while let core::task::Poll::Pending = future.as_mut().poll(&mut cx) {}
+}