about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_const_eval/src/transform/promote_consts.rs2
-rw-r--r--compiler/rustc_middle/src/mir/mod.rs29
-rw-r--r--compiler/rustc_middle/src/thir.rs2
-rw-r--r--compiler/rustc_mir_build/src/build/mod.rs120
-rw-r--r--compiler/rustc_mir_build/src/build/scope.rs10
5 files changed, 85 insertions, 78 deletions
diff --git a/compiler/rustc_const_eval/src/transform/promote_consts.rs b/compiler/rustc_const_eval/src/transform/promote_consts.rs
index 8b2ea2dc21d..155cf4ff9e2 100644
--- a/compiler/rustc_const_eval/src/transform/promote_consts.rs
+++ b/compiler/rustc_const_eval/src/transform/promote_consts.rs
@@ -969,7 +969,7 @@ pub fn promote_candidates<'tcx>(
             0,
             vec![],
             body.span,
-            body.coroutine_kind(),
+            None,
             body.tainted_by_errors,
         );
         promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial);
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index 45dbfe6b8a7..d426f6d8969 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -263,6 +263,23 @@ pub struct CoroutineInfo<'tcx> {
     pub coroutine_kind: CoroutineKind,
 }
 
+impl<'tcx> CoroutineInfo<'tcx> {
+    // Sets up `CoroutineInfo` for a pre-coroutine-transform MIR body.
+    pub fn initial(
+        coroutine_kind: CoroutineKind,
+        yield_ty: Ty<'tcx>,
+        resume_ty: Ty<'tcx>,
+    ) -> CoroutineInfo<'tcx> {
+        CoroutineInfo {
+            coroutine_kind,
+            yield_ty: Some(yield_ty),
+            resume_ty: Some(resume_ty),
+            coroutine_drop: None,
+            coroutine_layout: None,
+        }
+    }
+}
+
 /// The lowered representation of a single function.
 #[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)]
 pub struct Body<'tcx> {
@@ -367,7 +384,7 @@ impl<'tcx> Body<'tcx> {
         arg_count: usize,
         var_debug_info: Vec<VarDebugInfo<'tcx>>,
         span: Span,
-        coroutine_kind: Option<CoroutineKind>,
+        coroutine: Option<Box<CoroutineInfo<'tcx>>>,
         tainted_by_errors: Option<ErrorGuaranteed>,
     ) -> Self {
         // We need `arg_count` locals, and one for the return place.
@@ -384,15 +401,7 @@ impl<'tcx> Body<'tcx> {
             source,
             basic_blocks: BasicBlocks::new(basic_blocks),
             source_scopes,
-            coroutine: coroutine_kind.map(|coroutine_kind| {
-                Box::new(CoroutineInfo {
-                    yield_ty: None,
-                    resume_ty: None,
-                    coroutine_drop: None,
-                    coroutine_layout: None,
-                    coroutine_kind,
-                })
-            }),
+            coroutine,
             local_decls,
             user_type_annotations,
             arg_count,
diff --git a/compiler/rustc_middle/src/thir.rs b/compiler/rustc_middle/src/thir.rs
index 2b5983314ee..b4b8387c262 100644
--- a/compiler/rustc_middle/src/thir.rs
+++ b/compiler/rustc_middle/src/thir.rs
@@ -86,8 +86,6 @@ macro_rules! thir_with_elements {
     }
 }
 
-pub const UPVAR_ENV_PARAM: ParamId = ParamId::from_u32(0);
-
 thir_with_elements! {
     body_type: BodyTy<'tcx>,
 
diff --git a/compiler/rustc_mir_build/src/build/mod.rs b/compiler/rustc_mir_build/src/build/mod.rs
index c4cade83947..b8d08319422 100644
--- a/compiler/rustc_mir_build/src/build/mod.rs
+++ b/compiler/rustc_mir_build/src/build/mod.rs
@@ -9,7 +9,7 @@ use rustc_errors::ErrorGuaranteed;
 use rustc_hir as hir;
 use rustc_hir::def::DefKind;
 use rustc_hir::def_id::{DefId, LocalDefId};
-use rustc_hir::{CoroutineKind, Node};
+use rustc_hir::Node;
 use rustc_index::bit_set::GrowableBitSet;
 use rustc_index::{Idx, IndexSlice, IndexVec};
 use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
@@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> {
     check_overflow: bool,
     fn_span: Span,
     arg_count: usize,
-    coroutine_kind: Option<CoroutineKind>,
+    coroutine: Option<Box<CoroutineInfo<'tcx>>>,
 
     /// The current set of scopes, updated as we traverse;
     /// see the `scope` module for more details.
@@ -458,7 +458,6 @@ fn construct_fn<'tcx>(
 ) -> Body<'tcx> {
     let span = tcx.def_span(fn_def);
     let fn_id = tcx.local_def_id_to_hir_id(fn_def);
-    let coroutine_kind = tcx.coroutine_kind(fn_def);
 
     // The representation of thir for `-Zunpretty=thir-tree` relies on
     // the entry expression being the last element of `thir.exprs`.
@@ -488,17 +487,15 @@ fn construct_fn<'tcx>(
 
     let arguments = &thir.params;
 
-    let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
-        let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
-        let coroutine_sig = match coroutine_ty.kind() {
-            ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
-            _ => {
-                span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
-            }
-        };
-        (Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
-    } else {
-        (None, None, fn_sig.output())
+    let return_ty = fn_sig.output();
+    let coroutine = match tcx.type_of(fn_def).instantiate_identity().kind() {
+        ty::Coroutine(_, args) => Some(Box::new(CoroutineInfo::initial(
+            tcx.coroutine_kind(fn_def).unwrap(),
+            args.as_coroutine().yield_ty(),
+            args.as_coroutine().resume_ty(),
+        ))),
+        ty::Closure(..) | ty::FnDef(..) => None,
+        ty => span_bug!(span_with_body, "unexpected type of body: {ty:?}"),
     };
 
     if let Some(custom_mir_attr) =
@@ -529,7 +526,7 @@ fn construct_fn<'tcx>(
         safety,
         return_ty,
         return_ty_span,
-        coroutine_kind,
+        coroutine,
     );
 
     let call_site_scope =
@@ -563,11 +560,6 @@ fn construct_fn<'tcx>(
         None
     };
 
-    if coroutine_kind.is_some() {
-        body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
-        body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
-    }
-
     body
 }
 
