about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLeón Orell Valerian Liehr <me@fmease.dev>2024-01-23 21:53:57 +0100
committerGitHub <noreply@github.com>2024-01-23 21:53:57 +0100
commit6cca9b33eca9d6a0344b643021706ea2a3d48dfb (patch)
tree8ab4d52091ed36fc4c504f869a3084e2a52e245c
parent19a840d6760222f89ad1ec55e2bcd07cb380bcc1 (diff)
parentafaac75ac76cfbc38066d0474f8ca69d92ca184d (diff)
downloadrust-6cca9b33eca9d6a0344b643021706ea2a3d48dfb.tar.gz
rust-6cca9b33eca9d6a0344b643021706ea2a3d48dfb.zip
Rollup merge of #120171 - cjgillot:jump-threading-assume-assert, r=tmiasko
Fix assume and assert in jump threading

r? ``@tmiasko``
-rw-r--r--compiler/rustc_mir_build/src/build/custom/parse/instruction.rs4
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs17
-rw-r--r--compiler/rustc_span/src/symbol.rs1
-rw-r--r--library/core/src/intrinsics/mir.rs2
-rw-r--r--tests/mir-opt/building/custom/assume.assume_constant.built.after.mir10
-rw-r--r--tests/mir-opt/building/custom/assume.assume_local.built.after.mir10
-rw-r--r--tests/mir-opt/building/custom/assume.assume_place.built.after.mir10
-rw-r--r--tests/mir-opt/building/custom/assume.rs44
-rw-r--r--tests/mir-opt/jump_threading.assume.JumpThreading.panic-abort.diff39
-rw-r--r--tests/mir-opt/jump_threading.assume.JumpThreading.panic-unwind.diff39
-rw-r--r--tests/mir-opt/jump_threading.rs48
11 files changed, 209 insertions, 15 deletions
diff --git a/compiler/rustc_mir_build/src/build/custom/parse/instruction.rs b/compiler/rustc_mir_build/src/build/custom/parse/instruction.rs
index 5428333a116..c669d3fd623 100644
--- a/compiler/rustc_mir_build/src/build/custom/parse/instruction.rs
+++ b/compiler/rustc_mir_build/src/build/custom/parse/instruction.rs
@@ -20,6 +20,10 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
             @call(mir_storage_dead, args) => {
                 Ok(StatementKind::StorageDead(self.parse_local(args[0])?))
             },
