about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/add_retag.rs2
-rw-r--r--compiler/rustc_mir_transform/src/coverage/spans.rs10
-rw-r--r--compiler/rustc_mir_transform/src/generator.rs118
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs52
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs1
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs2
6 files changed, 146 insertions, 39 deletions
diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs
index 3d22035f078..7d2146214c6 100644
--- a/compiler/rustc_mir_transform/src/add_retag.rs
+++ b/compiler/rustc_mir_transform/src/add_retag.rs
@@ -120,7 +120,7 @@ impl<'tcx> MirPass<'tcx> for AddRetag {
         // PART 3
         // Add retag after assignments where data "enters" this function: the RHS is behind a deref and the LHS is not.
         for block_data in basic_blocks {
-            // We want to insert statements as we iterate.  To this end, we
+            // We want to insert statements as we iterate. To this end, we
             // iterate backwards using indices.
             for i in (0..block_data.statements.len()).rev() {
                 let (retag_kind, place) = match block_data.statements[i].kind {
diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs
index 9f842c929dc..c5434840453 100644
--- a/compiler/rustc_mir_transform/src/coverage/spans.rs
+++ b/compiler/rustc_mir_transform/src/coverage/spans.rs
@@ -341,11 +341,11 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
                     if a.is_in_same_bcb(b) {
                         Some(Ordering::Equal)
                     } else {
-                        // Sort equal spans by dominator relationship, in reverse order (so
-                        // dominators always come after the dominated equal spans). When later
-                        // comparing two spans in order, the first will either dominate the second,
-                        // or they will have no dominator relationship.
-                        self.basic_coverage_blocks.dominators().rank_partial_cmp(b.bcb, a.bcb)
+                        // Sort equal spans by dominator relationship (so dominators always come
+                        // before the dominated equal spans). When later comparing two spans in
+                        // order, the first will either dominate the second, or they will have no
+                        // dominator relationship.
+                        self.basic_coverage_blocks.dominators().rank_partial_cmp(a.bcb, b.bcb)
                     }
                 } else {
                     // Sort hi() in reverse order so shorter spans are attempted after longer spans.
diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs
index c097af61611..39c61a34afc 100644
--- a/compiler/rustc_mir_transform/src/generator.rs
+++ b/compiler/rustc_mir_transform/src/generator.rs
@@ -460,6 +460,104 @@ fn replace_local<'tcx>(
     new_local
 }
 
+/// Transforms the `body` of the generator applying the following transforms:
+///
+/// - Eliminates all the `get_context` calls that async lowering created.
+/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
+///
+/// The `Local`s that have their types replaced are:
+/// - The `resume` argument itself.
+/// - The argument to `get_context`.
+/// - The yielded value of a `yield`.
+///
+/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
+/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
+///
+/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
+/// but rather directly use `&mut Context<'_>`, however that would currently
+/// lead to higher-kinded lifetime errors.
+/// See <https://github.com/rust-lang/rust/issues/105501>.
+///
+/// The async lowering step and the type / lifetime inference / checking are
+/// still using the `ResumeTy` indirection for the time being, and that indirection
+/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
+fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+    let context_mut_ref = tcx.mk_task_context();
+
+    // replace the type of the `resume` argument
+    replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);
+
+    let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
+
+    for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
+        let bb_data = &body[bb];
+        if bb_data.is_cleanup {
+            continue;
+        }
+
+        match &bb_data.terminator().kind {
+            TerminatorKind::Call { func, .. } => {
+                let func_ty = func.ty(body, tcx);
+                if let ty::FnDef(def_id, _) = *func_ty.kind() {
+                    if def_id == get_context_def_id {
+                        let local = eliminate_get_context_call(&mut body[bb]);
+                        replace_resume_ty_local(tcx, body, local, context_mut_ref);
+                    }
+                } else {
+                    continue;
+                }
+            }
+            TerminatorKind::Yield { resume_arg, .. } => {
+                replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
+            }
+            _ => {}
+        }
+    }
+}
+
+fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
+    let terminator = bb_data.terminator.take().unwrap();
+    if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
+        let arg = args.pop().unwrap();
+        let local = arg.place().unwrap().local;
+
+        let arg = Rvalue::Use(arg);
+        let assign = Statement {
+            source_info: terminator.source_info,
+            kind: StatementKind::Assign(Box::new((destination, arg))),
+        };
+        bb_data.statements.push(assign);
+        bb_data.terminator = Some(Terminator {
+            source_info: terminator.source_info,
+            kind: TerminatorKind::Goto { target: target.unwrap() },
+        });
+        local
+    } else {
+        bug!();
+    }
+}
+
+#[cfg_attr(not(debug_assertions), allow(unused))]
+fn replace_resume_ty_local<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    body: &mut Body<'tcx>,
+    local: Local,
+    context_mut_ref: Ty<'tcx>,
+) {
+    let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
+    // We have to replace the `ResumeTy` that is used for type and borrow checking
+    // with `&mut Context<'_>` in MIR.
+    #[cfg(debug_assertions)]
+    {
+        if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
+            let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
+            assert_eq!(*resume_ty_adt, expected_adt);
+        } else {
+            panic!("expected `ResumeTy`, found `{:?}`", local_ty);
+        };
+    }
+}
+
 struct LivenessInfo {
     /// Which locals are live across any suspension point.
     saved_locals: GeneratorSavedLocals,
@@ -1283,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
             }
         };
 
