about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMazdak Farrokhzad <twingoow@gmail.com>2019-05-20 01:01:38 +0200
committerGitHub <noreply@github.com>2019-05-20 01:01:38 +0200
commit5c84d779b2a52519b72a2ba0dd492e2a7eb6552e (patch)
treee7a3fc91989cf876f94050fc538a6d9fc2f0fafb
parentf9d65c000dfa92825589dd1a13a20b95a328493b (diff)
parentec853ba026261bc1c8c53a99d210c02f88fde54f (diff)
downloadrust-5c84d779b2a52519b72a2ba0dd492e2a7eb6552e.tar.gz
rust-5c84d779b2a52519b72a2ba0dd492e2a7eb6552e.zip
Rollup merge of #60745 - wesleywiser:const_prop_into_terminators, r=oli-obk
Perform constant propagation into terminators

Perform constant propagation into MIR `Assert` and `SwitchInt` `Terminator`s which in some cases allows them to be removed by the branch simplification pass.

r? @oli-obk
-rw-r--r--src/librustc_mir/transform/const_prop.rs175
-rw-r--r--src/test/mir-opt/const_prop/array_index.rs2
-rw-r--r--src/test/mir-opt/const_prop/checked_add.rs2
-rw-r--r--src/test/mir-opt/const_prop/switch_int.rs38
-rw-r--r--src/test/mir-opt/simplify_if.rs8
5 files changed, 152 insertions, 73 deletions
diff --git a/src/librustc_mir/transform/const_prop.rs b/src/librustc_mir/transform/const_prop.rs
index 4e214c3c725..8f3dd72c4f2 100644
--- a/src/librustc_mir/transform/const_prop.rs
+++ b/src/librustc_mir/transform/const_prop.rs
@@ -546,6 +546,10 @@ impl<'a, 'mir, 'tcx> ConstPropagator<'a, 'mir, 'tcx> {
             }
         }
     }
+
+    fn should_const_prop(&self) -> bool {
+        self.tcx.sess.opts.debugging_opts.mir_opt_level >= 2
+    }
 }
 
 fn type_size_of<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
@@ -639,7 +643,7 @@ impl<'b, 'a, 'tcx> MutVisitor<'tcx> for ConstPropagator<'b, 'a, 'tcx> {
                             assert!(self.places[local].is_none());
                             self.places[local] = Some(value);
 
-                            if self.tcx.sess.opts.debugging_opts.mir_opt_level >= 2 {
+                            if self.should_const_prop() {
                                 self.replace_with_const(rval, value, statement.source_info.span);
                             }
                         }
