diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/coroutine/by_move_body.rs | 45 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline/cycle.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/pass_manager.rs | 11 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/shim.rs | 77 |
5 files changed, 111 insertions, 26 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index 1cc0a5026d1..fcd4715b9e8 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -6,7 +6,7 @@ use rustc_data_structures::fx::FxIndexSet; use rustc_hir as hir; use rustc_middle::mir::visit::MutVisitor; -use rustc_middle::mir::{self, MirPass}; +use rustc_middle::mir::{self, dump_mir, MirPass}; use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt}; use rustc_target::abi::FieldIdx; @@ -24,7 +24,9 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { }; let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() }; - if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce { + + let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(); + if coroutine_kind == ty::ClosureKind::FnOnce { return; } @@ -58,14 +60,49 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { let mut by_move_body = body.clone(); MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); by_move_body.source = mir::MirSource { - instance: InstanceDef::CoroutineByMoveShim { + instance: InstanceDef::CoroutineKindShim { coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnOnce, }, promoted: None, }; - body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + + // If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body. + // This is actually just a copy of the by-ref body, but with a different self type. + // FIXME(async_closures): We could probably unify this with the by-ref body somehow. + if coroutine_kind == ty::ClosureKind::Fn { + let by_mut_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(), + }, + ) + .args, + ); + let mut by_mut_body = body.clone(); + by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty; + dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(())); + by_mut_body.source = mir::MirSource { + instance: InstanceDef::CoroutineKindShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnMut, + }, + promoted: None, + }; + body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body); + } } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 24bc84a235c..e77553a03d6 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -318,7 +318,7 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } - | InstanceDef::CoroutineByMoveShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 77ff780393e..5b03bc361dd 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -88,7 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } - | InstanceDef::CoroutineByMoveShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index c7e770904fb..605e1ad46d7 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -190,10 +190,13 @@ fn run_passes_inner<'tcx>( body.pass_count = 1; } - if let Some(coroutine) = body.coroutine.as_mut() - && let Some(by_move_body) = coroutine.by_move_body.as_mut() - { - run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + if let Some(coroutine) = body.coroutine.as_mut() { + if let Some(by_move_body) = coroutine.by_move_body.as_mut() { + run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + } + if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() { + run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each); + } } } diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 668ccdd8735..7b6de3a5439 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -72,32 +72,70 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' } => match target_kind { ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"), ty::ClosureKind::FnMut => { - let body = build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); - // No need to optimize the body, it has already been optimized. - return body; + // No need to optimize the body, it has already been optimized + // since we steal it from the `AsyncFn::call` body and just fix + // the return type. + return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); } ty::ClosureKind::FnOnce => { build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id) } }, - ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => { - return tcx - .optimized_mir(coroutine_def_id) - .coroutine - .as_ref() - .unwrap() - .by_move_body - .as_ref() - .unwrap() - .clone(); - } + ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind { + ty::ClosureKind::Fn => unreachable!(), + ty::ClosureKind::FnMut => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_mut_body() + .unwrap() + .clone(); + } + ty::ClosureKind::FnOnce => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_move_body() + .unwrap() + .clone(); + } + }, ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? if let Some(ty::Coroutine(coroutine_def_id, args)) = ty.map(Ty::kind) { - let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap(); + let coroutine_body = tcx.optimized_mir(*coroutine_def_id); + + let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind() + else { + bug!() + }; + + // If this is a regular coroutine, grab its drop shim. If this is a coroutine + // that comes from a coroutine-closure, and the kind ty differs from the "maximum" + // kind that it supports, then grab the appropriate drop shim. This ensures that + // the future returned by `<[coroutine-closure] as AsyncFnOnce>::call_once` will + // drop the coroutine-closure's upvars. + let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() { + coroutine_body.coroutine_drop().unwrap() + } else { + match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() { + ty::ClosureKind::Fn => { + unreachable!() + } + ty::ClosureKind::FnMut => coroutine_body + .coroutine_by_mut_body() + .unwrap() + .coroutine_drop() + .unwrap(), + ty::ClosureKind::FnOnce => coroutine_body + .coroutine_by_move_body() + .unwrap() + .coroutine_drop() + .unwrap(), + } + }; + let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); debug!("make_shim({:?}) = {:?}", instance, body); @@ -1076,7 +1114,11 @@ fn build_construct_coroutine_by_move_shim<'tcx>( target_kind: ty::ClosureKind::FnOnce, }); - new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) + let body = + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span); + dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(())); + + body } fn build_construct_coroutine_by_mut_shim<'tcx>( @@ -1110,5 +1152,8 @@ fn build_construct_coroutine_by_mut_shim<'tcx>( target_kind: ty::ClosureKind::FnMut, }); + body.pass_count = 0; + dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(())); + body } |