@@ -632,47 +624,62 @@ fn construct_const<'a, 'tcx>(
 fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -> Body<'_> {
     let span = tcx.def_span(def_id);
     let hir_id = tcx.local_def_id_to_hir_id(def_id);
-    let coroutine_kind = tcx.coroutine_kind(def_id);
 
-    let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
+    let (inputs, output, coroutine) = match tcx.def_kind(def_id) {
         DefKind::Const
         | DefKind::AssocConst
         | DefKind::AnonConst
         | DefKind::InlineConst
-        | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
+        | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
         DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
             let sig = tcx.liberate_late_bound_regions(
                 def_id.to_def_id(),
                 tcx.fn_sig(def_id).instantiate_identity(),
             );
-            (sig.inputs().to_vec(), sig.output(), None, None)
-        }
-        DefKind::Closure if coroutine_kind.is_some() => {
-            let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
-            let ty::Coroutine(_, args) = coroutine_ty.kind() else {
-                bug!("expected type of coroutine-like closure to be a coroutine")
-            };
-            let args = args.as_coroutine();
-            let resume_ty = args.resume_ty();
-            let yield_ty = args.yield_ty();
-            let return_ty = args.return_ty();
-            (vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
+            (sig.inputs().to_vec(), sig.output(), None)
         }
         DefKind::Closure => {
             let closure_ty = tcx.type_of(def_id).instantiate_identity();
-            let ty::Closure(_, args) = closure_ty.kind() else {
-                bug!("expected type of closure to be a closure")
-            };
-            let args = args.as_closure();
-            let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
-            let self_ty = match args.kind() {
-                ty::ClosureKind::Fn => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
-                ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
-                ty::ClosureKind::FnOnce => closure_ty,
-            };
-            ([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
+            match closure_ty.kind() {
+                ty::Closure(_, args) => {
+                    let args = args.as_closure();
+                    let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
+                    let self_ty = match args.kind() {
+                        ty::ClosureKind::Fn => {
+                            Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
+                        }
+                        ty::ClosureKind::FnMut => {
+                            Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
+                        }
+                        ty::ClosureKind::FnOnce => closure_ty,
+                    };
+                    (
+                        [self_ty].into_iter().chain(sig.inputs().to_vec()).collect(),
+                        sig.output(),
+                        None,
+                    )
+                }
+                ty::Coroutine(_, args) => {
+                    let args = args.as_coroutine();
+                    let resume_ty = args.resume_ty();
+                    let yield_ty = args.yield_ty();
+                    let return_ty = args.return_ty();
+                    (
+                        vec![closure_ty, args.resume_ty()],
+                        return_ty,
+                        Some(Box::new(CoroutineInfo::initial(
+                            tcx.coroutine_kind(def_id).unwrap(),
+                            yield_ty,
+                            resume_ty,
+                        ))),
+                    )
+                }
+                _ => {
+                    span_bug!(span, "expected type of closure body to be a closure or coroutine");
+                }
+            }
         }
-        dk => bug!("{:?} is not a body: {:?}", def_id, dk),
+        dk => span_bug!(span, "{:?} is not a body: {:?}", def_id, dk),
     };
 
     let source_info = SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE };
@@ -696,7 +703,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
 
     cfg.terminate(START_BLOCK, source_info, TerminatorKind::Unreachable);
 
-    let mut body = Body::new(
+    Body::new(
         MirSource::item(def_id.to_def_id()),
         cfg.basic_blocks,
         source_scopes,
@@ -705,16 +712,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
         inputs.len(),
         vec![],
         span,
-        coroutine_kind,
+        coroutine,
         Some(guar),
-    );
-
-    body.coroutine.as_mut().map(|gen| {
-        gen.yield_ty = yield_ty;
-        gen.resume_ty = resume_ty;
-    });
-
-    body
+    )
 }
 
 impl<'a, 'tcx> Builder<'a, 'tcx> {
@@ -728,7 +728,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         safety: Safety,
         return_ty: Ty<'tcx>,
         return_span: Span,
-        coroutine_kind: Option<CoroutineKind>,
+        coroutine: Option<Box<CoroutineInfo<'tcx>>>,
     ) -> Builder<'a, 'tcx> {
         let tcx = infcx.tcx;
         let attrs = tcx.hir().attrs(hir_id);
@@ -759,7 +759,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             cfg: CFG { basic_blocks: IndexVec::new() },
             fn_span: span,
             arg_count,
-            coroutine_kind,
+            coroutine,
             scopes: scope::Scopes::new(),
             block_context: BlockContext::new(),
             source_scopes: IndexVec::new(),
@@ -803,7 +803,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             self.arg_count,
             self.var_debug_info,
             self.fn_span,
-            self.coroutine_kind,
+            self.coroutine,
             None,
         )
     }