@@ -656,75 +660,112 @@ impl<'b, 'a, 'tcx> MutVisitor<'tcx> for ConstPropagator<'b, 'a, 'tcx> {
         location: Location,
     ) {
         self.super_terminator(terminator, location);
-        let source_info = terminator.source_info;;
-        if let TerminatorKind::Assert { expected, msg, cond, .. } = &terminator.kind {
-            if let Some(value) = self.eval_operand(&cond, source_info) {
-                trace!("assertion on {:?} should be {:?}", value, expected);
-                let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
-                if expected != self.ecx.read_scalar(value).unwrap() {
-                    // poison all places this operand references so that further code
-                    // doesn't use the invalid value
-                    match cond {
-                        Operand::Move(ref place) | Operand::Copy(ref place) => {
-                            let mut place = place;
-                            while let Place::Projection(ref proj) = *place {
-                                place = &proj.base;
-                            }
-                            if let Place::Base(PlaceBase::Local(local)) = *place {
-                                self.places[local] = None;
+        let source_info = terminator.source_info;
+        match &mut terminator.kind {
+            TerminatorKind::Assert { expected, msg, ref mut cond, .. } => {
+                if let Some(value) = self.eval_operand(&cond, source_info) {
+                    trace!("assertion on {:?} should be {:?}", value, expected);
+                    let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
+                    let value_const = self.ecx.read_scalar(value).unwrap();
+                    if expected != value_const {
+                        // poison all places this operand references so that further code
+                        // doesn't use the invalid value
+                        match cond {
+                            Operand::Move(ref place) | Operand::Copy(ref place) => {
+                                let mut place = place;
+                                while let Place::Projection(ref proj) = *place {
+                                    place = &proj.base;
+                                }
+                                if let Place::Base(PlaceBase::Local(local)) = *place {
+                                    self.places[local] = None;
+                                }
+                            },
+                            Operand::Constant(_) => {}
+                        }
+                        let span = terminator.source_info.span;
+                        let hir_id = self
+                            .tcx
+                            .hir()
+                            .as_local_hir_id(self.source.def_id())
+                            .expect("some part of a failing const eval must be local");
+                        use rustc::mir::interpret::InterpError::*;
+                        let msg = match msg {
+                            Overflow(_) |
+                            OverflowNeg |
+                            DivisionByZero |
+                            RemainderByZero => msg.description().to_owned(),
+                            BoundsCheck { ref len, ref index } => {
+                                let len = self
+                                    .eval_operand(len, source_info)
+                                    .expect("len must be const");
+                                let len = match self.ecx.read_scalar(len) {
+                                    Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
+                                        bits, ..
+                                    })) => bits,
+                                    other => bug!("const len not primitive: {:?}", other),
+                                };
+                                let index = self
+                                    .eval_operand(index, source_info)
+                                    .expect("index must be const");
+                                let index = match self.ecx.read_scalar(index) {
+                                    Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
+                                        bits, ..
+                                    })) => bits,
+                                    other => bug!("const index not primitive: {:?}", other),
+                                };
+                                format!(
+                                    "index out of bounds: \
+                                    the len is {} but the index is {}",
+                                    len,
+                                    index,
+                                )
+                            },
+                            // Need proper const propagator for these
+                            _ => return,
+                        };
+                        self.tcx.lint_hir(
+                            ::rustc::lint::builtin::CONST_ERR,
+                            hir_id,
+                            span,
+                            &msg,
+                        );
+                    } else {
+                        if self.should_const_prop() {
+                            if let ScalarMaybeUndef::Scalar(scalar) = value_const {
+                                *cond = self.operand_from_scalar(
+                                    scalar,
+                                    self.tcx.types.bool,
+                                    source_info.span,
+                                );
                             }
-                        },
-                        Operand::Constant(_) => {}
+                        }
                     }
-                    let span = terminator.source_info.span;
-                    let hir_id = self
-                        .tcx
-                        .hir()
-                        .as_local_hir_id(self.source.def_id())
-                        .expect("some part of a failing const eval must be local");
-                    use rustc::mir::interpret::InterpError::*;
-                    let msg = match msg {
-                        Overflow(_) |
-                        OverflowNeg |
-                        DivisionByZero |
-                        RemainderByZero => msg.description().to_owned(),
-                        BoundsCheck { ref len, ref index } => {
-                            let len = self
-                                .eval_operand(len, source_info)
-                                .expect("len must be const");
-                            let len = match self.ecx.read_scalar(len) {
-                                Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
-                                    bits, ..
-                                })) => bits,
-                                other => bug!("const len not primitive: {:?}", other),
-                            };
-                            let index = self
-                                .eval_operand(index, source_info)
-                                .expect("index must be const");
-                            let index = match self.ecx.read_scalar(index) {
-                                Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
-                                    bits, ..
-                                })) => bits,
-                                other => bug!("const index not primitive: {:?}", other),
-                            };
-                            format!(
-                                "index out of bounds: \
-                                the len is {} but the index is {}",
-                                len,
-                                index,
-                            )
-                        },
-                        // Need proper const propagator for these
-                        _ => return,
-                    };
-                    self.tcx.lint_hir(
-                        ::rustc::lint::builtin::CONST_ERR,
-                        hir_id,
-                        span,
-                        &msg,
-                    );
                 }
-            }
+            },
+            TerminatorKind::SwitchInt { ref mut discr, switch_ty, .. } => {
+                if self.should_const_prop() {
+                    if let Some(value) = self.eval_operand(&discr, source_info) {
+                        if let ScalarMaybeUndef::Scalar(scalar) =
+                                self.ecx.read_scalar(value).unwrap() {
+                            *discr = self.operand_from_scalar(scalar, switch_ty, source_info.span);
+                        }
+                    }
+                }
+            },
+            //none of these have Operands to const-propagate
+            TerminatorKind::Goto { .. } |
+            TerminatorKind::Resume |
+            TerminatorKind::Abort |
+            TerminatorKind::Return |
+            TerminatorKind::Unreachable |
+            TerminatorKind::Drop { .. } |
+            TerminatorKind::DropAndReplace { .. } |
+            TerminatorKind::Yield { .. } |
+            TerminatorKind::GeneratorDrop |
+            TerminatorKind::FalseEdges { .. } |
+            TerminatorKind::FalseUnwind { .. } => { }
+            //FIXME(wesleywiser) Call does have Operands that could be const-propagated
+            TerminatorKind::Call { .. } => { }
         }
     }
 }
