about summary refs log tree commit diff
diff options
context:
space:
mode:
authorTomasz Miąsko <tomasz.miasko@gmail.com>2020-11-11 00:00:00 +0000
committerTomasz Miąsko <tomasz.miasko@gmail.com>2020-11-12 20:09:04 +0100
commit66cadec1763ac645337c1ac58f06ea48b9b72a26 (patch)
tree2290df57297ddb43400e090e6dab732c5875acf2
parent9bb3d6b7d472e2116312ea45db07a5338af205fb (diff)
downloadrust-66cadec1763ac645337c1ac58f06ea48b9b72a26.tar.gz
rust-66cadec1763ac645337c1ac58f06ea48b9b72a26.zip
Fix generator inlining by checking for rust-call abi and spread arg
-rw-r--r--compiler/rustc_mir/src/transform/inline.rs26
-rw-r--r--src/test/mir-opt/inline/inline-generator.rs16
2 files changed, 30 insertions, 12 deletions
diff --git a/compiler/rustc_mir/src/transform/inline.rs b/compiler/rustc_mir/src/transform/inline.rs
index 0d6d9e397ac..2ccb9b3709f 100644
--- a/compiler/rustc_mir/src/transform/inline.rs
+++ b/compiler/rustc_mir/src/transform/inline.rs
@@ -7,6 +7,7 @@ use rustc_index::vec::Idx;
 use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
 use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
+use rustc_middle::ty::subst::Subst;
 use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
 use rustc_span::{hygiene::ExpnKind, ExpnData, Span};
 use rustc_target::spec::abi::Abi;
@@ -28,6 +29,7 @@ pub struct Inline;
 #[derive(Copy, Clone, Debug)]
 struct CallSite<'tcx> {
     callee: Instance<'tcx>,
+    fn_sig: ty::PolyFnSig<'tcx>,
     block: BasicBlock,
     target: Option<BasicBlock>,
     source_info: SourceInfo,
@@ -173,22 +175,23 @@ impl Inliner<'tcx> {
 
         // Only consider direct calls to functions
         let terminator = bb_data.terminator();
-        if let TerminatorKind::Call { func: ref op, ref destination, .. } = terminator.kind {
-            if let ty::FnDef(callee_def_id, substs) = *op.ty(caller_body, self.tcx).kind() {
-                // To resolve an instance its substs have to be fully normalized, so
-                // we do this here.
-                let normalized_substs = self.tcx.normalize_erasing_regions(self.param_env, substs);
+        if let TerminatorKind::Call { ref func, ref destination, .. } = terminator.kind {
+            let func_ty = func.ty(caller_body, self.tcx);
+            if let ty::FnDef(def_id, substs) = *func_ty.kind() {
+                // To resolve an instance its substs have to be fully normalized.
+                let substs = self.tcx.normalize_erasing_regions(self.param_env, substs);
                 let callee =
-                    Instance::resolve(self.tcx, self.param_env, callee_def_id, normalized_substs)
-                        .ok()
-                        .flatten()?;
+                    Instance::resolve(self.tcx, self.param_env, def_id, substs).ok().flatten()?;
 
                 if let InstanceDef::Virtual(..) | InstanceDef::Intrinsic(_) = callee.def {
                     return None;
                 }
 
+                let fn_sig = self.tcx.fn_sig(def_id).subst(self.tcx, substs);
+
                 return Some(CallSite {
                     callee,
+                    fn_sig,
                     block: bb,
                     target: destination.map(|(_, target)| target),
                     source_info: terminator.source_info,
@@ -437,7 +440,7 @@ impl Inliner<'tcx> {
                 };
 
                 // Copy the arguments if needed.
-                let args: Vec<_> = self.make_call_args(args, &callsite, caller_body);
+                let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body);
 
                 let mut integrator = Integrator {
                     args: &args,
@@ -518,6 +521,7 @@ impl Inliner<'tcx> {
         args: Vec<Operand<'tcx>>,
         callsite: &CallSite<'tcx>,
         caller_body: &mut Body<'tcx>,
+        callee_body: &Body<'tcx>,
     ) -> Vec<Local> {
         let tcx = self.tcx;
 
@@ -544,9 +548,7 @@ impl Inliner<'tcx> {
         //     tmp2 = tuple_tmp.2
         //
         // and the vector is `[closure_ref, tmp0, tmp1, tmp2]`.
-        // FIXME(eddyb) make this check for `"rust-call"` ABI combined with
-        // `callee_body.spread_arg == None`, instead of special-casing closures.
-        if tcx.is_closure(callsite.callee.def_id()) {
+        if callsite.fn_sig.abi() == Abi::RustCall && callee_body.spread_arg.is_none() {
             let mut args = args.into_iter();
             let self_ = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
             let tuple = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body);
diff --git a/src/test/mir-opt/inline/inline-generator.rs b/src/test/mir-opt/inline/inline-generator.rs
new file mode 100644
index 00000000000..d11b3e548f7
--- /dev/null
+++ b/src/test/mir-opt/inline/inline-generator.rs
@@ -0,0 +1,16 @@
+// ignore-wasm32-bare compiled with panic=abort by default
+#![feature(generators, generator_trait)]
+
+use std::ops::Generator;
+use std::pin::Pin;
+
+// EMIT_MIR inline_generator.main.Inline.diff
+fn main() {
+    let _r = Pin::new(&mut g()).resume(false);
+}
+
+#[inline(always)]
+pub fn g() -> impl Generator<bool> {
+    #[inline(always)]
+    |a| { yield if a { 7 } else { 13 } }
+}