diff options
| author | bors <bors@rust-lang.org> | 2023-10-29 00:03:52 +0000 |
|---|---|---|
| committer | bors <bors@rust-lang.org> | 2023-10-29 00:03:52 +0000 |
| commit | 2cad938a8173520f2bc39137b475f136be0aeeda (patch) | |
| tree | 668c1a104a33221cc0a6082cbe42e752d82690b3 /compiler/rustc_mir_transform/src | |
| parent | e5cfc55477eceed1317a02189fdf77a4a98f2124 (diff) | |
| parent | eb66d10cc3a6947cad6b4b169ed86b8c07f464d3 (diff) | |
| download | rust-2cad938a8173520f2bc39137b475f136be0aeeda.tar.gz rust-2cad938a8173520f2bc39137b475f136be0aeeda.zip | |
Auto merge of #116447 - oli-obk:gen_fn, r=compiler-errors
Implement `gen` blocks in the 2024 edition
Coroutines tracking issue https://github.com/rust-lang/rust/issues/43122
`gen` block tracking issue https://github.com/rust-lang/rust/issues/117078
This PR implements `gen` blocks that implement `Iterator`. Most of the logic with `async` blocks is shared, and thus I renamed various types that were referring to `async` specifically.
An example usage of `gen` blocks is
```rust
fn foo() -> impl Iterator<Item = i32> {
gen {
yield 42;
for i in 5..18 {
if i.is_even() { continue }
yield i * 2;
}
}
}
```
The limitations (to be resolved) of the implementation are listed in the tracking issue
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/coroutine.rs | 143 |
1 files changed, 105 insertions, 38 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index fa56d59dd80..50d244d2831 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> { struct TransformVisitor<'tcx> { tcx: TyCtxt<'tcx>, - is_async_kind: bool, + coroutine_kind: hir::CoroutineKind, state_adt_ref: AdtDef<'tcx>, state_args: GenericArgsRef<'tcx>, @@ -249,6 +249,47 @@ struct TransformVisitor<'tcx> { } impl<'tcx> TransformVisitor<'tcx> { + fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock { + let block = BasicBlock::new(body.basic_blocks.len()); + + let source_info = SourceInfo::outermost(body.span); + + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true); + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + let statements = vec![Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }]; + + body.basic_blocks_mut().push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }); + + block + } + + fn coroutine_state_adt_and_variant_idx( + &self, + is_return: bool, + ) -> (AggregateKind<'tcx>, VariantIdx) { + let idx = VariantIdx::new(match (is_return, self.coroutine_kind) { + (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete + (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded + (true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready + (false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending + (true, hir::CoroutineKind::Gen(_)) => 0, // Option::None + (false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some + }); + + let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); + (kind, idx) + } + // Make a `CoroutineState` or `Poll` variant assignment. // // `core::ops::CoroutineState` only has single element tuple variants, @@ -261,31 +302,44 @@ impl<'tcx> TransformVisitor<'tcx> { is_return: bool, statements: &mut Vec<Statement<'tcx>>, ) { - let idx = VariantIdx::new(match (is_return, self.is_async_kind) { - (true, false) => 1, // CoroutineState::Complete - (false, false) => 0, // CoroutineState::Yielded - (true, true) => 0, // Poll::Ready - (false, true) => 1, // Poll::Pending - }); + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return); - let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); + match self.coroutine_kind { + // `Poll::Pending` + CoroutineKind::Async(_) => { + if !is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - // `Poll::Pending` - if self.is_async_kind && idx == VariantIdx::new(1) { - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + // FIXME(swatinem): assert that `val` is indeed unit? + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + // `Option::None` + CoroutineKind::Gen(_) => { + if is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - // FIXME(swatinem): assert that `val` is indeed unit? - statements.push(Statement { - kind: StatementKind::Assign(Box::new(( - Place::return_place(), - Rvalue::Aggregate(Box::new(kind), IndexVec::new()), - ))), - source_info, - }); - return; + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + CoroutineKind::Coroutine => {} } - // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)` + // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); statements.push(Statement { @@ -1263,10 +1317,13 @@ fn create_coroutine_resume_function<'tcx>( } if can_return { - cases.insert( - 1, - (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))), - ); + let block = match coroutine_kind { + CoroutineKind::Async(_) | CoroutineKind::Coroutine => { + insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) + } + CoroutineKind::Gen(_) => transform.insert_none_ret_block(body), + }; + cases.insert(1, (RETURNED, block)); } insert_switch(body, cases, &transform, TerminatorKind::Unreachable); @@ -1439,18 +1496,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform { }; let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); - let (state_adt_ref, state_args) = if is_async_kind { - // Compute Poll<return_ty> - let poll_did = tcx.require_lang_item(LangItem::Poll, None); - let poll_adt_ref = tcx.adt_def(poll_did); - let poll_args = tcx.mk_args(&[body.return_ty().into()]); - (poll_adt_ref, poll_args) - } else { - // Compute CoroutineState<yield_ty, return_ty> - let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); - (state_adt_ref, state_args) + let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() { + CoroutineKind::Async(_) => { + // Compute Poll<return_ty> + let poll_did = tcx.require_lang_item(LangItem::Poll, None); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_args = tcx.mk_args(&[body.return_ty().into()]); + (poll_adt_ref, poll_args) + } + CoroutineKind::Gen(_) => { + // Compute Option<yield_ty> + let option_did = tcx.require_lang_item(LangItem::Option, None); + let option_adt_ref = tcx.adt_def(option_did); + let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]); + (option_adt_ref, option_args) + } + CoroutineKind::Coroutine => { + // Compute CoroutineState<yield_ty, return_ty> + let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); + (state_adt_ref, state_args) + } }; let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args); @@ -1518,7 +1585,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`. let mut transform = TransformVisitor { tcx, - is_async_kind, + coroutine_kind: body.coroutine_kind().unwrap(), state_adt_ref, state_args, remap, |