diff --git a/src/test/mir-opt/const_prop/array_index.rs b/src/test/mir-opt/const_prop/array_index.rs
index 4b97af68ff0..dd22eb5d604 100644
--- a/src/test/mir-opt/const_prop/array_index.rs
+++ b/src/test/mir-opt/const_prop/array_index.rs
@@ -23,7 +23,7 @@ fn main() {
 //  bb0: {
 //      ...
 //      _5 = const true;
-//      assert(move _5, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
+//      assert(const true, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
 //  }
 //  bb1: {
 //      _1 = _2[_3];
diff --git a/src/test/mir-opt/const_prop/checked_add.rs b/src/test/mir-opt/const_prop/checked_add.rs
index 0718316307c..fe98cf24eec 100644
--- a/src/test/mir-opt/const_prop/checked_add.rs
+++ b/src/test/mir-opt/const_prop/checked_add.rs
@@ -16,6 +16,6 @@ fn main() {
 //  bb0: {
 //      ...
 //      _2 = (const 2u32, const false);
-//      assert(!move (_2.1: bool), "attempt to add with overflow") -> bb1;
+//      assert(!const false, "attempt to add with overflow") -> bb1;
 //  }
 // END rustc.main.ConstProp.after.mir
diff --git a/src/test/mir-opt/const_prop/switch_int.rs b/src/test/mir-opt/const_prop/switch_int.rs
new file mode 100644
index 00000000000..0df1112ec3e
--- /dev/null
+++ b/src/test/mir-opt/const_prop/switch_int.rs
@@ -0,0 +1,38 @@
+#[inline(never)]
+fn foo(_: i32) { }
+
+fn main() {
+    match 1 {
+        1 => foo(0),
+        _ => foo(-1),
+    }
+}
+
+// END RUST SOURCE
+// START rustc.main.ConstProp.before.mir
+//  bb0: {
+//      ...
+//      _1 = const 1i32;
+//      switchInt(_1) -> [1i32: bb1, otherwise: bb2];
+//  }
+// END rustc.main.ConstProp.before.mir
+// START rustc.main.ConstProp.after.mir
+//  bb0: {
+//      ...
+//      switchInt(const 1i32) -> [1i32: bb1, otherwise: bb2];
+//  }
+// END rustc.main.ConstProp.after.mir
+// START rustc.main.SimplifyBranches-after-const-prop.before.mir
+//  bb0: {
+//      ...
+//      _1 = const 1i32;
+//      switchInt(const 1i32) -> [1i32: bb1, otherwise: bb2];
+//  }
+// END rustc.main.SimplifyBranches-after-const-prop.before.mir
+// START rustc.main.SimplifyBranches-after-const-prop.after.mir
+//  bb0: {
+//      ...
+//      _1 = const 1i32;
+//      goto -> bb1;
+//  }
+// END rustc.main.SimplifyBranches-after-const-prop.after.mir
diff --git a/src/test/mir-opt/simplify_if.rs b/src/test/mir-opt/simplify_if.rs
index b2a99a6d446..35512b94c0c 100644
--- a/src/test/mir-opt/simplify_if.rs
+++ b/src/test/mir-opt/simplify_if.rs
@@ -5,15 +5,15 @@ fn main() {
 }
 
 // END RUST SOURCE
-// START rustc.main.SimplifyBranches-after-copy-prop.before.mir
+// START rustc.main.SimplifyBranches-after-const-prop.before.mir
 // bb0: {
 //     ...
 //     switchInt(const false) -> [false: bb3, otherwise: bb1];
 // }
-// END rustc.main.SimplifyBranches-after-copy-prop.before.mir
-// START rustc.main.SimplifyBranches-after-copy-prop.after.mir
+// END rustc.main.SimplifyBranches-after-const-prop.before.mir
+// START rustc.main.SimplifyBranches-after-const-prop.after.mir
 // bb0: {
 //     ...
 //     goto -> bb3;
 // }
-// END rustc.main.SimplifyBranches-after-copy-prop.after.mir
+// END rustc.main.SimplifyBranches-after-const-prop.after.mir