diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/add_retag.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/spans.rs | 10 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/generator.rs | 118 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline.rs | 52 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/sroa.rs | 2 |
6 files changed, 146 insertions, 39 deletions
diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs index 3d22035f078..7d2146214c6 100644 --- a/compiler/rustc_mir_transform/src/add_retag.rs +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -120,7 +120,7 @@ impl<'tcx> MirPass<'tcx> for AddRetag { // PART 3 // Add retag after assignments where data "enters" this function: the RHS is behind a deref and the LHS is not. for block_data in basic_blocks { - // We want to insert statements as we iterate. To this end, we + // We want to insert statements as we iterate. To this end, we // iterate backwards using indices. for i in (0..block_data.statements.len()).rev() { let (retag_kind, place) = match block_data.statements[i].kind { diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 9f842c929dc..c5434840453 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -341,11 +341,11 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { if a.is_in_same_bcb(b) { Some(Ordering::Equal) } else { - // Sort equal spans by dominator relationship, in reverse order (so - // dominators always come after the dominated equal spans). When later - // comparing two spans in order, the first will either dominate the second, - // or they will have no dominator relationship. - self.basic_coverage_blocks.dominators().rank_partial_cmp(b.bcb, a.bcb) + // Sort equal spans by dominator relationship (so dominators always come + // before the dominated equal spans). When later comparing two spans in + // order, the first will either dominate the second, or they will have no + // dominator relationship. + self.basic_coverage_blocks.dominators().rank_partial_cmp(a.bcb, b.bcb) } } else { // Sort hi() in reverse order so shorter spans are attempted after longer spans. diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index c097af61611..39c61a34afc 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -460,6 +460,104 @@ fn replace_local<'tcx>( new_local } +/// Transforms the `body` of the generator applying the following transforms: +/// +/// - Eliminates all the `get_context` calls that async lowering created. +/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`). +/// +/// The `Local`s that have their types replaced are: +/// - The `resume` argument itself. +/// - The argument to `get_context`. +/// - The yielded value of a `yield`. +/// +/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the +/// `get_context` function is being used to convert that back to a `&mut Context<'_>`. +/// +/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection, +/// but rather directly use `&mut Context<'_>`, however that would currently +/// lead to higher-kinded lifetime errors. +/// See <https://github.com/rust-lang/rust/issues/105501>. +/// +/// The async lowering step and the type / lifetime inference / checking are +/// still using the `ResumeTy` indirection for the time being, and that indirection +/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`. +fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let context_mut_ref = tcx.mk_task_context(); + + // replace the type of the `resume` argument + replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref); + + let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None); + + for bb in BasicBlock::new(0)..body.basic_blocks.next_index() { + let bb_data = &body[bb]; + if bb_data.is_cleanup { + continue; + } + + match &bb_data.terminator().kind { + TerminatorKind::Call { func, .. } => { + let func_ty = func.ty(body, tcx); + if let ty::FnDef(def_id, _) = *func_ty.kind() { + if def_id == get_context_def_id { + let local = eliminate_get_context_call(&mut body[bb]); + replace_resume_ty_local(tcx, body, local, context_mut_ref); + } + } else { + continue; + } + } + TerminatorKind::Yield { resume_arg, .. } => { + replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref); + } + _ => {} + } + } +} + +fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local { + let terminator = bb_data.terminator.take().unwrap(); + if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind { + let arg = args.pop().unwrap(); + let local = arg.place().unwrap().local; + + let arg = Rvalue::Use(arg); + let assign = Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new((destination, arg))), + }; + bb_data.statements.push(assign); + bb_data.terminator = Some(Terminator { + source_info: terminator.source_info, + kind: TerminatorKind::Goto { target: target.unwrap() }, + }); + local + } else { + bug!(); + } +} + +#[cfg_attr(not(debug_assertions), allow(unused))] +fn replace_resume_ty_local<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + local: Local, + context_mut_ref: Ty<'tcx>, +) { + let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref); + // We have to replace the `ResumeTy` that is used for type and borrow checking + // with `&mut Context<'_>` in MIR. + #[cfg(debug_assertions)] + { + if let ty::Adt(resume_ty_adt, _) = local_ty.kind() { + let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None)); + assert_eq!(*resume_ty_adt, expected_adt); + } else { + panic!("expected `ResumeTy`, found `{:?}`", local_ty); + }; + } +} + struct LivenessInfo { /// Which locals are live across any suspension point. saved_locals: GeneratorSavedLocals, @@ -1283,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } }; - let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen; + let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_))); let (state_adt_ref, state_substs) = if is_async_kind { // Compute Poll<return_ty> - let state_did = tcx.require_lang_item(LangItem::Poll, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_substs = tcx.intern_substs(&[body.return_ty().into()]); - (state_adt_ref, state_substs) + let poll_did = tcx.require_lang_item(LangItem::Poll, None); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_substs = tcx.intern_substs(&[body.return_ty().into()]); + (poll_adt_ref, poll_substs) } else { // Compute GeneratorState<yield_ty, return_ty> let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); @@ -1303,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // RETURN_PLACE then is a fresh unused local with type ret_ty. let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx); + // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies. + if is_async_kind { + transform_async_context(tcx, body); + } + // We also replace the resume argument and insert an `Assign`. // This is needed because the resume argument `_2` might be live across a `yield`, in which // case there is no `Assign` to it that the transform can turn into a store to the generator // state. After the yield the slot in the generator state would then be uninitialized. let resume_local = Local::new(2); - let new_resume_local = - replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx); + let resume_ty = + if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty }; + let new_resume_local = replace_local(resume_local, resume_ty, body, tcx); // When first entering the generator, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 4219e6280eb..28c9080d38d 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -542,6 +542,21 @@ impl<'tcx> Inliner<'tcx> { destination }; + // Always create a local to hold the destination, as `RETURN_PLACE` may appear + // where a full `Place` is not allowed. + let (remap_destination, destination_local) = if let Some(d) = dest.as_local() { + (false, d) + } else { + ( + true, + self.new_call_temp( + caller_body, + &callsite, + destination.ty(caller_body, self.tcx).ty, + ), + ) + }; + // Copy the arguments if needed. let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body); @@ -560,7 +575,7 @@ impl<'tcx> Inliner<'tcx> { new_locals: Local::new(caller_body.local_decls.len()).., new_scopes: SourceScope::new(caller_body.source_scopes.len()).., new_blocks: BasicBlock::new(caller_body.basic_blocks.len()).., - destination: dest, + destination: destination_local, callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(), callsite, cleanup_block: cleanup, @@ -591,6 +606,16 @@ impl<'tcx> Inliner<'tcx> { // To avoid repeated O(n) insert, push any new statements to the end and rotate // the slice once. let mut n = 0; + if remap_destination { + caller_body[block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new(( + dest, + Rvalue::Use(Operand::Move(destination_local.into())), + ))), + }); + n += 1; + } for local in callee_body.vars_and_temps_iter().rev() { if !callee_body.local_decls[local].internal && integrator.always_live_locals.contains(local) @@ -959,7 +984,7 @@ struct Integrator<'a, 'tcx> { new_locals: RangeFrom<Local>, new_scopes: RangeFrom<SourceScope>, new_blocks: RangeFrom<BasicBlock>, - destination: Place<'tcx>, + destination: Local, callsite_scope: SourceScopeData<'tcx>, callsite: &'a CallSite<'tcx>, cleanup_block: Option<BasicBlock>, @@ -972,7 +997,7 @@ struct Integrator<'a, 'tcx> { impl Integrator<'_, '_> { fn map_local(&self, local: Local) -> Local { let new = if local == RETURN_PLACE { - self.destination.local + self.destination } else { let idx = local.index() - 1; if idx < self.args.len() { @@ -1053,27 +1078,6 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { *span = span.fresh_expansion(self.expn_data); } - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - for elem in place.projection { - // FIXME: Make sure that return place is not used in an indexing projection, since it - // won't be rebased as it is supposed to be. - assert_ne!(ProjectionElem::Index(RETURN_PLACE), elem); - } - - // If this is the `RETURN_PLACE`, we need to rebase any projections onto it. - let dest_proj_len = self.destination.projection.len(); - if place.local == RETURN_PLACE && dest_proj_len > 0 { - let mut projs = Vec::with_capacity(dest_proj_len + place.projection.len()); - projs.extend(self.destination.projection); - projs.extend(place.projection); - - place.projection = self.tcx.intern_place_elems(&*projs); - } - // Handles integrating any locals that occur in the base - // or projections - self.super_place(place, context, location) - } - fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { self.in_cleanup_block = data.is_cleanup; self.super_basic_block_data(block, data); diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 16b8a901f36..20b7fdcfe6d 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -487,7 +487,6 @@ fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx> fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let passes: &[&dyn MirPass<'tcx>] = &[ &cleanup_post_borrowck::CleanupPostBorrowck, - &simplify_branches::SimplifyConstCondition::new("initial"), &remove_noop_landing_pads::RemoveNoopLandingPads, &simplify::SimplifyCfg::new("early-opt"), &deref_separator::Derefer, diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 3a2bf051516..42124f5a480 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -215,7 +215,7 @@ struct ReplacementVisitor<'tcx, 'll> { replacements: ReplacementMap<'tcx>, /// This is used to check that we are not leaving references to replaced locals behind. all_dead_locals: BitSet<Local>, - /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage + /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage /// and deinit statement and debuginfo. fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>, } |
