about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNicholas Nethercote <n.nethercote@gmail.com>2024-08-28 13:47:22 +1000
committerNicholas Nethercote <n.nethercote@gmail.com>2024-08-30 10:30:57 +1000
commit590a02173bba33fd7cd50bf9ae9061727ebe24ca (patch)
tree0bd6b53380ff8bcf03e85b1461edae3f4f2c733f
parent408481f4d876e58ab2b4d520706cfe07b223dc6f (diff)
downloadrust-590a02173bba33fd7cd50bf9ae9061727ebe24ca.tar.gz
rust-590a02173bba33fd7cd50bf9ae9061727ebe24ca.zip
Factor out some repetitive code.
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs119
1 files changed, 38 insertions, 81 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index b9480d31e88..8ee2b8cbb9c 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -63,7 +63,9 @@ use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
 use rustc_index::{Idx, IndexVec};
 use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
 use rustc_middle::mir::*;
-use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, InstanceKind, Ty, TyCtxt};
+use rustc_middle::ty::{
+    self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt,
+};
 use rustc_middle::{bug, span_bug};
 use rustc_mir_dataflow::impls::{
     MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
@@ -210,14 +212,10 @@ impl<'tcx> TransformVisitor<'tcx> {
             // `gen` continues return `None`
             CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
                 let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
-                Rvalue::Aggregate(
-                    Box::new(AggregateKind::Adt(
-                        option_def_id,
-                        VariantIdx::ZERO,
-                        self.tcx.mk_args(&[self.old_yield_ty.into()]),
-                        None,
-                        None,
-                    )),
+                make_aggregate_adt(
+                    option_def_id,
+                    VariantIdx::ZERO,
+                    self.tcx.mk_args(&[self.old_yield_ty.into()]),
                     IndexVec::new(),
                 )
             }
@@ -266,64 +264,28 @@ impl<'tcx> TransformVisitor<'tcx> {
         is_return: bool,
         statements: &mut Vec<Statement<'tcx>>,
     ) {
+        const ZERO: VariantIdx = VariantIdx::ZERO;
+        const ONE: VariantIdx = VariantIdx::from_usize(1);
         let rvalue = match self.coroutine_kind {
             CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
                 let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
                 let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
-                if is_return {
-                    // Poll::Ready(val)
-                    Rvalue::Aggregate(
-                        Box::new(AggregateKind::Adt(
-                            poll_def_id,
-                            VariantIdx::ZERO,
-                            args,
-                            None,
-                            None,
-                        )),
-                        IndexVec::from_raw(vec![val]),
-                    )
+                let (variant_idx, operands) = if is_return {
+                    (ZERO, IndexVec::from_raw(vec![val])) // Poll::Ready(val)
                 } else {
-                    // Poll::Pending
-                    Rvalue::Aggregate(
-                        Box::new(AggregateKind::Adt(
-                            poll_def_id,
-                            VariantIdx::from_usize(1),
-                            args,
-                            None,
-                            None,
-                        )),
-                        IndexVec::new(),
-                    )
-                }
+                    (ONE, IndexVec::new()) // Poll::Pending
+                };
+                make_aggregate_adt(poll_def_id, variant_idx, args, operands)
             }
             CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
                 let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
                 let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
-                if is_return {
-                    // None
-                    Rvalue::Aggregate(
-                        Box::new(AggregateKind::Adt(
-                            option_def_id,
-                            VariantIdx::ZERO,
-                            args,
-                            None,
-                            None,
-                        )),
-                        IndexVec::new(),
-                    )
+                let (variant_idx, operands) = if is_return {
+                    (ZERO, IndexVec::new()) // None
                 } else {
-                    // Some(val)
-                    Rvalue::Aggregate(
-                        Box::new(AggregateKind::Adt(
-                            option_def_id,
-                            VariantIdx::from_usize(1),
-                            args,
-                            None,
-                            None,
-                        )),
-                        IndexVec::from_raw(vec![val]),
-                    )
-                }
+                    (ONE, IndexVec::from_raw(vec![val])) // Some(val)
+                };
+                make_aggregate_adt(option_def_id, variant_idx, args, operands)
             }
             CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
                 if is_return {
@@ -349,31 +311,17 @@ impl<'tcx> TransformVisitor<'tcx> {
                 let coroutine_state_def_id =
                     self.tcx.require_lang_item(LangItem::CoroutineState, None);
                 let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
-                if is_return {
-                    // CoroutineState::Complete(val)
-                    Rvalue::Aggregate(
-                        Box::new(AggregateKind::Adt(
-                            coroutine_state_def_id,
-                            VariantIdx::from_usize(1),
-                            args,
-                            None,
-                            None,
-                        )),
-                        IndexVec::from_raw(vec![val]),
-                    )
+                let variant_idx = if is_return {
+                    ONE // CoroutineState::Complete(val)
                 } else {
-                    // CoroutineState::Yielded(val)
-                    Rvalue::Aggregate(
-                        Box::new(AggregateKind::Adt(
-                            coroutine_state_def_id,
-                            VariantIdx::ZERO,
-                            args,
-                            None,
-                            None,
-                        )),
-                        IndexVec::from_raw(vec![val]),
-                    )
-                }
+                    ZERO // CoroutineState::Yielded(val)
+                };
+                make_aggregate_adt(
+                    coroutine_state_def_id,
+                    variant_idx,
+                    args,
+                    IndexVec::from_raw(vec![val]),
+                )
             }
         };
 
@@ -509,6 +457,15 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
     }
 }
 
+fn make_aggregate_adt<'tcx>(
+    def_id: DefId,
+    variant_idx: VariantIdx,
+    args: GenericArgsRef<'tcx>,
+    operands: IndexVec<FieldIdx, Operand<'tcx>>,
+) -> Rvalue<'tcx> {
+    Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
+}
+
 fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
     let coroutine_ty = body.local_decls.raw[1].ty;