about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Benfield <mbenfield@google.com>2022-11-17 15:03:02 +0000
committerMichael Benfield <mbenfield@google.com>2022-11-18 21:16:12 +0000
commit31c0645b9d2539f47eecb096142474b29dc542f7 (patch)
tree30828b9875ceedf215a0858810ebeb08b5ad5e94
parentfd3bfb35511cbcff59ce1454d3db627b576d7e92 (diff)
downloadrust-31c0645b9d2539f47eecb096142474b29dc542f7.tar.gz
rust-31c0645b9d2539f47eecb096142474b29dc542f7.zip
rustc_codegen_ssa: Fix for codegen_get_discr
When doing the optimized implementation of getting the discriminant, the
arithmetic needs to be done in the tag type so wrapping behavior works
correctly.

Fixes #104519
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/place.rs27
-rw-r--r--src/test/codegen/enum-match.rs7
-rw-r--r--src/test/ui/enum-discriminant/issue-104519.rs36
3 files changed, 56 insertions, 14 deletions
diff --git a/compiler/rustc_codegen_ssa/src/mir/place.rs b/compiler/rustc_codegen_ssa/src/mir/place.rs
index 90855538589..fbe30154a7c 100644
--- a/compiler/rustc_codegen_ssa/src/mir/place.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/place.rs
@@ -309,14 +309,14 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
                 // In the algorithm above, we can change
                 // cast(relative_tag) + niche_variants.start()
                 // into
-                // cast(tag) + (niche_variants.start() - niche_start)
+                // cast(tag + (niche_variants.start() - niche_start))
                 // if either the casted type is no larger than the original
                 // type, or if the niche values are contiguous (in either the
                 // signed or unsigned sense).
-                let can_incr_after_cast = cast_smaller || niches_ule || niches_sle;
+                let can_incr = cast_smaller || niches_ule || niches_sle;
 
                 let data_for_boundary_niche = || -> Option<(IntPredicate, u128)> {
-                    if !can_incr_after_cast {
+                    if !can_incr {
                         None
                     } else if niche_start == low_unsigned {
                         Some((IntPredicate::IntULE, niche_end))
@@ -353,24 +353,33 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
                     // The algorithm is now this:
                     // is_niche = tag <= niche_end
                     // discr = if is_niche {
-                    //     cast(tag) + (niche_variants.start() - niche_start)
+                    //     cast(tag + (niche_variants.start() - niche_start))
                     // } else {
                     //     untagged_variant
                     // }
                     // (the first line may instead be tag >= niche_start,
                     // and may be a signed or unsigned comparison)
+                    // The arithmetic must be done before the cast, so we can
+                    // have the correct wrapping behavior. See issue #104519 for
+                    // the consequences of getting this wrong.
                     let is_niche =
                         bx.icmp(predicate, tag, bx.cx().const_uint_big(tag_llty, constant));
+                    let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start);
+                    let incr_tag = if delta == 0 {
+                        tag
+                    } else {
+                        bx.add(tag, bx.cx().const_uint_big(tag_llty, delta))
+                    };
+
                     let cast_tag = if cast_smaller {
-                        bx.intcast(tag, cast_to, false)
+                        bx.intcast(incr_tag, cast_to, false)
                     } else if niches_ule {
-                        bx.zext(tag, cast_to)
+                        bx.zext(incr_tag, cast_to)
                     } else {
-                        bx.sext(tag, cast_to)
+                        bx.sext(incr_tag, cast_to)
                     };
 
-                    let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start);
-                    (is_niche, cast_tag, delta)
+                    (is_niche, cast_tag, 0)
                 } else {
                     // The special cases don't apply, so we'll have to go with
                     // the general algorithm.
diff --git a/src/test/codegen/enum-match.rs b/src/test/codegen/enum-match.rs
index efab189fd7b..44f1b408d21 100644
--- a/src/test/codegen/enum-match.rs
+++ b/src/test/codegen/enum-match.rs
@@ -34,11 +34,8 @@ pub enum Enum1 {
 
 // CHECK: define i8 @match1{{.*}}
 // CHECK-NEXT: start:
-// CHECK-NEXT: %1 = icmp ugt i8 %0, 1
-// CHECK-NEXT: %2 = zext i8 %0 to i64
-// CHECK-NEXT: %3 = add nsw i64 %2, -1
-// CHECK-NEXT: %_2 = select i1 %1, i64 %3, i64 0
-// CHECK-NEXT: switch i64 %_2, label {{.*}} [
+// CHECK-NEXT: %1 = {{.*}}call i8 @llvm.usub.sat.i8(i8 %0, i8 1)
+// CHECK-NEXT: switch i8 %1, label {{.*}} [
 #[no_mangle]
 pub fn match1(e: Enum1) -> u8 {
     use Enum1::*;
diff --git a/src/test/ui/enum-discriminant/issue-104519.rs b/src/test/ui/enum-discriminant/issue-104519.rs
new file mode 100644
index 00000000000..c4630f76b3a
--- /dev/null
+++ b/src/test/ui/enum-discriminant/issue-104519.rs
@@ -0,0 +1,36 @@
+// run-pass
+#![allow(dead_code)]
+
+enum OpenResult {
+    Ok(()),
+    Err(()),
+    TransportErr(TransportErr),
+}
+
+#[repr(i32)]
+enum TransportErr {
+    UnknownMethod = -2,
+}
+
+#[inline(never)]
+fn some_match(result: OpenResult) -> u8 {
+    match result {
+        OpenResult::Ok(()) => 0,
+        _ => 1,
+    }
+}
+
+fn main() {
+    let result = OpenResult::Ok(());
+    assert_eq!(some_match(result), 0);
+
+    let result = OpenResult::Ok(());
+    match result {
+        OpenResult::Ok(()) => (),
+        _ => unreachable!("message a"),
+    }
+    match result {
+        OpenResult::Ok(()) => (),
+        _ => unreachable!("message b"),
+    }
+}