about summary refs log tree commit diff
diff options
context:
space:
mode:
authorBastian Kersting <bkersting@google.com>2025-07-01 08:59:52 +0000
committerBastian Kersting <bkersting@google.com>2025-07-02 20:02:27 +0300
commit8a0d8dde44b3b413488aa69caebc9cc3fb0f7edf (patch)
tree0ee83fc516c1e41a6ae154285758775300653c0d
parentf51c9870bab634afb9e7a262b6ca7816bb9e940d (diff)
downloadrust-8a0d8dde44b3b413488aa69caebc9cc3fb0f7edf.tar.gz
rust-8a0d8dde44b3b413488aa69caebc9cc3fb0f7edf.zip
Make the enum check work for negative discriminants
The discriminant check was not working correctly for negative numbers.
This change fixes that by masking out the relevant bits correctly.
-rw-r--r--compiler/rustc_mir_transform/src/check_enums.rs30
-rw-r--r--tests/ui/mir/enum/negative_discr_break.rs14
-rw-r--r--tests/ui/mir/enum/negative_discr_ok.rs53
3 files changed, 93 insertions, 4 deletions
diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs
index 240da87ab27..fae984b4936 100644
--- a/compiler/rustc_mir_transform/src/check_enums.rs
+++ b/compiler/rustc_mir_transform/src/check_enums.rs
@@ -120,6 +120,7 @@ enum EnumCheckType<'tcx> {
     },
 }
 
+#[derive(Debug, Copy, Clone)]
 struct TyAndSize<'tcx> {
     pub ty: Ty<'tcx>,
     pub size: Size,
@@ -337,7 +338,7 @@ fn insert_direct_enum_check<'tcx>(
     let invalid_discr_block_data = BasicBlockData::new(None, false);
     let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
     let block_data = &mut basic_blocks[current_block];
-    let discr = insert_discr_cast_to_u128(
+    let discr_place = insert_discr_cast_to_u128(
         tcx,
         local_decls,
         block_data,
@@ -348,13 +349,34 @@ fn insert_direct_enum_check<'tcx>(
         source_info,
     );
 
+    // Mask out the bits of the discriminant type.
+    let mask = discr.size.unsigned_int_max();
+    let discr_masked =
+        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
+    let rvalue = Rvalue::BinaryOp(
+        BinOp::BitAnd,
+        Box::new((
+            Operand::Copy(discr_place),
+            Operand::Constant(Box::new(ConstOperand {
+                span: source_info.span,
+                user_ty: None,
+                const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
+            })),
+        )),
+    );
+    block_data
+        .statements
+        .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
+
     // Branch based on the discriminant value.
     block_data.terminator = Some(Terminator {
         source_info,
         kind: TerminatorKind::SwitchInt {
-            discr: Operand::Copy(discr),
+            discr: Operand::Copy(discr_masked),
             targets: SwitchTargets::new(
-                discriminants.into_iter().map(|discr| (discr, new_block)),
+                discriminants
+                    .into_iter()
+                    .map(|discr_val| (discr.size.truncate(discr_val), new_block)),
                 invalid_discr_block,
             ),
         },
@@ -371,7 +393,7 @@ fn insert_direct_enum_check<'tcx>(
             })),
             expected: true,
             target: new_block,
-            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
+            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
             // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
             // We never want to insert an unwind into unsafe code, because unwinding could
             // make a failing UB check turn into much worse UB when we start unwinding.
diff --git a/tests/ui/mir/enum/negative_discr_break.rs b/tests/ui/mir/enum/negative_discr_break.rs
new file mode 100644
index 00000000000..fa1284f72a0
--- /dev/null
+++ b/tests/ui/mir/enum/negative_discr_break.rs
@@ -0,0 +1,14 @@
+//@ run-fail
+//@ compile-flags: -C debug-assertions
+//@ error-pattern: trying to construct an enum from an invalid value 0xfd
+
+#[allow(dead_code)]
+enum Foo {
+    A = -2,
+    B = -1,
+    C = 1,
+}
+
+fn main() {
+    let _val: Foo = unsafe { std::mem::transmute::<i8, Foo>(-3) };
+}
diff --git a/tests/ui/mir/enum/negative_discr_ok.rs b/tests/ui/mir/enum/negative_discr_ok.rs
new file mode 100644
index 00000000000..5c15b33fa84
--- /dev/null
+++ b/tests/ui/mir/enum/negative_discr_ok.rs
@@ -0,0 +1,53 @@
+//@ run-pass
+//@ compile-flags: -C debug-assertions
+
+#[allow(dead_code)]
+#[derive(Debug, PartialEq)]
+enum Foo {
+    A = -12121,
+    B = -2,
+    C = -1,
+    D = 1,
+    E = 2,
+    F = 12121,
+}
+
+#[allow(dead_code)]
+#[repr(i64)]
+#[derive(Debug, PartialEq)]
+enum Bar {
+    A = i64::MIN,
+    B = -2,
+    C = -1,
+    D = 1,
+    E = 2,
+    F = i64::MAX,
+}
+
+fn main() {
+    let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-12121) };
+    assert_eq!(val, Foo::A);
+    let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-2) };
+    assert_eq!(val, Foo::B);
+    let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-1) };
+    assert_eq!(val, Foo::C);
+    let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(1) };
+    assert_eq!(val, Foo::D);
+    let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(2) };
+    assert_eq!(val, Foo::E);
+    let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(12121) };
+    assert_eq!(val, Foo::F);
+
+    let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MIN) };
+    assert_eq!(val, Bar::A);
+    let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-2) };
+    assert_eq!(val, Bar::B);
+    let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-1) };
+    assert_eq!(val, Bar::C);
+    let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(1) };
+    assert_eq!(val, Bar::D);
+    let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(2) };
+    assert_eq!(val, Bar::E);
+    let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MAX) };
+    assert_eq!(val, Bar::F);
+}