about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_const_eval/src/interpret/terminator.rs1
-rw-r--r--compiler/rustc_middle/src/mir/mono.rs1
-rw-r--r--compiler/rustc_middle/src/mir/visit.rs1
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs23
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs1
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs1
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs1
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs121
-rw-r--r--compiler/rustc_monomorphize/src/collector.rs1
-rw-r--r--compiler/rustc_monomorphize/src/partitioning.rs2
-rw-r--r--compiler/rustc_smir/src/rustc_smir/convert/ty.rs1
-rw-r--r--compiler/rustc_ty_utils/src/abi.rs15
-rw-r--r--compiler/rustc_ty_utils/src/instance.rs17
13 files changed, 175 insertions, 11 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs
index b7ffb4a16fc..4c8f68b25b5 100644
--- a/compiler/rustc_const_eval/src/interpret/terminator.rs
+++ b/compiler/rustc_const_eval/src/interpret/terminator.rs
@@ -545,6 +545,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             ty::InstanceDef::VTableShim(..)
             | ty::InstanceDef::ReifyShim(..)
             | ty::InstanceDef::ClosureOnceShim { .. }
+            | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
             | ty::InstanceDef::FnPtrShim(..)
             | ty::InstanceDef::DropGlue(..)
             | ty::InstanceDef::CloneShim(..)
diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs
index 91fdf0b3129..4a29171d8bf 100644
--- a/compiler/rustc_middle/src/mir/mono.rs
+++ b/compiler/rustc_middle/src/mir/mono.rs
@@ -402,6 +402,7 @@ impl<'tcx> CodegenUnit<'tcx> {
                             | InstanceDef::FnPtrShim(..)
                             | InstanceDef::Virtual(..)
                             | InstanceDef::ClosureOnceShim { .. }
+                            | InstanceDef::ConstructCoroutineInClosureShim { .. }
                             | InstanceDef::DropGlue(..)
                             | InstanceDef::CloneShim(..)
                             | InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs
index 591104ecc66..6bc58adea0f 100644
--- a/compiler/rustc_middle/src/mir/visit.rs
+++ b/compiler/rustc_middle/src/mir/visit.rs
@@ -345,6 +345,7 @@ macro_rules! make_mir_visitor {
                         ty::InstanceDef::Virtual(_def_id, _) |
                         ty::InstanceDef::ThreadLocalShim(_def_id) |
                         ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
+                        ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
                         ty::InstanceDef::DropGlue(_def_id, None) => {}
 
                         ty::InstanceDef::FnPtrShim(_def_id, ty) |
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index 293fdb026b6..41ae136851e 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -82,11 +82,25 @@ pub enum InstanceDef<'tcx> {
     /// details on that).
     Virtual(DefId, usize),
 
-    /// `<[FnMut closure] as FnOnce>::call_once`.
+    /// `<[FnMut/Fn closure] as FnOnce>::call_once`.
     ///
     /// The `DefId` is the ID of the `call_once` method in `FnOnce`.
+    ///
+    /// This generates a body that will just borrow the (owned) self type,
+    /// and dispatch to the `FnMut::call_mut` instance for the closure.
     ClosureOnceShim { call_once: DefId, track_caller: bool },
 
+    /// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once` or
+    /// `<[Fn coroutine-closure] as FnMut>::call_mut`.
+    ///
+    /// The body generated here differs significantly from the `ClosureOnceShim`,
+    /// since we need to generate a distinct coroutine type that will move the
+    /// closure's upvars *out* of the closure.
+    ConstructCoroutineInClosureShim {
+        coroutine_closure_def_id: DefId,
+        target_kind: ty::ClosureKind,
+    },
+
     /// Compiler-generated accessor for thread locals which returns a reference to the thread local
     /// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
     /// native support.
@@ -168,6 +182,10 @@ impl<'tcx> InstanceDef<'tcx> {
             | InstanceDef::Intrinsic(def_id)
             | InstanceDef::ThreadLocalShim(def_id)
             | InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
+            | ty::InstanceDef::ConstructCoroutineInClosureShim {
+                coroutine_closure_def_id: def_id,
+                target_kind: _,
+            }
             | InstanceDef::DropGlue(def_id, _)
             | InstanceDef::CloneShim(def_id, _)
             | InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -187,6 +205,7 @@ impl<'tcx> InstanceDef<'tcx> {
             | InstanceDef::Virtual(..)
             | InstanceDef::Intrinsic(..)
             | InstanceDef::ClosureOnceShim { .. }
+            | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::CloneShim(..)
             | InstanceDef::FnPtrAddrShim(..) => None,
@@ -282,6 +301,7 @@ impl<'tcx> InstanceDef<'tcx> {
             | InstanceDef::FnPtrShim(..)
             | InstanceDef::DropGlue(_, Some(_)) => false,
             InstanceDef::ClosureOnceShim { .. }
+            | InstanceDef::ConstructCoroutineInClosureShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::Item(_)
             | InstanceDef::Intrinsic(..)
@@ -319,6 +339,7 @@ fn fmt_instance(
         InstanceDef::Virtual(_, num) => write!(f, " - virtual#{num}"),
         InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
         InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
+        InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
         InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
         InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
         InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 4348d01eff5..05875a9798b 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -1680,6 +1680,7 @@ impl<'tcx> TyCtxt<'tcx> {
             | ty::InstanceDef::FnPtrShim(..)
             | ty::InstanceDef::Virtual(..)
             | ty::InstanceDef::ClosureOnceShim { .. }
+            | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
             | ty::InstanceDef::DropGlue(..)
             | ty::InstanceDef::CloneShim(..)
             | ty::InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 67668a216de..7c731b070a7 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -317,6 +317,7 @@ impl<'tcx> Inliner<'tcx> {
             | InstanceDef::ReifyShim(_)
             | InstanceDef::FnPtrShim(..)
             | InstanceDef::ClosureOnceShim { .. }
+            | InstanceDef::ConstructCoroutineInClosureShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::CloneShim(..)
             | InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index d30e0bad813..3f3dc9145b6 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -87,6 +87,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
                 | InstanceDef::ReifyShim(_)
                 | InstanceDef::FnPtrShim(..)
                 | InstanceDef::ClosureOnceShim { .. }
+                | InstanceDef::ConstructCoroutineInClosureShim { .. }
                 | InstanceDef::ThreadLocalShim { .. }
                 | InstanceDef::CloneShim(..) => {}
 
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index 89414ce940e..29b83f58ef5 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::lang_items::LangItem;
 use rustc_middle::mir::*;
 use rustc_middle::query::Providers;
-use rustc_middle::ty::GenericArgs;
 use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt};
+use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL};
 use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
 
 use rustc_index::{Idx, IndexVec};
@@ -66,6 +66,21 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
             build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
         }
 
+        ty::InstanceDef::ConstructCoroutineInClosureShim {
+            coroutine_closure_def_id,
+            target_kind,
+        } => match target_kind {
+            ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"),
+            ty::ClosureKind::FnMut => {
+                let body = build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
+                // No need to optimize the body, it has already been optimized.
+                return body;
+            }
+            ty::ClosureKind::FnOnce => {
+                build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
+            }
+        },
+
         ty::InstanceDef::DropGlue(def_id, ty) => {
             // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
             // of this function. Is this intentional?
@@ -981,3 +996,107 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t
     let source = MirSource::from_instance(ty::InstanceDef::FnPtrAddrShim(def_id, self_ty));
     new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span)
 }
+
+fn build_construct_coroutine_by_move_shim<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    coroutine_closure_def_id: DefId,
+) -> Body<'tcx> {
+    let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
+    let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
+        bug!();
+    };
+
+    let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
+        tcx.mk_fn_sig(
+            [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()),
+            sig.to_coroutine_given_kind_and_upvars(
+                tcx,
+                args.as_coroutine_closure().parent_args(),
+                tcx.coroutine_for_closure(coroutine_closure_def_id),
+                ty::ClosureKind::FnOnce,
+                tcx.lifetimes.re_erased,
+                args.as_coroutine_closure().tupled_upvars_ty(),
+                args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
+            ),
+            sig.c_variadic,
+            sig.unsafety,
+            sig.abi,
+        )
+    });
+    let sig = tcx.liberate_late_bound_regions(coroutine_closure_def_id, poly_sig);
+    let ty::Coroutine(coroutine_def_id, coroutine_args) = *sig.output().kind() else {
+        bug!();
+    };
+
+    let span = tcx.def_span(coroutine_closure_def_id);
+    let locals = local_decls_for_sig(&sig, span);
+
+    let mut fields = vec![];
+    for idx in 1..sig.inputs().len() {
+        fields.push(Operand::Move(Local::from_usize(idx + 1).into()));
+    }
+    for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() {
+        fields.push(Operand::Move(tcx.mk_place_field(
+            Local::from_usize(1).into(),
+            FieldIdx::from_usize(idx),
+            ty,
+        )));
+    }
+
+    let source_info = SourceInfo::outermost(span);
+    let rvalue = Rvalue::Aggregate(
+        Box::new(AggregateKind::Coroutine(coroutine_def_id, coroutine_args)),
+        IndexVec::from_raw(fields),
+    );
+    let stmt = Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
+    };
+    let statements = vec![stmt];
+    let start_block = BasicBlockData {
+        statements,
+        terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
+        is_cleanup: false,
+    };
+
+    let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
+        coroutine_closure_def_id,
+        target_kind: ty::ClosureKind::FnOnce,
+    });
+
+    new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span)
+}
+
+fn build_construct_coroutine_by_mut_shim<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    coroutine_closure_def_id: DefId,
+) -> Body<'tcx> {
+    let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone();
+    let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
+    let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else {
+        bug!();
+    };
+    let args = args.as_coroutine_closure();
+
+    body.local_decls[RETURN_PLACE].ty =
+        tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| {
+            sig.to_coroutine_given_kind_and_upvars(
+                tcx,
+                args.parent_args(),
+                tcx.coroutine_for_closure(coroutine_closure_def_id),
+                ty::ClosureKind::FnMut,
+                tcx.lifetimes.re_erased,
+                args.tupled_upvars_ty(),
+                args.coroutine_captures_by_ref_ty(),
+            )
+        }));
+    body.local_decls[CAPTURE_STRUCT_LOCAL].ty =
+        Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty);
+
+    body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
+        coroutine_closure_def_id,
+        target_kind: ty::ClosureKind::FnMut,
+    });
+
+    body
+}
diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs
index b86244e5a4b..698fd634114 100644
--- a/compiler/rustc_monomorphize/src/collector.rs
+++ b/compiler/rustc_monomorphize/src/collector.rs
@@ -983,6 +983,7 @@ fn visit_instance_use<'tcx>(
         | ty::InstanceDef::VTableShim(..)
         | ty::InstanceDef::ReifyShim(..)
         | ty::InstanceDef::ClosureOnceShim { .. }