diff --git a/compiler/rustc_mir_build/src/build/scope.rs b/compiler/rustc_mir_build/src/build/scope.rs
index 1a700ac7342..48b237f3ae6 100644
--- a/compiler/rustc_mir_build/src/build/scope.rs
+++ b/compiler/rustc_mir_build/src/build/scope.rs
@@ -706,7 +706,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         // If we are emitting a `drop` statement, we need to have the cached
         // diverge cleanup pads ready in case that drop panics.
         let needs_cleanup = self.scopes.scopes.last().is_some_and(|scope| scope.needs_cleanup());
-        let is_coroutine = self.coroutine_kind.is_some();
+        let is_coroutine = self.coroutine.is_some();
         let unwind_to = if needs_cleanup { self.diverge_cleanup() } else { DropIdx::MAX };
 
         let scope = self.scopes.scopes.last().expect("leave_top_scope called with no scopes");
@@ -960,7 +960,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         // path, we only need to invalidate the cache for drops that happen on
         // the unwind or coroutine drop paths. This means that for
         // non-coroutines we don't need to invalidate caches for `DropKind::Storage`.
-        let invalidate_caches = needs_drop || self.coroutine_kind.is_some();
+        let invalidate_caches = needs_drop || self.coroutine.is_some();
         for scope in self.scopes.scopes.iter_mut().rev() {
             if invalidate_caches {
                 scope.invalidate_cache();
@@ -1073,7 +1073,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             return cached_drop;
         }
 
-        let is_coroutine = self.coroutine_kind.is_some();
+        let is_coroutine = self.coroutine.is_some();
         for scope in &mut self.scopes.scopes[uncached_scope..=target] {
             for drop in &scope.drops {
                 if is_coroutine || drop.kind == DropKind::Value {
@@ -1318,7 +1318,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {
         blocks[ROOT_NODE] = continue_block;
 
         drops.build_mir::<ExitScopes>(&mut self.cfg, &mut blocks);
-        let is_coroutine = self.coroutine_kind.is_some();
+        let is_coroutine = self.coroutine.is_some();
 
         // Link the exit drop tree to unwind drop tree.
         if drops.drops.iter().any(|(drop, _)| drop.kind == DropKind::Value) {
@@ -1355,7 +1355,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {
 
     /// Build the unwind and coroutine drop trees.
     pub(crate) fn build_drop_trees(&mut self) {
-        if self.coroutine_kind.is_some() {
+        if self.coroutine.is_some() {
             self.build_coroutine_drop_trees();
         } else {
             Self::build_unwind_tree(