diff options
| author | Eric Holk <ericholk@microsoft.com> | 2023-12-08 17:00:11 -0800 |
|---|---|---|
| committer | Eric Holk <ericholk@microsoft.com> | 2023-12-19 12:26:27 -0800 |
| commit | 97df0d3657fafac6e6ca481c1d4123bf6e2633f1 (patch) | |
| tree | 22c7d7314ac22444869f27eb80487a56e499ae3a | |
| parent | 27d6539a46123dcdb6fae6e043b8c1c12b3e0d6f (diff) | |
| download | rust-97df0d3657fafac6e6ca481c1d4123bf6e2633f1.tar.gz rust-97df0d3657fafac6e6ca481c1d4123bf6e2633f1.zip | |
Desugar for await loops
| -rw-r--r-- | compiler/rustc_ast_lowering/src/expr.rs | 120 | ||||
| -rw-r--r-- | compiler/rustc_ast_lowering/src/lib.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_builtin_macros/src/assert/context.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_feature/src/unstable.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_hir/src/lang_items.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_span/src/symbol.rs | 2 | ||||
| -rw-r--r-- | library/core/src/async_iter/async_iter.rs | 1 | ||||
| -rw-r--r-- | tests/ui/async-await/for-await.rs | 24 |
8 files changed, 125 insertions, 30 deletions
diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index c2d91b09453..660cc9b70f1 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> { ), ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr), - ExprKind::Paren(_) | ExprKind::ForLoop{..} => { + ExprKind::Paren(_) | ExprKind::ForLoop { .. } => { unreachable!("already handled") } @@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> { /// } /// ``` fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> { + let expr = self.arena.alloc(self.lower_expr_mut(expr)); + self.make_lowered_await(await_kw_span, expr, FutureKind::Future) + } + + /// Takes an expr that has already been lowered and generates a desugared await loop around it + fn make_lowered_await( + &mut self, + await_kw_span: Span, + expr: &'hir hir::Expr<'hir>, + await_kind: FutureKind, + ) -> hir::ExprKind<'hir> { let full_span = expr.span.to(await_kw_span); let is_async_gen = match self.coroutine_kind { @@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> { } }; - let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None); + let features = match await_kind { + FutureKind::Future => None, + FutureKind::AsyncIterator => Some(self.allow_for_await.clone()), + }; + let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, features); let gen_future_span = self.mark_span_with_reason( DesugaringKind::Await, full_span, Some(self.allow_gen_future.clone()), ); - let expr = self.lower_expr_mut(expr); let expr_hir_id = expr.hir_id; // Note that the name of this binding must not be changed to something else because @@ -933,11 +947,18 @@ impl<'hir> LoweringContext<'_, 'hir> { hir::LangItem::GetContext, arena_vec![self; task_context], ); - let call = self.expr_call_lang_item_fn( - span, - hir::LangItem::FuturePoll, - arena_vec![self; new_unchecked, get_context], - ); + let call = match await_kind { + FutureKind::Future => self.expr_call_lang_item_fn( + span, + hir::LangItem::FuturePoll, + arena_vec![self; new_unchecked, get_context], + ), + FutureKind::AsyncIterator => self.expr_call_lang_item_fn( + span, + hir::LangItem::AsyncIteratorPollNext, + arena_vec![self; new_unchecked, get_context], + ), + }; self.arena.alloc(self.expr_unsafe(call)) }; @@ -1021,11 +1042,16 @@ impl<'hir> LoweringContext<'_, 'hir> { let awaitee_arm = self.arm(awaitee_pat, loop_expr); // `match ::std::future::IntoFuture::into_future(<expr>) { ... }` - let into_future_expr = self.expr_call_lang_item_fn( - span, - hir::LangItem::IntoFutureIntoFuture, - arena_vec![self; expr], - ); + let into_future_expr = match await_kind { + FutureKind::Future => self.expr_call_lang_item_fn( + span, + hir::LangItem::IntoFutureIntoFuture, + arena_vec![self; *expr], + ), + // Not needed for `for await` because we expect to have already called + // `IntoAsyncIterator::into_async_iter` on it. + FutureKind::AsyncIterator => expr, + }; // match <into_future_expr> { // mut __awaitee => loop { .. } @@ -1673,7 +1699,7 @@ impl<'hir> LoweringContext<'_, 'hir> { head: &Expr, body: &Block, opt_label: Option<Label>, - _loop_kind: ForLoopKind, + loop_kind: ForLoopKind, ) -> hir::Expr<'hir> { let head = self.lower_expr_mut(head); let pat = self.lower_pat(pat); @@ -1702,17 +1728,41 @@ impl<'hir> LoweringContext<'_, 'hir> { let (iter_pat, iter_pat_nid) = self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT); - // `match Iterator::next(&mut iter) { ... }` let match_expr = { let iter = self.expr_ident(head_span, iter, iter_pat_nid); - let ref_mut_iter = self.expr_mut_addr_of(head_span, iter); - let next_expr = self.expr_call_lang_item_fn( - head_span, - hir::LangItem::IteratorNext, - arena_vec![self; ref_mut_iter], - ); + let next_expr = match loop_kind { + ForLoopKind::For => { + // `Iterator::next(&mut iter)` + let ref_mut_iter = self.expr_mut_addr_of(head_span, iter); + self.expr_call_lang_item_fn( + head_span, + hir::LangItem::IteratorNext, + arena_vec![self; ref_mut_iter], + ) + } + ForLoopKind::ForAwait => { + // we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this + // to make_lowered_await with `FutureKind::AsyncIterator` which will generator + // calls to `poll_next`. In user code, this would probably be a call to + // `Pin::as_mut` but here it's easy enough to do `new_unchecked`. + + // `&mut iter` + let iter = self.expr_mut_addr_of(head_span, iter); + // `Pin::new_unchecked(...)` + let iter = self.arena.alloc(self.expr_call_lang_item_fn_mut( + head_span, + hir::LangItem::PinNewUnchecked, + arena_vec![self; iter], + )); + // `unsafe { ... }` + let iter = self.arena.alloc(self.expr_unsafe(iter)); + let kind = self.make_lowered_await(head_span, iter, FutureKind::AsyncIterator); + self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span }) + } + }; let arms = arena_vec![self; none_arm, some_arm]; + // `match $next_expr { ... }` self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar) }; let match_stmt = self.stmt_expr(for_span, match_expr); @@ -1732,13 +1782,16 @@ impl<'hir> LoweringContext<'_, 'hir> { // `mut iter => { ... }` let iter_arm = self.arm(iter_pat, loop_expr); - // `match ::std::iter::IntoIterator::into_iter(<head>) { ... }` - let into_iter_expr = { - self.expr_call_lang_item_fn( - head_span, - hir::LangItem::IntoIterIntoIter, - arena_vec![self; head], - ) + let into_iter_expr = match loop_kind { + ForLoopKind::For => { + // `::std::iter::IntoIterator::into_iter(<head>)` + self.expr_call_lang_item_fn( + head_span, + hir::LangItem::IntoIterIntoIter, + arena_vec![self; head], + ) + } + ForLoopKind::ForAwait => self.arena.alloc(head), }; let match_expr = self.arena.alloc(self.expr_match( @@ -2141,3 +2194,14 @@ impl<'hir> LoweringContext<'_, 'hir> { } } } + +/// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind +/// of future we are awaiting. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum FutureKind { + /// We are awaiting a normal future + Future, + /// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of + /// a `for await` loop) + AsyncIterator, +} diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs index 96ed3eee02e..1de5db4deb8 100644 --- a/compiler/rustc_ast_lowering/src/lib.rs +++ b/compiler/rustc_ast_lowering/src/lib.rs @@ -130,6 +130,7 @@ struct LoweringContext<'a, 'hir> { allow_try_trait: Lrc<[Symbol]>, allow_gen_future: Lrc<[Symbol]>, allow_async_iterator: Lrc<[Symbol]>, + allow_for_await: Lrc<[Symbol]>, /// Mapping from generics `def_id`s to TAIT generics `def_id`s. /// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic @@ -174,6 +175,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { } else { [sym::gen_future].into() }, + allow_for_await: [sym::async_iterator].into(), // FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller` // interact with `gen`/`async gen` blocks allow_async_iterator: [sym::gen_future, sym::async_iterator].into(), diff --git a/compiler/rustc_builtin_macros/src/assert/context.rs b/compiler/rustc_builtin_macros/src/assert/context.rs index 4e07122a5c4..d244897f8a5 100644 --- a/compiler/rustc_builtin_macros/src/assert/context.rs +++ b/compiler/rustc_builtin_macros/src/assert/context.rs @@ -303,7 +303,7 @@ impl<'cx, 'a> Context<'cx, 'a> { | ExprKind::Continue(_) | ExprKind::Err | ExprKind::Field(_, _) - | ExprKind::ForLoop {..} + | ExprKind::ForLoop { .. } | ExprKind::FormatArgs(_) | ExprKind::IncludedBytes(..) | ExprKind::InlineAsm(_) diff --git a/compiler/rustc_feature/src/unstable.rs b/compiler/rustc_feature/src/unstable.rs index fdd4bb755e4..bb1c95c25d7 100644 --- a/compiler/rustc_feature/src/unstable.rs +++ b/compiler/rustc_feature/src/unstable.rs @@ -358,7 +358,7 @@ declare_features! ( /// Allows `#[track_caller]` on async functions. (unstable, async_fn_track_caller, "1.73.0", Some(110011)), /// Allows `for await` loops. - (unstable, async_for_loop, "CURRENT_RUSTC_VERSION", None), + (unstable, async_for_loop, "CURRENT_RUSTC_VERSION", Some(118898)), /// Allows builtin # foo() syntax (unstable, builtin_syntax, "1.71.0", Some(110680)), /// Treat `extern "C"` function as nounwind. diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index b0b53bb7478..7691cd11c4f 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -307,6 +307,8 @@ language_item_table! { Context, sym::Context, context, Target::Struct, GenericRequirement::None; FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None; + AsyncIteratorPollNext, sym::async_iterator_poll_next, async_iterator_poll_next, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(0); + Option, sym::Option, option_type, Target::Enum, GenericRequirement::None; OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None; OptionNone, sym::None, option_none_variant, Target::Variant, GenericRequirement::None; diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 2e7037532b5..95106cc64c1 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -428,6 +428,7 @@ symbols! { async_fn_track_caller, async_for_loop, async_iterator, + async_iterator_poll_next, atomic, atomic_mod, atomics, @@ -894,6 +895,7 @@ symbols! { instruction_set, integer_: "integer", // underscore to avoid clashing with the function `sym::integer` below integral, + into_async_iter_into_iter, into_future, into_iter, intra_doc_pointers, diff --git a/library/core/src/async_iter/async_iter.rs b/library/core/src/async_iter/async_iter.rs index 8a45bd36f7a..ed9cb5bfea5 100644 --- a/library/core/src/async_iter/async_iter.rs +++ b/library/core/src/async_iter/async_iter.rs @@ -47,6 +47,7 @@ pub trait AsyncIterator { /// Rust's usual rules apply: calls must never cause undefined behavior /// (memory corruption, incorrect use of `unsafe` functions, or the like), /// regardless of the async iterator's state. + #[cfg_attr(not(bootstrap), lang = "async_iterator_poll_next")] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>; /// Returns the bounds on the remaining length of the async iterator. diff --git a/tests/ui/async-await/for-await.rs b/tests/ui/async-await/for-await.rs new file mode 100644 index 00000000000..6345ceb0c27 --- /dev/null +++ b/tests/ui/async-await/for-await.rs @@ -0,0 +1,24 @@ +// run-pass +// edition: 2021 +#![feature(async_iterator, async_iter_from_iter, const_waker, async_for_loop, noop_waker)] + +use std::future::Future; + +// make sure a simple for await loop works +async fn real_main() { + let iter = core::async_iter::from_iter(0..3); + let mut count = 0; + for await i in iter { + assert_eq!(i, count); + count += 1; + } + assert_eq!(count, 3); +} + +fn main() { + let future = real_main(); + let waker = std::task::Waker::noop(); + let mut cx = &mut core::task::Context::from_waker(&waker); + let mut future = core::pin::pin!(future); + while let core::task::Poll::Pending = future.as_mut().poll(&mut cx) {} +} |
