about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/inline.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/inline.rs')
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs58
1 files changed, 44 insertions, 14 deletions
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 5d4560b7d5f..7b5697bc949 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -32,7 +32,6 @@ struct CallSite<'tcx> {
     callee: Instance<'tcx>,
     fn_sig: ty::PolyFnSig<'tcx>,
     block: BasicBlock,
-    target: Option<BasicBlock>,
     source_info: SourceInfo,
 }
 
@@ -367,7 +366,7 @@ impl<'tcx> Inliner<'tcx> {
     ) -> Option<CallSite<'tcx>> {
         // Only consider direct calls to functions
         let terminator = bb_data.terminator();
-        if let TerminatorKind::Call { ref func, target, fn_span, .. } = terminator.kind {
+        if let TerminatorKind::Call { ref func, fn_span, .. } = terminator.kind {
             let func_ty = func.ty(caller_body, self.tcx);
             if let ty::FnDef(def_id, args) = *func_ty.kind() {
                 // To resolve an instance its args have to be fully normalized.
@@ -386,7 +385,7 @@ impl<'tcx> Inliner<'tcx> {
                 let fn_sig = self.tcx.fn_sig(def_id).instantiate(self.tcx, args);
                 let source_info = SourceInfo { span: fn_span, ..terminator.source_info };
 
-                return Some(CallSite { callee, fn_sig, block: bb, target, source_info });
+                return Some(CallSite { callee, fn_sig, block: bb, source_info });
             }
         }
 
@@ -541,10 +540,23 @@ impl<'tcx> Inliner<'tcx> {
         mut callee_body: Body<'tcx>,
     ) {
         let terminator = caller_body[callsite.block].terminator.take().unwrap();
-        let TerminatorKind::Call { args, destination, unwind, .. } = terminator.kind else {
+        let TerminatorKind::Call { args, destination, unwind, target, .. } = terminator.kind else {
             bug!("unexpected terminator kind {:?}", terminator.kind);
         };
 
+        let return_block = if let Some(block) = target {
+            // Prepare a new block for code that should execute when call returns. We don't use
+            // target block directly since it might have other predecessors.
+            let mut data = BasicBlockData::new(Some(Terminator {
+                source_info: terminator.source_info,
+                kind: TerminatorKind::Goto { target: block },
+            }));
+            data.is_cleanup = caller_body[block].is_cleanup;
+            Some(caller_body.basic_blocks_mut().push(data))
+        } else {
+            None
+        };
+
         // If the call is something like `a[*i] = f(i)`, where
         // `i : &mut usize`, then just duplicating the `a[*i]`
         // Place could result in two different locations if `f`
@@ -569,7 +581,8 @@ impl<'tcx> Inliner<'tcx> {
                 destination,
             );
             let dest_ty = dest.ty(caller_body, self.tcx);
-            let temp = Place::from(self.new_call_temp(caller_body, &callsite, dest_ty));
+            let temp =
+                Place::from(self.new_call_temp(caller_body, &callsite, dest_ty, return_block));
             caller_body[callsite.block].statements.push(Statement {
                 source_info: callsite.source_info,
                 kind: StatementKind::Assign(Box::new((temp, dest))),
@@ -590,12 +603,14 @@ impl<'tcx> Inliner<'tcx> {
                     caller_body,
                     &callsite,
                     destination.ty(caller_body, self.tcx).ty,
+                    return_block,
                 ),
             )
         };
 
         // Copy the arguments if needed.
-        let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body);
+        let args: Vec<_> =
+            self.make_call_args(args, &callsite, caller_body, &callee_body, return_block);
 
         let mut integrator = Integrator {
             args: &args,
@@ -607,6 +622,7 @@ impl<'tcx> Inliner<'tcx> {
             callsite,
             cleanup_block: unwind,
             in_cleanup_block: false,
+            return_block,
             tcx: self.tcx,
             always_live_locals: BitSet::new_filled(callee_body.local_decls.len()),
         };
@@ -626,7 +642,7 @@ impl<'tcx> Inliner<'tcx> {
                 });
             }
         }
-        if let Some(block) = callsite.target {
+        if let Some(block) = return_block {
             // To avoid repeated O(n) insert, push any new statements to the end and rotate
             // the slice once.
             let mut n = 0;
@@ -684,6 +700,7 @@ impl<'tcx> Inliner<'tcx> {
         callsite: &CallSite<'tcx>,
         caller_body: &mut Body<'tcx>,
         callee_body: &Body<'tcx>,
+        return_block: Option<BasicBlock>,
     ) -> Vec<Local> {
         let tcx = self.tcx;
 
@@ -712,8 +729,18 @@ impl<'tcx> Inliner<'tcx> {
         // and the vector is `[closure_ref, tmp0, tmp1, tmp2]`.
         if callsite.fn_sig.abi() == Abi::RustCall && callee_body.spread_arg.is_none() {
             let mut args = args.into_iter();
-            let self_ = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
-            let tuple = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
+            let self_ = self.create_temp_if_necessary(
+                args.next().unwrap(),
+                callsite,
+                caller_body,
+                return_block,
+            );
+            let tuple = self.create_temp_if_necessary(
+                args.next().unwrap(),
+                callsite,
+                caller_body,
+                return_block,
+            );
             assert!(args.next().is_none());
 
             let tuple = Place::from(tuple);
@@ -730,13 +757,13 @@ impl<'tcx> Inliner<'tcx> {
                 let tuple_field = Operand::Move(tcx.mk_place_field(tuple, FieldIdx::new(i), ty));
 
                 // Spill to a local to make e.g., `tmp0`.
-                self.create_temp_if_necessary(tuple_field, callsite, caller_body)
+                self.create_temp_if_necessary(tuple_field, callsite, caller_body, return_block)
             });
 
             closure_ref_arg.chain(tuple_tmp_args).collect()
         } else {
             args.into_iter()
-                .map(|a| self.create_temp_if_necessary(a, callsite, caller_body))
+                .map(|a| self.create_temp_if_necessary(a, callsite, caller_body, return_block))
                 .collect()
         }
     }
@@ -748,6 +775,7 @@ impl<'tcx> Inliner<'tcx> {
         arg: Operand<'tcx>,
         callsite: &CallSite<'tcx>,
         caller_body: &mut Body<'tcx>,
+        return_block: Option<BasicBlock>,
     ) -> Local {
         // Reuse the operand if it is a moved temporary.
         if let Operand::Move(place) = &arg
@@ -760,7 +788,7 @@ impl<'tcx> Inliner<'tcx> {
         // Otherwise, create a temporary for the argument.
         trace!("creating temp for argument {:?}", arg);
         let arg_ty = arg.ty(caller_body, self.tcx);
-        let local = self.new_call_temp(caller_body, callsite, arg_ty);
+        let local = self.new_call_temp(caller_body, callsite, arg_ty, return_block);
         caller_body[callsite.block].statements.push(Statement {
             source_info: callsite.source_info,
             kind: StatementKind::Assign(Box::new((Place::from(local), Rvalue::Use(arg)))),
@@ -774,6 +802,7 @@ impl<'tcx> Inliner<'tcx> {
         caller_body: &mut Body<'tcx>,
         callsite: &CallSite<'tcx>,
         ty: Ty<'tcx>,
+        return_block: Option<BasicBlock>,
     ) -> Local {
         let local = caller_body.local_decls.push(LocalDecl::new(ty, callsite.source_info.span));
 
@@ -782,7 +811,7 @@ impl<'tcx> Inliner<'tcx> {
             kind: StatementKind::StorageLive(local),
         });
 
-        if let Some(block) = callsite.target {
+        if let Some(block) = return_block {
             caller_body[block].statements.insert(
                 0,
                 Statement {
@@ -813,6 +842,7 @@ struct Integrator<'a, 'tcx> {
     callsite: &'a CallSite<'tcx>,
     cleanup_block: UnwindAction,
     in_cleanup_block: bool,
+    return_block: Option<BasicBlock>,
     tcx: TyCtxt<'tcx>,
     always_live_locals: BitSet<Local>,
 }
@@ -956,7 +986,7 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> {
                 *unwind = self.map_unwind(*unwind);
             }
             TerminatorKind::Return => {
-                terminator.kind = if let Some(tgt) = self.callsite.target {
+                terminator.kind = if let Some(tgt) = self.return_block {
                     TerminatorKind::Goto { target: tgt }
                 } else {
                     TerminatorKind::Unreachable