+        | InstanceDef::ConstructCoroutineInClosureShim { .. }
         | ty::InstanceDef::Item(..)
         | ty::InstanceDef::FnPtrShim(..)
         | ty::InstanceDef::CloneShim(..)
diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs
index 7ff182381b8..46bd33c89e7 100644
--- a/compiler/rustc_monomorphize/src/partitioning.rs
+++ b/compiler/rustc_monomorphize/src/partitioning.rs
@@ -620,6 +620,7 @@ fn characteristic_def_id_of_mono_item<'tcx>(
                 | ty::InstanceDef::ReifyShim(..)
                 | ty::InstanceDef::FnPtrShim(..)
                 | ty::InstanceDef::ClosureOnceShim { .. }
+                | InstanceDef::ConstructCoroutineInClosureShim { .. }
                 | ty::InstanceDef::Intrinsic(..)
                 | ty::InstanceDef::DropGlue(..)
                 | ty::InstanceDef::Virtual(..)
@@ -783,6 +784,7 @@ fn mono_item_visibility<'tcx>(
         | InstanceDef::Virtual(..)
         | InstanceDef::Intrinsic(..)
         | InstanceDef::ClosureOnceShim { .. }
+        | InstanceDef::ConstructCoroutineInClosureShim { .. }
         | InstanceDef::DropGlue(..)
         | InstanceDef::CloneShim(..)
         | InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden,
diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
index e3fd64eb189..e0e9815cf40 100644
--- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
+++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
@@ -799,6 +799,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> {
             | ty::InstanceDef::ReifyShim(..)
             | ty::InstanceDef::FnPtrAddrShim(..)
             | ty::InstanceDef::ClosureOnceShim { .. }
+            | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
             | ty::InstanceDef::ThreadLocalShim(..)
             | ty::InstanceDef::DropGlue(..)
             | ty::InstanceDef::CloneShim(..)
diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs
index 8d1bdec6684..3ea48ee824d 100644
--- a/compiler/rustc_ty_utils/src/abi.rs
+++ b/compiler/rustc_ty_utils/src/abi.rs
@@ -111,11 +111,14 @@ fn fn_sig_for_fn_abi<'tcx>(
                 kind: ty::BoundRegionKind::BrEnv,
             };
             let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br);
-            let env_ty = tcx.closure_env_ty(
-                Ty::new_coroutine_closure(tcx, def_id, args),
-                args.as_coroutine_closure().kind(),
-                env_region,
-            );
+
+            let mut kind = args.as_coroutine_closure().kind();
+            if let InstanceDef::ConstructCoroutineInClosureShim { target_kind, .. } = instance.def {
+                kind = target_kind;
+            }
+
+            let env_ty =
+                tcx.closure_env_ty(Ty::new_coroutine_closure(tcx, def_id, args), kind, env_region);
 
             let sig = sig.skip_binder();
             ty::Binder::bind_with_vars(
@@ -125,7 +128,7 @@ fn fn_sig_for_fn_abi<'tcx>(
                         tcx,
                         args.as_coroutine_closure().parent_args(),
                         tcx.coroutine_for_closure(def_id),
-                        args.as_coroutine_closure().kind(),
+                        kind,
                         env_region,
                         args.as_coroutine_closure().tupled_upvars_ty(),
                         args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs
index 50d27ed4ced..fc224d242db 100644
--- a/compiler/rustc_ty_utils/src/instance.rs
+++ b/compiler/rustc_ty_utils/src/instance.rs
@@ -283,10 +283,21 @@ fn resolve_associated_item<'tcx>(
                         tcx.item_name(trait_item_id)
                     ),
                 }
-            } else if tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id).is_some() {
+            } else if let Some(target_kind) = tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id)
+            {
                 match *rcvr_args.type_at(0).kind() {
-                    ty::CoroutineClosure(closure_def_id, args) => {
-                        Some(Instance::new(closure_def_id, args))
+                    ty::CoroutineClosure(coroutine_closure_def_id, args) => {
+                        if target_kind > args.as_coroutine_closure().kind() {
+                            Some(Instance {
+                                def: ty::InstanceDef::ConstructCoroutineInClosureShim {
+                                    coroutine_closure_def_id,
+                                    target_kind,
+                                },
+                                args,
+                            })
+                        } else {
+                            Some(Instance::new(coroutine_closure_def_id, args))
+                        }
                     }
                     _ => bug!(
                         "no built-in definition for `{trait_ref}::{}` for non-lending-closure type",