diff options
| author | Bastian Kersting <bkersting@google.com> | 2025-07-01 08:59:52 +0000 |
|---|---|---|
| committer | Bastian Kersting <bkersting@google.com> | 2025-07-02 20:02:27 +0300 |
| commit | 8a0d8dde44b3b413488aa69caebc9cc3fb0f7edf (patch) | |
| tree | 0ee83fc516c1e41a6ae154285758775300653c0d | |
| parent | f51c9870bab634afb9e7a262b6ca7816bb9e940d (diff) | |
| download | rust-8a0d8dde44b3b413488aa69caebc9cc3fb0f7edf.tar.gz rust-8a0d8dde44b3b413488aa69caebc9cc3fb0f7edf.zip | |
Make the enum check work for negative discriminants
The discriminant check was not working correctly for negative numbers. This change fixes that by masking out the relevant bits correctly.
| -rw-r--r-- | compiler/rustc_mir_transform/src/check_enums.rs | 30 | ||||
| -rw-r--r-- | tests/ui/mir/enum/negative_discr_break.rs | 14 | ||||
| -rw-r--r-- | tests/ui/mir/enum/negative_discr_ok.rs | 53 |
3 files changed, 93 insertions, 4 deletions
diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs index 240da87ab27..fae984b4936 100644 --- a/compiler/rustc_mir_transform/src/check_enums.rs +++ b/compiler/rustc_mir_transform/src/check_enums.rs @@ -120,6 +120,7 @@ enum EnumCheckType<'tcx> { }, } +#[derive(Debug, Copy, Clone)] struct TyAndSize<'tcx> { pub ty: Ty<'tcx>, pub size: Size, @@ -337,7 +338,7 @@ fn insert_direct_enum_check<'tcx>( let invalid_discr_block_data = BasicBlockData::new(None, false); let invalid_discr_block = basic_blocks.push(invalid_discr_block_data); let block_data = &mut basic_blocks[current_block]; - let discr = insert_discr_cast_to_u128( + let discr_place = insert_discr_cast_to_u128( tcx, local_decls, block_data, @@ -348,13 +349,34 @@ fn insert_direct_enum_check<'tcx>( source_info, ); + // Mask out the bits of the discriminant type. + let mask = discr.size.unsigned_int_max(); + let discr_masked = + local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into(); + let rvalue = Rvalue::BinaryOp( + BinOp::BitAnd, + Box::new(( + Operand::Copy(discr_place), + Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128), + })), + )), + ); + block_data + .statements + .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue))))); + // Branch based on the discriminant value. block_data.terminator = Some(Terminator { source_info, kind: TerminatorKind::SwitchInt { - discr: Operand::Copy(discr), + discr: Operand::Copy(discr_masked), targets: SwitchTargets::new( - discriminants.into_iter().map(|discr| (discr, new_block)), + discriminants + .into_iter() + .map(|discr_val| (discr.size.truncate(discr_val), new_block)), invalid_discr_block, ), }, @@ -371,7 +393,7 @@ fn insert_direct_enum_check<'tcx>( })), expected: true, target: new_block, - msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))), // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. // We never want to insert an unwind into unsafe code, because unwinding could // make a failing UB check turn into much worse UB when we start unwinding. diff --git a/tests/ui/mir/enum/negative_discr_break.rs b/tests/ui/mir/enum/negative_discr_break.rs new file mode 100644 index 00000000000..fa1284f72a0 --- /dev/null +++ b/tests/ui/mir/enum/negative_discr_break.rs @@ -0,0 +1,14 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0xfd + +#[allow(dead_code)] +enum Foo { + A = -2, + B = -1, + C = 1, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::<i8, Foo>(-3) }; +} diff --git a/tests/ui/mir/enum/negative_discr_ok.rs b/tests/ui/mir/enum/negative_discr_ok.rs new file mode 100644 index 00000000000..5c15b33fa84 --- /dev/null +++ b/tests/ui/mir/enum/negative_discr_ok.rs @@ -0,0 +1,53 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[derive(Debug, PartialEq)] +enum Foo { + A = -12121, + B = -2, + C = -1, + D = 1, + E = 2, + F = 12121, +} + +#[allow(dead_code)] +#[repr(i64)] +#[derive(Debug, PartialEq)] +enum Bar { + A = i64::MIN, + B = -2, + C = -1, + D = 1, + E = 2, + F = i64::MAX, +} + +fn main() { + let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-12121) }; + assert_eq!(val, Foo::A); + let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-2) }; + assert_eq!(val, Foo::B); + let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-1) }; + assert_eq!(val, Foo::C); + let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(1) }; + assert_eq!(val, Foo::D); + let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(2) }; + assert_eq!(val, Foo::E); + let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(12121) }; + assert_eq!(val, Foo::F); + + let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MIN) }; + assert_eq!(val, Bar::A); + let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-2) }; + assert_eq!(val, Bar::B); + let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-1) }; + assert_eq!(val, Bar::C); + let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(1) }; + assert_eq!(val, Bar::D); + let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(2) }; + assert_eq!(val, Bar::E); + let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MAX) }; + assert_eq!(val, Bar::F); +} |
