about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs3
-rw-r--r--compiler/rustc_mir_transform/src/coroutine/by_move_body.rs108
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs1
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs1
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs4
-rw-r--r--compiler/rustc_mir_transform/src/pass_manager.rs6
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs12
7 files changed, 135 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index bde879f6067..297b2fa143d 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -50,6 +50,9 @@
 //! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
 //! Otherwise it drops all the values in scope at the last suspension point.
 
+mod by_move_body;
+pub use by_move_body::ByMoveBody;
+
 use crate::abort_unwinding_calls;
 use crate::deref_separator::deref_finder;
 use crate::errors;
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
new file mode 100644
index 00000000000..4e3e70bdafe
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
@@ -0,0 +1,108 @@
+use rustc_data_structures::fx::FxIndexSet;
+use rustc_hir as hir;
+use rustc_middle::mir::visit::MutVisitor;
+use rustc_middle::mir::{self, MirPass};
+use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt};
+use rustc_target::abi::FieldIdx;
+
+pub struct ByMoveBody;
+
+impl<'tcx> MirPass<'tcx> for ByMoveBody {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
+        let Some(coroutine_def_id) = body.source.def_id().as_local() else {
+            return;
+        };
+        let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
+            tcx.coroutine_kind(coroutine_def_id)
+        else {
+            return;
+        };
+        let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+        let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() };
+        if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce {
+            return;
+        }
+
+        let mut by_ref_fields = FxIndexSet::default();
+        let by_move_upvars = Ty::new_tup_from_iter(
+            tcx,
+            tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
+                if capture.is_by_ref() {
+                    by_ref_fields.insert(FieldIdx::from_usize(idx));
+                }
+                capture.place.ty()
+            }),
+        );
+        let by_move_coroutine_ty = Ty::new_coroutine(
+            tcx,
+            coroutine_def_id.to_def_id(),
+            ty::CoroutineArgs::new(
+                tcx,
+                ty::CoroutineArgsParts {
+                    parent_args: args.as_coroutine().parent_args(),
+                    kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
+                    resume_ty: args.as_coroutine().resume_ty(),
+                    yield_ty: args.as_coroutine().yield_ty(),
+                    return_ty: args.as_coroutine().return_ty(),
+                    witness: args.as_coroutine().witness(),
+                    tupled_upvars_ty: by_move_upvars,
+                },
+            )
+            .args,
+        );
+
+        let mut by_move_body = body.clone();
+        MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
+        by_move_body.source = mir::MirSource {
+            instance: InstanceDef::CoroutineByMoveShim {
+                coroutine_def_id: coroutine_def_id.to_def_id(),
+            },
+            promoted: None,
+        };
+
+        body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
+    }
+}
+
+struct MakeByMoveBody<'tcx> {
+    tcx: TyCtxt<'tcx>,
+    by_ref_fields: FxIndexSet<FieldIdx>,
+    by_move_coroutine_ty: Ty<'tcx>,
+}
+
+impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
+    fn tcx(&self) -> TyCtxt<'tcx> {
+        self.tcx
+    }
+
+    fn visit_place(
+        &mut self,
+        place: &mut mir::Place<'tcx>,
+        context: mir::visit::PlaceContext,
+        location: mir::Location,
+    ) {
+        if place.local == ty::CAPTURE_STRUCT_LOCAL
+            && !place.projection.is_empty()
+            && let mir::ProjectionElem::Field(idx, ty) = place.projection[0]
+            && self.by_ref_fields.contains(&idx)
+        {
+            let (begin, end) = place.projection[1..].split_first().unwrap();
+            assert_eq!(*begin, mir::ProjectionElem::Deref);
+            *place = mir::Place {
+                local: place.local,
+                projection: self.tcx.mk_place_elems_from_iter(
+                    [mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)]
+                        .into_iter()
+                        .chain(end.iter().copied()),
+                ),
+            };
+        }
+        self.super_place(place, context, location);
+    }
+
+    fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
+        if local == ty::CAPTURE_STRUCT_LOCAL {
+            local_decl.ty = self.by_move_coroutine_ty;
+        }
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 7c731b070a7..24bc84a235c 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -318,6 +318,7 @@ impl<'tcx> Inliner<'tcx> {
             | InstanceDef::FnPtrShim(..)
             | InstanceDef::ClosureOnceShim { .. }
             | InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | InstanceDef::CoroutineByMoveShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::CloneShim(..)
             | InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index 3f3dc9145b6..77ff780393e 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -88,6 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
                 | InstanceDef::FnPtrShim(..)
                 | InstanceDef::ClosureOnceShim { .. }
                 | InstanceDef::ConstructCoroutineInClosureShim { .. }
+                | InstanceDef::CoroutineByMoveShim { .. }
                 | InstanceDef::ThreadLocalShim { .. }
                 | InstanceDef::CloneShim(..) => {}
 
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 69f93fa3a0e..031515ea958 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -307,6 +307,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> {
             &Lint(check_packed_ref::CheckPackedRef),
             &Lint(check_const_item_mutation::CheckConstItemMutation),
             &Lint(function_item_references::FunctionItemReferences),
+            // If this is an async closure's output coroutine, generate
+            // by-move and by-mut bodies if needed. We do this first so
+            // they can be optimized in lockstep with their parent bodies.
+            &coroutine::ByMoveBody,
             // What we need to do constant evaluation.
             &simplify::SimplifyCfg::Initial,
             &rustc_peek::SanityCheck, // Just a lint
diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs
index c1ef2b9f887..c7e770904fb 100644
--- a/compiler/rustc_mir_transform/src/pass_manager.rs
+++ b/compiler/rustc_mir_transform/src/pass_manager.rs
@@ -189,6 +189,12 @@ fn run_passes_inner<'tcx>(
 
         body.pass_count = 1;
     }
+
+    if let Some(coroutine) = body.coroutine.as_mut()
+        && let Some(by_move_body) = coroutine.by_move_body.as_mut()
+    {
+        run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
+    }
 }
 
 pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) {
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index 29b83f58ef5..668ccdd8735 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -81,6 +81,18 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
             }
         },
 
+        ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => {
+            return tcx
+                .optimized_mir(coroutine_def_id)
+                .coroutine
+                .as_ref()
+                .unwrap()
+                .by_move_body
+                .as_ref()
+                .unwrap()
+                .clone();
+        }
+
         ty::InstanceDef::DropGlue(def_id, ty) => {
             // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
             // of this function. Is this intentional?