-        let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen;
+        let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
         let (state_adt_ref, state_substs) = if is_async_kind {
             // Compute Poll<return_ty>
-            let state_did = tcx.require_lang_item(LangItem::Poll, None);
-            let state_adt_ref = tcx.adt_def(state_did);
-            let state_substs = tcx.intern_substs(&[body.return_ty().into()]);
-            (state_adt_ref, state_substs)
+            let poll_did = tcx.require_lang_item(LangItem::Poll, None);
+            let poll_adt_ref = tcx.adt_def(poll_did);
+            let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
+            (poll_adt_ref, poll_substs)
         } else {
             // Compute GeneratorState<yield_ty, return_ty>
             let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
@@ -1303,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         // RETURN_PLACE then is a fresh unused local with type ret_ty.
         let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
 
+        // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
+        if is_async_kind {
+            transform_async_context(tcx, body);
+        }
+
         // We also replace the resume argument and insert an `Assign`.
         // This is needed because the resume argument `_2` might be live across a `yield`, in which
         // case there is no `Assign` to it that the transform can turn into a store to the generator
         // state. After the yield the slot in the generator state would then be uninitialized.
         let resume_local = Local::new(2);
-        let new_resume_local =
-            replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx);
+        let resume_ty =
+            if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty };
+        let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
 
         // When first entering the generator, move the resume argument into its new local.
         let source_info = SourceInfo::outermost(body.span);
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 4219e6280eb..28c9080d38d 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -542,6 +542,21 @@ impl<'tcx> Inliner<'tcx> {
                     destination
                 };
 
+                // Always create a local to hold the destination, as `RETURN_PLACE` may appear
+                // where a full `Place` is not allowed.
+                let (remap_destination, destination_local) = if let Some(d) = dest.as_local() {
+                    (false, d)
+                } else {
+                    (
+                        true,
+                        self.new_call_temp(
+                            caller_body,
+                            &callsite,
+                            destination.ty(caller_body, self.tcx).ty,
+                        ),
+                    )
+                };
+
                 // Copy the arguments if needed.
                 let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body);
 
@@ -560,7 +575,7 @@ impl<'tcx> Inliner<'tcx> {
                     new_locals: Local::new(caller_body.local_decls.len())..,
                     new_scopes: SourceScope::new(caller_body.source_scopes.len())..,
                     new_blocks: BasicBlock::new(caller_body.basic_blocks.len())..,
-                    destination: dest,
+                    destination: destination_local,
                     callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(),
                     callsite,
                     cleanup_block: cleanup,
@@ -591,6 +606,16 @@ impl<'tcx> Inliner<'tcx> {
                     // To avoid repeated O(n) insert, push any new statements to the end and rotate
                     // the slice once.
                     let mut n = 0;
+                    if remap_destination {
+                        caller_body[block].statements.push(Statement {
+                            source_info: callsite.source_info,
+                            kind: StatementKind::Assign(Box::new((
+                                dest,
+                                Rvalue::Use(Operand::Move(destination_local.into())),
+                            ))),
+                        });
+                        n += 1;
+                    }
                     for local in callee_body.vars_and_temps_iter().rev() {
                         if !callee_body.local_decls[local].internal
                             && integrator.always_live_locals.contains(local)
@@ -959,7 +984,7 @@ struct Integrator<'a, 'tcx> {
     new_locals: RangeFrom<Local>,
     new_scopes: RangeFrom<SourceScope>,
     new_blocks: RangeFrom<BasicBlock>,
-    destination: Place<'tcx>,
+    destination: Local,
     callsite_scope: SourceScopeData<'tcx>,
     callsite: &'a CallSite<'tcx>,
     cleanup_block: Option<BasicBlock>,
@@ -972,7 +997,7 @@ struct Integrator<'a, 'tcx> {
 impl Integrator<'_, '_> {
     fn map_local(&self, local: Local) -> Local {
         let new = if local == RETURN_PLACE {
-            self.destination.local
+            self.destination
         } else {
             let idx = local.index() - 1;
             if idx < self.args.len() {
@@ -1053,27 +1078,6 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> {
         *span = span.fresh_expansion(self.expn_data);
     }
 
-    fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
-        for elem in place.projection {
-            // FIXME: Make sure that return place is not used in an indexing projection, since it
-            // won't be rebased as it is supposed to be.
-            assert_ne!(ProjectionElem::Index(RETURN_PLACE), elem);
-        }
-
-        // If this is the `RETURN_PLACE`, we need to rebase any projections onto it.
-        let dest_proj_len = self.destination.projection.len();
-        if place.local == RETURN_PLACE && dest_proj_len > 0 {
-            let mut projs = Vec::with_capacity(dest_proj_len + place.projection.len());
-            projs.extend(self.destination.projection);
-            projs.extend(place.projection);
-
-            place.projection = self.tcx.intern_place_elems(&*projs);
-        }
-        // Handles integrating any locals that occur in the base
-        // or projections
-        self.super_place(place, context, location)
-    }
-
     fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
         self.in_cleanup_block = data.is_cleanup;
         self.super_basic_block_data(block, data);
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 16b8a901f36..20b7fdcfe6d 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -487,7 +487,6 @@ fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>
 fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
     let passes: &[&dyn MirPass<'tcx>] = &[
         &cleanup_post_borrowck::CleanupPostBorrowck,
-        &simplify_branches::SimplifyConstCondition::new("initial"),
         &remove_noop_landing_pads::RemoveNoopLandingPads,
         &simplify::SimplifyCfg::new("early-opt"),
         &deref_separator::Derefer,
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 3a2bf051516..42124f5a480 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -215,7 +215,7 @@ struct ReplacementVisitor<'tcx, 'll> {
     replacements: ReplacementMap<'tcx>,
     /// This is used to check that we are not leaving references to replaced locals behind.
     all_dead_locals: BitSet<Local>,
-    /// Pre-computed list of all "new" locals for each "old" local.  This is used to expand storage
+    /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
     /// and deinit statement and debuginfo.
     fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
 }