+            @call(mir_assume, args) => {
+                let op = self.parse_operand(args[0])?;
+                Ok(StatementKind::Intrinsic(Box::new(NonDivergingIntrinsic::Assume(op))))
+            },
             @call(mir_deinit, args) => {
                 Ok(StatementKind::Deinit(Box::new(self.parse_place(args[0])?)))
             },
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
index e87f68a0905..7a70ed5cb7f 100644
--- a/compiler/rustc_mir_transform/src/jump_threading.rs
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -566,11 +566,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
         cost: &CostChecker<'_, 'tcx>,
         depth: usize,
     ) {
-        let register_opportunity = |c: Condition| {
-            debug!(?bb, ?c.target, "register");
-            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
-        };
-
         let term = self.body.basic_blocks[bb].terminator();
         let place_to_flood = match term.kind {
             // We come from a target, so those are not possible.
@@ -592,16 +587,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
             // Flood the overwritten place, and progress through.
             TerminatorKind::Drop { place: destination, .. }
             | TerminatorKind::Call { destination, .. } => Some(destination),
-            // Treat as an `assume(cond == expected)`.
-            TerminatorKind::Assert { ref cond, expected, .. } => {
-                if let Some(place) = cond.place()
-                    && let Some(conditions) = state.try_get(place.as_ref(), self.map)
-                {
-                    let expected = if expected { ScalarInt::TRUE } else { ScalarInt::FALSE };
-                    conditions.iter_matches(expected).for_each(register_opportunity);
-                }
-                None
-            }
+            // Ignore, as this can be a no-op at codegen time.
+            TerminatorKind::Assert { .. } => None,
         };
 
         // We can recurse through this terminator.
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 6c39a38750e..90a38b26f73 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -1028,6 +1028,7 @@ symbols! {
         minnumf32,
         minnumf64,
         mips_target_feature,
+        mir_assume,
         mir_basic_block,
         mir_call,
         mir_cast_transmute,
diff --git a/library/core/src/intrinsics/mir.rs b/library/core/src/intrinsics/mir.rs
index 334e32b26b1..d348e31609d 100644
--- a/library/core/src/intrinsics/mir.rs
+++ b/library/core/src/intrinsics/mir.rs
@@ -357,6 +357,8 @@ define!("mir_unwind_resume",
 
 define!("mir_storage_live", fn StorageLive<T>(local: T));
 define!("mir_storage_dead", fn StorageDead<T>(local: T));
+#[cfg(not(bootstrap))]
+define!("mir_assume", fn Assume(operand: bool));
 define!("mir_deinit", fn Deinit<T>(place: T));
 define!("mir_checked", fn Checked<T>(binop: T) -> (T, bool));
 define!("mir_len", fn Len<T>(place: T) -> usize);
diff --git a/tests/mir-opt/building/custom/assume.assume_constant.built.after.mir b/tests/mir-opt/building/custom/assume.assume_constant.built.after.mir
new file mode 100644
index 00000000000..8e70d0a1e9b
--- /dev/null
+++ b/tests/mir-opt/building/custom/assume.assume_constant.built.after.mir
@@ -0,0 +1,10 @@
+// MIR for `assume_constant` after built
+
+fn assume_constant() -> () {
+    let mut _0: ();
+
+    bb0: {
+        assume(const true);
+        return;
+    }
+}
diff --git a/tests/mir-opt/building/custom/assume.assume_local.built.after.mir b/tests/mir-opt/building/custom/assume.assume_local.built.after.mir
new file mode 100644
index 00000000000..7ea1fcd30c2
--- /dev/null
+++ b/tests/mir-opt/building/custom/assume.assume_local.built.after.mir
@@ -0,0 +1,10 @@
+// MIR for `assume_local` after built
+
+fn assume_local(_1: bool) -> () {
+    let mut _0: ();
+
+    bb0: {
+        assume(_1);
+        return;
+    }
+}
diff --git a/tests/mir-opt/building/custom/assume.assume_place.built.after.mir b/tests/mir-opt/building/custom/assume.assume_place.built.after.mir
new file mode 100644
index 00000000000..ce914618d3d
--- /dev/null
+++ b/tests/mir-opt/building/custom/assume.assume_place.built.after.mir
@@ -0,0 +1,10 @@
+// MIR for `assume_place` after built
+
+fn assume_place(_1: (bool, u8)) -> () {
+    let mut _0: ();
+
+    bb0: {
+        assume((_1.0: bool));
+        return;
+    }
+}
diff --git a/tests/mir-opt/building/custom/assume.rs b/tests/mir-opt/building/custom/assume.rs
new file mode 100644
index 00000000000..a477e12f0e0
--- /dev/null
+++ b/tests/mir-opt/building/custom/assume.rs
@@ -0,0 +1,44 @@
+// skip-filecheck
+#![feature(custom_mir, core_intrinsics)]
+
+extern crate core;
+use core::intrinsics::mir::*;
+
+// EMIT_MIR assume.assume_local.built.after.mir
+#[custom_mir(dialect = "built")]
+fn assume_local(x: bool) {
+    mir!(
+        {
+            Assume(x);
+            Return()
+        }
+    )
+}
+
+// EMIT_MIR assume.assume_place.built.after.mir
+#[custom_mir(dialect = "built")]
+fn assume_place(p: (bool, u8)) {
+    mir!(
+        {
+            Assume(p.0);
+            Return()
+        }
+    )
+}
+
+// EMIT_MIR assume.assume_constant.built.after.mir
+#[custom_mir(dialect = "built")]
+fn assume_constant() {
+    mir!(
+        {
+            Assume(true);
+            Return()
+        }
+    )
+}
+
+fn main() {
+    assume_local(true);
+    assume_place((true, 50));
+    assume_constant();
+}
diff --git a/tests/mir-opt/jump_threading.assume.JumpThreading.panic-abort.diff b/tests/mir-opt/jump_threading.assume.JumpThreading.panic-abort.diff
new file mode 100644
index 00000000000..f1f0106fdbc
--- /dev/null
+++ b/tests/mir-opt/jump_threading.assume.JumpThreading.panic-abort.diff
@@ -0,0 +1,39 @@
+- // MIR for `assume` before JumpThreading
++ // MIR for `assume` after JumpThreading
+  
+  fn assume(_1: u8, _2: bool) -> u8 {
+      let mut _0: u8;
+  
+      bb0: {
+          switchInt(_1) -> [7: bb1, otherwise: bb2];
+      }
+  
+      bb1: {
+          assume(_2);
+-         goto -> bb3;
++         goto -> bb6;
+      }
+  
+      bb2: {
+          goto -> bb3;
+      }
+  
+      bb3: {
+          switchInt(_2) -> [0: bb4, otherwise: bb5];
+      }
+  
+      bb4: {
+          _0 = const 4_u8;
+          return;
+      }
+  
+      bb5: {
+          _0 = const 5_u8;
+          return;
++     }
++ 
++     bb6: {
++         goto -> bb5;
+      }
+  }
+  
diff --git a/tests/mir-opt/jump_threading.assume.JumpThreading.panic-unwind.diff b/tests/mir-opt/jump_threading.assume.JumpThreading.panic-unwind.diff
new file mode 100644
index 00000000000..f1f0106fdbc
--- /dev/null
+++ b/tests/mir-opt/jump_threading.assume.JumpThreading.panic-unwind.diff
@@ -0,0 +1,39 @@
+- // MIR for `assume` before JumpThreading
++ // MIR for `assume` after JumpThreading
+  
+  fn assume(_1: u8, _2: bool) -> u8 {
+      let mut _0: u8;
+  
+      bb0: {
+          switchInt(_1) -> [7: bb1, otherwise: bb2];
+      }
+  
+      bb1: {
+          assume(_2);
+-         goto -> bb3;
++         goto -> bb6;
+      }
+  
+      bb2: {
+          goto -> bb3;
+      }
+  
+      bb3: {
+          switchInt(_2) -> [0: bb4, otherwise: bb5];
+      }
+  
+      bb4: {
+          _0 = const 4_u8;
+          return;
+      }
+  
+      bb5: {
+          _0 = const 5_u8;
+          return;
++     }
++ 
++     bb6: {
++         goto -> bb5;
+      }
+  }
+  
diff --git a/tests/mir-opt/jump_threading.rs b/tests/mir-opt/jump_threading.rs
index 7c2fa42828b..a66fe8b57e7 100644
--- a/tests/mir-opt/jump_threading.rs
+++ b/tests/mir-opt/jump_threading.rs
@@ -468,6 +468,52 @@ fn aggregate(x: u8) -> u8 {
     }
 }
 
+/// Verify that we can leverage the existence of an `Assume` terminator.
+#[custom_mir(dialect = "runtime", phase = "post-cleanup")]
+fn assume(a: u8, b: bool) -> u8 {
+    // CHECK-LABEL: fn assume(
+    mir!(
+        {
+            // CHECK: bb0: {
+            // CHECK-NEXT: switchInt(_1) -> [7: bb1, otherwise: bb2]
+            match a { 7 => bb1, _ => bb2 }
+        }
+        bb1 = {
+            // CHECK: bb1: {
+            // CHECK-NEXT: assume(_2);
+            // CHECK-NEXT: goto -> bb6;
+            Assume(b);
+            Goto(bb3)
+        }
+        bb2 = {
+            // CHECK: bb2: {
+            // CHECK-NEXT: goto -> bb3;
+            Goto(bb3)
+        }
+        bb3 = {
+            // CHECK: bb3: {
+            // CHECK-NEXT: switchInt(_2) -> [0: bb4, otherwise: bb5];
+            match b { false => bb4, _ => bb5 }
+        }
+        bb4 = {
+            // CHECK: bb4: {
+            // CHECK-NEXT: _0 = const 4_u8;
+            // CHECK-NEXT: return;
+            RET = 4;
+            Return()
+        }
+        bb5 = {
+            // CHECK: bb5: {
+            // CHECK-NEXT: _0 = const 5_u8;
+            // CHECK-NEXT: return;
+            RET = 5;
+            Return()
+        }
+        // CHECK: bb6: {
+        // CHECK-NEXT: goto -> bb5;
+    )
+}
+
 fn main() {
     // CHECK-LABEL: fn main(
     too_complex(Ok(0));
@@ -481,6 +527,7 @@ fn main() {
     renumbered_bb(true);
     disappearing_bb(7);
     aggregate(7);
+    assume(7, false);
 }
 
 // EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@@ -494,3 +541,4 @@ fn main() {
 // EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
 // EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
 // EMIT_MIR jump_threading.aggregate.JumpThreading.diff
+// EMIT_MIR jump_threading.assume.JumpThreading.diff