about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs297
1 files changed, 163 insertions, 134 deletions
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 793dcf0d994..7b5697bc949 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -32,7 +32,6 @@ struct CallSite<'tcx> {
     callee: Instance<'tcx>,
     fn_sig: ty::PolyFnSig<'tcx>,
     block: BasicBlock,
-    target: Option<BasicBlock>,
     source_info: SourceInfo,
 }
 
@@ -367,7 +366,7 @@ impl<'tcx> Inliner<'tcx> {
     ) -> Option<CallSite<'tcx>> {
         // Only consider direct calls to functions
         let terminator = bb_data.terminator();
-        if let TerminatorKind::Call { ref func, target, fn_span, .. } = terminator.kind {
+        if let TerminatorKind::Call { ref func, fn_span, .. } = terminator.kind {
             let func_ty = func.ty(caller_body, self.tcx);
             if let ty::FnDef(def_id, args) = *func_ty.kind() {
                 // To resolve an instance its args have to be fully normalized.
@@ -386,7 +385,7 @@ impl<'tcx> Inliner<'tcx> {
                 let fn_sig = self.tcx.fn_sig(def_id).instantiate(self.tcx, args);
                 let source_info = SourceInfo { span: fn_span, ..terminator.source_info };
 
-                return Some(CallSite { callee, fn_sig, block: bb, target, source_info });
+                return Some(CallSite { callee, fn_sig, block: bb, source_info });
             }
         }
 
@@ -541,142 +540,158 @@ impl<'tcx> Inliner<'tcx> {
         mut callee_body: Body<'tcx>,
     ) {
         let terminator = caller_body[callsite.block].terminator.take().unwrap();
-        match terminator.kind {
-            TerminatorKind::Call { args, destination, unwind, .. } => {
-                // If the call is something like `a[*i] = f(i)`, where
-                // `i : &mut usize`, then just duplicating the `a[*i]`
-                // Place could result in two different locations if `f`
-                // writes to `i`. To prevent this we need to create a temporary
-                // borrow of the place and pass the destination as `*temp` instead.
-                fn dest_needs_borrow(place: Place<'_>) -> bool {
-                    for elem in place.projection.iter() {
-                        match elem {
-                            ProjectionElem::Deref | ProjectionElem::Index(_) => return true,
-                            _ => {}
-                        }
-                    }
+        let TerminatorKind::Call { args, destination, unwind, target, .. } = terminator.kind else {
+            bug!("unexpected terminator kind {:?}", terminator.kind);
+        };
+
+        let return_block = if let Some(block) = target {
+            // Prepare a new block for code that should execute when call returns. We don't use
+            // target block directly since it might have other predecessors.
+            let mut data = BasicBlockData::new(Some(Terminator {
+                source_info: terminator.source_info,
+                kind: TerminatorKind::Goto { target: block },
+            }));
+            data.is_cleanup = caller_body[block].is_cleanup;
+            Some(caller_body.basic_blocks_mut().push(data))
+        } else {
+            None
+        };
 
-                    false
+        // If the call is something like `a[*i] = f(i)`, where
+        // `i : &mut usize`, then just duplicating the `a[*i]`
+        // Place could result in two different locations if `f`
+        // writes to `i`. To prevent this we need to create a temporary
+        // borrow of the place and pass the destination as `*temp` instead.
+        fn dest_needs_borrow(place: Place<'_>) -> bool {
+            for elem in place.projection.iter() {
+                match elem {
+                    ProjectionElem::Deref | ProjectionElem::Index(_) => return true,
+                    _ => {}
                 }
+            }
 
-                let dest = if dest_needs_borrow(destination) {
-                    trace!("creating temp for return destination");
-                    let dest = Rvalue::Ref(
-                        self.tcx.lifetimes.re_erased,
-                        BorrowKind::Mut { kind: MutBorrowKind::Default },
-                        destination,
-                    );
-                    let dest_ty = dest.ty(caller_body, self.tcx);
-                    let temp = Place::from(self.new_call_temp(caller_body, &callsite, dest_ty));
-                    caller_body[callsite.block].statements.push(Statement {
-                        source_info: callsite.source_info,
-                        kind: StatementKind::Assign(Box::new((temp, dest))),
-                    });
-                    self.tcx.mk_place_deref(temp)
-                } else {
-                    destination
-                };
+            false
+        }
 
-                // Always create a local to hold the destination, as `RETURN_PLACE` may appear
-                // where a full `Place` is not allowed.
-                let (remap_destination, destination_local) = if let Some(d) = dest.as_local() {
-                    (false, d)
-                } else {
-                    (
-                        true,
-                        self.new_call_temp(
-                            caller_body,
-                            &callsite,
-                            destination.ty(caller_body, self.tcx).ty,
-                        ),
-                    )
-                };
+        let dest = if dest_needs_borrow(destination) {
+            trace!("creating temp for return destination");
+            let dest = Rvalue::Ref(
+                self.tcx.lifetimes.re_erased,
+                BorrowKind::Mut { kind: MutBorrowKind::Default },
+                destination,
+            );
+            let dest_ty = dest.ty(caller_body, self.tcx);
+            let temp =
+                Place::from(self.new_call_temp(caller_body, &callsite, dest_ty, return_block));
+            caller_body[callsite.block].statements.push(Statement {
+                source_info: callsite.source_info,
+                kind: StatementKind::Assign(Box::new((temp, dest))),
+            });
+            self.tcx.mk_place_deref(temp)
+        } else {
+            destination
+        };
 
-                // Copy the arguments if needed.
-                let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body);
-
-                let mut integrator = Integrator {
-                    args: &args,
-                    new_locals: Local::new(caller_body.local_decls.len())..,
-                    new_scopes: SourceScope::new(caller_body.source_scopes.len())..,
-                    new_blocks: BasicBlock::new(caller_body.basic_blocks.len())..,
-                    destination: destination_local,
-                    callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(),
-                    callsite,
-                    cleanup_block: unwind,
-                    in_cleanup_block: false,
-                    tcx: self.tcx,
-                    always_live_locals: BitSet::new_filled(callee_body.local_decls.len()),
-                };
+        // Always create a local to hold the destination, as `RETURN_PLACE` may appear
+        // where a full `Place` is not allowed.
+        let (remap_destination, destination_local) = if let Some(d) = dest.as_local() {
+            (false, d)
+        } else {
+            (
+                true,
+                self.new_call_temp(
+                    caller_body,
+                    &callsite,
+                    destination.ty(caller_body, self.tcx).ty,
+                    return_block,
+                ),
+            )
+        };
 
-                // Map all `Local`s, `SourceScope`s and `BasicBlock`s to new ones
-                // (or existing ones, in a few special cases) in the caller.
-                integrator.visit_body(&mut callee_body);
-
-                // If there are any locals without storage markers, give them storage only for the
-                // duration of the call.
-                for local in callee_body.vars_and_temps_iter() {
-                    if integrator.always_live_locals.contains(local) {
-                        let new_local = integrator.map_local(local);
-                        caller_body[callsite.block].statements.push(Statement {
-                            source_info: callsite.source_info,
-                            kind: StatementKind::StorageLive(new_local),
-                        });
-                    }
-                }
-                if let Some(block) = callsite.target {
-                    // To avoid repeated O(n) insert, push any new statements to the end and rotate
-                    // the slice once.
-                    let mut n = 0;
-                    if remap_destination {
-                        caller_body[block].statements.push(Statement {
-                            source_info: callsite.source_info,
-                            kind: StatementKind::Assign(Box::new((
-                                dest,
-                                Rvalue::Use(Operand::Move(destination_local.into())),
-                            ))),
-                        });
-                        n += 1;
-                    }
-                    for local in callee_body.vars_and_temps_iter().rev() {
-                        if integrator.always_live_locals.contains(local) {
-                            let new_local = integrator.map_local(local);
-                            caller_body[block].statements.push(Statement {
-                                source_info: callsite.source_info,
-                                kind: StatementKind::StorageDead(new_local),
-                            });
-                            n += 1;
-                        }
-                    }
-                    caller_body[block].statements.rotate_right(n);
-                }
+        // Copy the arguments if needed.
+        let args: Vec<_> =
+            self.make_call_args(args, &callsite, caller_body, &callee_body, return_block);
+
+        let mut integrator = Integrator {
+            args: &args,
+            new_locals: Local::new(caller_body.local_decls.len())..,
+            new_scopes: SourceScope::new(caller_body.source_scopes.len())..,
+            new_blocks: BasicBlock::new(caller_body.basic_blocks.len())..,
+            destination: destination_local,
+            callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(),
+            callsite,
+            cleanup_block: unwind,
+            in_cleanup_block: false,
+            return_block,
+            tcx: self.tcx,
+            always_live_locals: BitSet::new_filled(callee_body.local_decls.len()),
+        };
 
-                // Insert all of the (mapped) parts of the callee body into the caller.
-                caller_body.local_decls.extend(callee_body.drain_vars_and_temps());
-                caller_body.source_scopes.extend(&mut callee_body.source_scopes.drain(..));
-                caller_body.var_debug_info.append(&mut callee_body.var_debug_info);
-                caller_body.basic_blocks_mut().extend(callee_body.basic_blocks_mut().drain(..));
+        // Map all `Local`s, `SourceScope`s and `BasicBlock`s to new ones
+        // (or existing ones, in a few special cases) in the caller.
+        integrator.visit_body(&mut callee_body);
 
-                caller_body[callsite.block].terminator = Some(Terminator {
+        // If there are any locals without storage markers, give them storage only for the
+        // duration of the call.
+        for local in callee_body.vars_and_temps_iter() {
+            if integrator.always_live_locals.contains(local) {
+                let new_local = integrator.map_local(local);
+                caller_body[callsite.block].statements.push(Statement {
                     source_info: callsite.source_info,
-                    kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) },
+                    kind: StatementKind::StorageLive(new_local),
                 });
-
-                // Copy only unevaluated constants from the callee_body into the caller_body.
-                // Although we are only pushing `ConstKind::Unevaluated` consts to
-                // `required_consts`, here we may not only have `ConstKind::Unevaluated`
-                // because we are calling `instantiate_and_normalize_erasing_regions`.
-                caller_body.required_consts.extend(
-                    callee_body.required_consts.iter().copied().filter(|&ct| match ct.const_ {
-                        Const::Ty(_) => {
-                            bug!("should never encounter ty::UnevaluatedConst in `required_consts`")
-                        }
-                        Const::Val(..) | Const::Unevaluated(..) => true,
-                    }),
-                );
             }
-            kind => bug!("unexpected terminator kind {:?}", kind),
         }
+        if let Some(block) = return_block {
+            // To avoid repeated O(n) insert, push any new statements to the end and rotate
+            // the slice once.
+            let mut n = 0;
+            if remap_destination {
+                caller_body[block].statements.push(Statement {
+                    source_info: callsite.source_info,
+                    kind: StatementKind::Assign(Box::new((
+                        dest,
+                        Rvalue::Use(Operand::Move(destination_local.into())),
+                    ))),
+                });
+                n += 1;
+            }
+            for local in callee_body.vars_and_temps_iter().rev() {
+                if integrator.always_live_locals.contains(local) {
+                    let new_local = integrator.map_local(local);
+                    caller_body[block].statements.push(Statement {
+                        source_info: callsite.source_info,
+                        kind: StatementKind::StorageDead(new_local),
+                    });
+                    n += 1;
+                }
+            }
+            caller_body[block].statements.rotate_right(n);
+        }
+
+        // Insert all of the (mapped) parts of the callee body into the caller.
+        caller_body.local_decls.extend(callee_body.drain_vars_and_temps());
+        caller_body.source_scopes.extend(&mut callee_body.source_scopes.drain(..));
+        caller_body.var_debug_info.append(&mut callee_body.var_debug_info);
+        caller_body.basic_blocks_mut().extend(callee_body.basic_blocks_mut().drain(..));
+
+        caller_body[callsite.block].terminator = Some(Terminator {
+            source_info: callsite.source_info,
+            kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) },
+        });
+
+        // Copy only unevaluated constants from the callee_body into the caller_body.
+        // Although we are only pushing `ConstKind::Unevaluated` consts to
+        // `required_consts`, here we may not only have `ConstKind::Unevaluated`
+        // because we are calling `instantiate_and_normalize_erasing_regions`.
+        caller_body.required_consts.extend(callee_body.required_consts.iter().copied().filter(
+            |&ct| match ct.const_ {
+                Const::Ty(_) => {
+                    bug!("should never encounter ty::UnevaluatedConst in `required_consts`")
+                }
+                Const::Val(..) | Const::Unevaluated(..) => true,
+            },
+        ));
     }
 
     fn make_call_args(
@@ -685,6 +700,7 @@ impl<'tcx> Inliner<'tcx> {
         callsite: &CallSite<'tcx>,
         caller_body: &mut Body<'tcx>,
         callee_body: &Body<'tcx>,
+        return_block: Option<BasicBlock>,
     ) -> Vec<Local> {
         let tcx = self.tcx;
 
@@ -713,8 +729,18 @@ impl<'tcx> Inliner<'tcx> {
         // and the vector is `[closure_ref, tmp0, tmp1, tmp2]`.
         if callsite.fn_sig.abi() == Abi::RustCall && callee_body.spread_arg.is_none() {
             let mut args = args.into_iter();
-            let self_ = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
-            let tuple = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
+            let self_ = self.create_temp_if_necessary(
+                args.next().unwrap(),
+                callsite,
+                caller_body,
+                return_block,
+            );
+            let tuple = self.create_temp_if_necessary(
+                args.next().unwrap(),
+                callsite,
+                caller_body,
+                return_block,
+            );
             assert!(args.next().is_none());
 
             let tuple = Place::from(tuple);
@@ -731,13 +757,13 @@ impl<'tcx> Inliner<'tcx> {
                 let tuple_field = Operand::Move(tcx.mk_place_field(tuple, FieldIdx::new(i), ty));
 
                 // Spill to a local to make e.g., `tmp0`.
-                self.create_temp_if_necessary(tuple_field, callsite, caller_body)
+                self.create_temp_if_necessary(tuple_field, callsite, caller_body, return_block)
             });
 
             closure_ref_arg.chain(tuple_tmp_args).collect()
         } else {
             args.into_iter()
-                .map(|a| self.create_temp_if_necessary(a, callsite, caller_body))
+                .map(|a| self.create_temp_if_necessary(a, callsite, caller_body, return_block))
                 .collect()
         }
     }
@@ -749,6 +775,7 @@ impl<'tcx> Inliner<'tcx> {
         arg: Operand<'tcx>,
         callsite: &CallSite<'tcx>,
         caller_body: &mut Body<'tcx>,
+        return_block: Option<BasicBlock>,
     ) -> Local {
         // Reuse the operand if it is a moved temporary.
         if let Operand::Move(place) = &arg
@@ -761,7 +788,7 @@ impl<'tcx> Inliner<'tcx> {
         // Otherwise, create a temporary for the argument.
         trace!("creating temp for argument {:?}", arg);
         let arg_ty = arg.ty(caller_body, self.tcx);
-        let local = self.new_call_temp(caller_body, callsite, arg_ty);
+        let local = self.new_call_temp(caller_body, callsite, arg_ty, return_block);
         caller_body[callsite.block].statements.push(Statement {
             source_info: callsite.source_info,
             kind: StatementKind::Assign(Box::new((Place::from(local), Rvalue::Use(arg)))),
@@ -775,6 +802,7 @@ impl<'tcx> Inliner<'tcx> {
         caller_body: &mut Body<'tcx>,
         callsite: &CallSite<'tcx>,
         ty: Ty<'tcx>,
+        return_block: Option<BasicBlock>,
     ) -> Local {
         let local = caller_body.local_decls.push(LocalDecl::new(ty, callsite.source_info.span));
 
@@ -783,7 +811,7 @@ impl<'tcx> Inliner<'tcx> {
             kind: StatementKind::StorageLive(local),
         });
 
-        if let Some(block) = callsite.target {
+        if let Some(block) = return_block {
             caller_body[block].statements.insert(
                 0,
                 Statement {
@@ -814,6 +842,7 @@ struct Integrator<'a, 'tcx> {
     callsite: &'a CallSite<'tcx>,
     cleanup_block: UnwindAction,
     in_cleanup_block: bool,
+    return_block: Option<BasicBlock>,
     tcx: TyCtxt<'tcx>,
     always_live_locals: BitSet<Local>,
 }
@@ -957,7 +986,7 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> {
                 *unwind = self.map_unwind(*unwind);
             }
             TerminatorKind::Return => {
-                terminator.kind = if let Some(tgt) = self.callsite.target {
+                terminator.kind = if let Some(tgt) = self.return_block {
                     TerminatorKind::Goto { target: tgt }
                 } else {
                     TerminatorKind::Unreachable