about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNilstrieb <48135649+Nilstrieb@users.noreply.github.com>2024-07-27 15:08:11 +0200
committerNilstrieb <48135649+Nilstrieb@users.noreply.github.com>2024-07-27 15:11:59 +0200
commitf305e188041b586fb162161f961298f1532fe83b (patch)
treee19d0aeb1476a77a00df3f28c72e6e8abbbf705a
parent3942254d00bf95cd5920980f85ebea57a1e6ce2a (diff)
downloadrust-f305e188041b586fb162161f961298f1532fe83b.tar.gz
rust-f305e188041b586fb162161f961298f1532fe83b.zip
Disable jump threading of float equality
Jump threading stores values as `u128` (`ScalarInt`) and does its
comparisons for equality as integer comparisons.
This works great for integers. Sadly, not everything is an integer.

Floats famously have wonky equality semantcs, with `NaN!=NaN` and
`0.0 == -0.0`. This does not match our beautiful integer bitpattern
equality and therefore causes things to go horribly wrong.

While jump threading could be extended to support floats by remembering
that they're floats in the value state and handling them properly,
it's signficantly easier to just disable it for now.
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs7
-rw-r--r--tests/mir-opt/jump_threading.floats.JumpThreading.panic-abort.diff59
-rw-r--r--tests/mir-opt/jump_threading.floats.JumpThreading.panic-unwind.diff59
-rw-r--r--tests/mir-opt/jump_threading.rs12
4 files changed, 137 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
index 2100f4b4a1a..96c52845a4a 100644
--- a/compiler/rustc_mir_transform/src/jump_threading.rs
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -509,6 +509,13 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
                     BinOp::Ne => ScalarInt::FALSE,
                     _ => return None,
                 };
+                if value.const_.ty().is_floating_point() {
+                    // Floating point equality does not follow bit-patterns.
+                    // -0.0 and NaN both have special rules for equality,
+                    // and therefore we cannot use integer comparisons for them.
+                    // Avoid handling them, though this could be extended in the future.
+                    return None;
+                }
                 let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
                 let conds = conditions.map(self.arena, |c| Condition {
                     value,
diff --git a/tests/mir-opt/jump_threading.floats.JumpThreading.panic-abort.diff b/tests/mir-opt/jump_threading.floats.JumpThreading.panic-abort.diff
new file mode 100644
index 00000000000..6ca37e96d29
--- /dev/null
+++ b/tests/mir-opt/jump_threading.floats.JumpThreading.panic-abort.diff
@@ -0,0 +1,59 @@
+- // MIR for `floats` before JumpThreading
++ // MIR for `floats` after JumpThreading
+  
+  fn floats() -> u32 {
+      let mut _0: u32;
+      let _1: f64;
+      let mut _2: bool;
+      let mut _3: bool;
+      let mut _4: f64;
+      scope 1 {
+          debug x => _1;
+      }
+  
+      bb0: {
+          StorageLive(_1);
+          StorageLive(_2);
+          _2 = const true;
+-         switchInt(move _2) -> [0: bb2, otherwise: bb1];
++         goto -> bb1;
+      }
+  
+      bb1: {
+          _1 = const -0f64;
+          goto -> bb3;
+      }
+  
+      bb2: {
+          _1 = const 1f64;
+          goto -> bb3;
+      }
+  
+      bb3: {
+          StorageDead(_2);
+          StorageLive(_3);
+          StorageLive(_4);
+          _4 = _1;
+          _3 = Eq(move _4, const 0f64);
+          switchInt(move _3) -> [0: bb5, otherwise: bb4];
+      }
+  
+      bb4: {
+          StorageDead(_4);
+          _0 = const 0_u32;
+          goto -> bb6;
+      }
+  
+      bb5: {
+          StorageDead(_4);
+          _0 = const 1_u32;
+          goto -> bb6;
+      }
+  
+      bb6: {
+          StorageDead(_3);
+          StorageDead(_1);
+          return;
+      }
+  }
+  
diff --git a/tests/mir-opt/jump_threading.floats.JumpThreading.panic-unwind.diff b/tests/mir-opt/jump_threading.floats.JumpThreading.panic-unwind.diff
new file mode 100644
index 00000000000..6ca37e96d29
--- /dev/null
+++ b/tests/mir-opt/jump_threading.floats.JumpThreading.panic-unwind.diff
@@ -0,0 +1,59 @@
+- // MIR for `floats` before JumpThreading
++ // MIR for `floats` after JumpThreading
+  
+  fn floats() -> u32 {
+      let mut _0: u32;
+      let _1: f64;
+      let mut _2: bool;
+      let mut _3: bool;
+      let mut _4: f64;
+      scope 1 {
+          debug x => _1;
+      }
+  
+      bb0: {
+          StorageLive(_1);
+          StorageLive(_2);
+          _2 = const true;
+-         switchInt(move _2) -> [0: bb2, otherwise: bb1];
++         goto -> bb1;
+      }
+  
+      bb1: {
+          _1 = const -0f64;
+          goto -> bb3;
+      }
+  
+      bb2: {
+          _1 = const 1f64;
+          goto -> bb3;
+      }
+  
+      bb3: {
+          StorageDead(_2);
+          StorageLive(_3);
+          StorageLive(_4);
+          _4 = _1;
+          _3 = Eq(move _4, const 0f64);
+          switchInt(move _3) -> [0: bb5, otherwise: bb4];
+      }
+  
+      bb4: {
+          StorageDead(_4);
+          _0 = const 0_u32;
+          goto -> bb6;
+      }
+  
+      bb5: {
+          StorageDead(_4);
+          _0 = const 1_u32;
+          goto -> bb6;
+      }
+  
+      bb6: {
+          StorageDead(_3);
+          StorageDead(_1);
+          return;
+      }
+  }
+  
diff --git a/tests/mir-opt/jump_threading.rs b/tests/mir-opt/jump_threading.rs
index de290c1ef44..e5d8525dcac 100644
--- a/tests/mir-opt/jump_threading.rs
+++ b/tests/mir-opt/jump_threading.rs
@@ -521,6 +521,16 @@ fn aggregate_copy() -> u32 {
     if c == 2 { b.0 } else { 13 }
 }
 
+fn floats() -> u32 {
+    // CHECK-LABEL: fn floats(
+    // CHECK: switchInt(
+
+    // Test for issue #128243, where float equality was assumed to be bitwise.
+    // When adding float support, it must be ensured that this continues working properly.
+    let x = if true { -0.0 } else { 1.0 };
+    if x == 0.0 { 0 } else { 1 }
+}
+
 fn main() {
     // CHECK-LABEL: fn main(
     too_complex(Ok(0));
@@ -535,6 +545,7 @@ fn main() {
     disappearing_bb(7);
     aggregate(7);
     assume(7, false);
+    floats();
 }
 
 // EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@@ -550,3 +561,4 @@ fn main() {
 // EMIT_MIR jump_threading.aggregate.JumpThreading.diff
 // EMIT_MIR jump_threading.assume.JumpThreading.diff
 // EMIT_MIR jump_threading.aggregate_copy.JumpThreading.diff
+// EMIT_MIR jump_threading.floats.JumpThreading.diff