diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/coroutine.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coroutine/by_move_body.rs | 108 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline/cycle.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 4 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/pass_manager.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/shim.rs | 12 |
7 files changed, 135 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index bde879f6067..297b2fa143d 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -50,6 +50,9 @@ //! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing. //! Otherwise it drops all the values in scope at the last suspension point. +mod by_move_body; +pub use by_move_body::ByMoveBody; + use crate::abort_unwinding_calls; use crate::deref_separator::deref_finder; use crate::errors; diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs new file mode 100644 index 00000000000..4e3e70bdafe --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -0,0 +1,108 @@ +use rustc_data_structures::fx::FxIndexSet; +use rustc_hir as hir; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{self, MirPass}; +use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt}; +use rustc_target::abi::FieldIdx; + +pub struct ByMoveBody; + +impl<'tcx> MirPass<'tcx> for ByMoveBody { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let Some(coroutine_def_id) = body.source.def_id().as_local() else { + return; + }; + let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) = + tcx.coroutine_kind(coroutine_def_id) + else { + return; + }; + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() }; + if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce { + return; + } + + let mut by_ref_fields = FxIndexSet::default(); + let by_move_upvars = Ty::new_tup_from_iter( + tcx, + tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { + if capture.is_by_ref() { + by_ref_fields.insert(FieldIdx::from_usize(idx)); + } + capture.place.ty() + }), + ); + let by_move_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: by_move_upvars, + }, + ) + .args, + ); + + let mut by_move_body = body.clone(); + MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + by_move_body.source = mir::MirSource { + instance: InstanceDef::CoroutineByMoveShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + }, + promoted: None, + }; + + body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + } +} + +struct MakeByMoveBody<'tcx> { + tcx: TyCtxt<'tcx>, + by_ref_fields: FxIndexSet<FieldIdx>, + by_move_coroutine_ty: Ty<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut mir::Place<'tcx>, + context: mir::visit::PlaceContext, + location: mir::Location, + ) { + if place.local == ty::CAPTURE_STRUCT_LOCAL + && !place.projection.is_empty() + && let mir::ProjectionElem::Field(idx, ty) = place.projection[0] + && self.by_ref_fields.contains(&idx) + { + let (begin, end) = place.projection[1..].split_first().unwrap(); + assert_eq!(*begin, mir::ProjectionElem::Deref); + *place = mir::Place { + local: place.local, + projection: self.tcx.mk_place_elems_from_iter( + [mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)] + .into_iter() + .chain(end.iter().copied()), + ), + }; + } + self.super_place(place, context, location); + } + + fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) { + if local == ty::CAPTURE_STRUCT_LOCAL { + local_decl.ty = self.by_move_coroutine_ty; + } + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 7c731b070a7..24bc84a235c 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -318,6 +318,7 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 3f3dc9145b6..77ff780393e 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -88,6 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 69f93fa3a0e..031515ea958 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -307,6 +307,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { &Lint(check_packed_ref::CheckPackedRef), &Lint(check_const_item_mutation::CheckConstItemMutation), &Lint(function_item_references::FunctionItemReferences), + // If this is an async closure's output coroutine, generate + // by-move and by-mut bodies if needed. We do this first so + // they can be optimized in lockstep with their parent bodies. + &coroutine::ByMoveBody, // What we need to do constant evaluation. &simplify::SimplifyCfg::Initial, &rustc_peek::SanityCheck, // Just a lint diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index c1ef2b9f887..c7e770904fb 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -189,6 +189,12 @@ fn run_passes_inner<'tcx>( body.pass_count = 1; } + + if let Some(coroutine) = body.coroutine.as_mut() + && let Some(by_move_body) = coroutine.by_move_body.as_mut() + { + run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + } } pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) { diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 29b83f58ef5..668ccdd8735 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -81,6 +81,18 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' } }, + ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine + .as_ref() + .unwrap() + .by_move_body + .as_ref() + .unwrap() + .clone(); + } + ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? |
