about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs22
-rw-r--r--tests/mir-opt/set_no_discriminant.f.JumpThreading.diff26
-rw-r--r--tests/mir-opt/set_no_discriminant.generic.JumpThreading.diff26
-rw-r--r--tests/mir-opt/set_no_discriminant.rs78
4 files changed, 150 insertions, 2 deletions
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
index a41d8e21245..dcab124505e 100644
--- a/compiler/rustc_mir_transform/src/jump_threading.rs
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -43,6 +43,7 @@ use rustc_middle::mir::visit::Visitor;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
+use rustc_target::abi::{TagEncoding, Variants};
 
 use crate::cost_checker::CostChecker;
 
@@ -391,8 +392,25 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
             StatementKind::SetDiscriminant { box place, variant_index } => {
                 let discr_target = self.map.find_discr(place.as_ref())?;
                 let enum_ty = place.ty(self.body, self.tcx).ty;
-                let discr = discriminant_for_variant(enum_ty, *variant_index)?;
-                self.process_operand(bb, discr_target, &discr, state)?;
+                // `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
+                // of a niche encoding. If we cannot ensure that we write to the discriminant, do
+                // nothing.
+                let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
+                let writes_discriminant = match enum_layout.variants {
+                    Variants::Single { index } => {
+                        assert_eq!(index, *variant_index);
+                        true
+                    }
+                    Variants::Multiple { tag_encoding: TagEncoding::Direct, .. } => true,
+                    Variants::Multiple {
+                        tag_encoding: TagEncoding::Niche { untagged_variant, .. },
+                        ..
+                    } => *variant_index != untagged_variant,
+                };
+                if writes_discriminant {
+                    let discr = discriminant_for_variant(enum_ty, *variant_index)?;
+                    self.process_operand(bb, discr_target, &discr, state)?;
+                }
             }
             // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
             StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
diff --git a/tests/mir-opt/set_no_discriminant.f.JumpThreading.diff b/tests/mir-opt/set_no_discriminant.f.JumpThreading.diff
new file mode 100644
index 00000000000..bc28e81c9a8
--- /dev/null
+++ b/tests/mir-opt/set_no_discriminant.f.JumpThreading.diff
@@ -0,0 +1,26 @@
+- // MIR for `f` before JumpThreading
++ // MIR for `f` after JumpThreading
+  
+  fn f() -> usize {
+      let mut _0: usize;
+      let mut _1: isize;
+      let mut _2: E<char>;
+  
+      bb0: {
+          _2 = E::<char>::A;
+          discriminant(_2) = 1;
+          _1 = discriminant(_2);
+          switchInt(_1) -> [0: bb1, otherwise: bb2];
+      }
+  
+      bb1: {
+          _0 = const 0_usize;
+          return;
+      }
+  
+      bb2: {
+          _0 = const 1_usize;
+          return;
+      }
+  }
+  
diff --git a/tests/mir-opt/set_no_discriminant.generic.JumpThreading.diff b/tests/mir-opt/set_no_discriminant.generic.JumpThreading.diff
new file mode 100644
index 00000000000..78bfeef3c64
--- /dev/null
+++ b/tests/mir-opt/set_no_discriminant.generic.JumpThreading.diff
@@ -0,0 +1,26 @@
+- // MIR for `generic` before JumpThreading
++ // MIR for `generic` after JumpThreading
+  
+  fn generic() -> usize {
+      let mut _0: usize;
+      let mut _1: isize;
+      let mut _2: E<T>;
+  
+      bb0: {
+          _2 = E::<T>::A;
+          discriminant(_2) = 1;
+          _1 = discriminant(_2);
+          switchInt(_1) -> [0: bb1, otherwise: bb2];
+      }
+  
+      bb1: {
+          _0 = const 0_usize;
+          return;
+      }
+  
+      bb2: {
+          _0 = const 1_usize;
+          return;
+      }
+  }
+  
diff --git a/tests/mir-opt/set_no_discriminant.rs b/tests/mir-opt/set_no_discriminant.rs
new file mode 100644
index 00000000000..8ffb9a2910a
--- /dev/null
+++ b/tests/mir-opt/set_no_discriminant.rs
@@ -0,0 +1,78 @@
+// `SetDiscriminant` does not actually write anything if the chosen variant is the untagged variant
+// of a niche encoding. Verify that we do not thread over this case.
+// unit-test: JumpThreading
+
+#![feature(custom_mir)]
+#![feature(core_intrinsics)]
+
+use std::intrinsics::mir::*;
+
+enum E<T> {
+    A,
+    B(T),
+}
+
+// EMIT_MIR set_no_discriminant.f.JumpThreading.diff
+#[custom_mir(dialect = "runtime")]
+pub fn f() -> usize {
+    // CHECK-LABEL: fn f(
+    // CHECK-NOT: goto
+    // CHECK: switchInt(
+    // CHECK-NOT: goto
+    mir!(
+        let a: isize;
+        let e: E<char>;
+        {
+            e = E::A;
+            SetDiscriminant(e, 1);
+            a = Discriminant(e);
+            match a {
+                0 => bb0,
+                _ => bb1,
+            }
+        }
+        bb0 = {
+            RET = 0;
+            Return()
+        }
+        bb1 = {
+            RET = 1;
+            Return()
+        }
+    )
+}
+
+// EMIT_MIR set_no_discriminant.generic.JumpThreading.diff
+#[custom_mir(dialect = "runtime")]
+pub fn generic<T>() -> usize {
+    // CHECK-LABEL: fn generic(
+    // CHECK-NOT: goto
+    // CHECK: switchInt(
+    // CHECK-NOT: goto
+    mir!(
+        let a: isize;
+        let e: E<T>;
+        {
+            e = E::A;
+            SetDiscriminant(e, 1);
+            a = Discriminant(e);
+            match a {
+                0 => bb0,
+                _ => bb1,
+            }
+        }
+        bb0 = {
+            RET = 0;
+            Return()
+        }
+        bb1 = {
+            RET = 1;
+            Return()
+        }
+    )
+}
+
+fn main() {
+    assert_eq!(f(), 0);
+    assert_eq!(generic::<char>(), 0);
+}