diff options
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/intrinsic.rs | 31 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/lib.rs | 1 | ||||
| -rw-r--r-- | library/core/src/intrinsics/fallback.rs | 4 | ||||
| -rw-r--r-- | library/core/tests/intrinsics.rs | 10 | ||||
| -rw-r--r-- | tests/codegen/intrinsics/carrying_mul_add.rs | 137 |
5 files changed, 181 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index c38c5d4c644..cabcfc9b42b 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -340,6 +340,37 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { self.const_i32(cache_type), ]) } + sym::carrying_mul_add => { + let (size, signed) = fn_args.type_at(0).int_size_and_signed(self.tcx); + + let wide_llty = self.type_ix(size.bits() * 2); + let args = args.as_array().unwrap(); + let [a, b, c, d] = args.map(|a| self.intcast(a.immediate(), wide_llty, signed)); + + let wide = if signed { + let prod = self.unchecked_smul(a, b); + let acc = self.unchecked_sadd(prod, c); + self.unchecked_sadd(acc, d) + } else { + let prod = self.unchecked_umul(a, b); + let acc = self.unchecked_uadd(prod, c); + self.unchecked_uadd(acc, d) + }; + + let narrow_llty = self.type_ix(size.bits()); + let low = self.trunc(wide, narrow_llty); + let bits_const = self.const_uint(wide_llty, size.bits()); + // No need for ashr when signed; LLVM changes it to lshr anyway. + let high = self.lshr(wide, bits_const); + // FIXME: could be `trunc nuw`, even for signed. + let high = self.trunc(high, narrow_llty); + + let pair_llty = self.type_struct(&[narrow_llty, narrow_llty], false); + let pair = self.const_poison(pair_llty); + let pair = self.insert_value(pair, low, 0); + let pair = self.insert_value(pair, high, 1); + pair + } sym::ctlz | sym::ctlz_nonzero | sym::cttz diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 0de0c6a7a89..dca7738daf7 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -17,6 +17,7 @@ #![feature(iter_intersperse)] #![feature(let_chains)] #![feature(rustdoc_internals)] +#![feature(slice_as_array)] #![feature(try_blocks)] #![warn(unreachable_pub)] // tidy-alphabetical-end diff --git a/library/core/src/intrinsics/fallback.rs b/library/core/src/intrinsics/fallback.rs index d87331f7262..1779126b180 100644 --- a/library/core/src/intrinsics/fallback.rs +++ b/library/core/src/intrinsics/fallback.rs @@ -100,8 +100,8 @@ impl const CarryingMulAdd for i128 { fn carrying_mul_add(self, b: i128, c: i128, d: i128) -> (u128, i128) { let (low, high) = wide_mul_u128(self as u128, b as u128); let mut high = high as i128; - high = high.wrapping_add((self >> 127) * b); - high = high.wrapping_add(self * (b >> 127)); + high = high.wrapping_add(i128::wrapping_mul(self >> 127, b)); + high = high.wrapping_add(i128::wrapping_mul(self, b >> 127)); let (low, carry) = u128::overflowing_add(low, c as u128); high = high.wrapping_add((carry as i128) + (c >> 127)); let (low, carry) = u128::overflowing_add(low, d as u128); diff --git a/library/core/tests/intrinsics.rs b/library/core/tests/intrinsics.rs index 76f42594091..744a6a0d2dd 100644 --- a/library/core/tests/intrinsics.rs +++ b/library/core/tests/intrinsics.rs @@ -153,6 +153,7 @@ fn carrying_mul_add_fallback_i32() { #[test] fn carrying_mul_add_fallback_u128() { + assert_eq!(fallback_cma::<u128>(u128::MAX, u128::MAX, 0, 0), (1, u128::MAX - 1)); assert_eq!(fallback_cma::<u128>(1, 1, 1, 1), (3, 0)); assert_eq!(fallback_cma::<u128>(0, 0, u128::MAX, u128::MAX), (u128::MAX - 1, 1)); assert_eq!( @@ -178,8 +179,17 @@ fn carrying_mul_add_fallback_u128() { #[test] fn carrying_mul_add_fallback_i128() { + assert_eq!(fallback_cma::<i128>(-1, -1, 0, 0), (1, 0)); let r = fallback_cma::<i128>(-1, -1, -1, -1); assert_eq!(r, (u128::MAX, -1)); let r = fallback_cma::<i128>(1, -1, 1, 1); assert_eq!(r, (1, 0)); + assert_eq!( + fallback_cma::<i128>(i128::MAX, i128::MAX, i128::MAX, i128::MAX), + (u128::MAX, i128::MAX / 2), + ); + assert_eq!( + fallback_cma::<i128>(i128::MIN, i128::MIN, i128::MAX, i128::MAX), + (u128::MAX - 1, -(i128::MIN / 2)), + ); } diff --git a/tests/codegen/intrinsics/carrying_mul_add.rs b/tests/codegen/intrinsics/carrying_mul_add.rs new file mode 100644 index 00000000000..174c4077f09 --- /dev/null +++ b/tests/codegen/intrinsics/carrying_mul_add.rs @@ -0,0 +1,137 @@ +//@ revisions: RAW OPT +//@ compile-flags: -C opt-level=1 +//@[RAW] compile-flags: -C no-prepopulate-passes +//@[OPT] min-llvm-version: 19 + +#![crate_type = "lib"] +#![feature(core_intrinsics)] +#![feature(core_intrinsics_fallbacks)] + +// Note that LLVM seems to sometimes permute the order of arguments to mul and add, +// so these tests don't check the arguments in the optimized revision. + +use std::intrinsics::{carrying_mul_add, fallback}; + +// The fallbacks are emitted even when they're never used, but optimize out. + +// RAW: wide_mul_u128 +// OPT-NOT: wide_mul_u128 + +// CHECK-LABEL: @cma_u8 +#[no_mangle] +pub unsafe fn cma_u8(a: u8, b: u8, c: u8, d: u8) -> (u8, u8) { + // CHECK: [[A:%.+]] = zext i8 %a to i16 + // CHECK: [[B:%.+]] = zext i8 %b to i16 + // CHECK: [[C:%.+]] = zext i8 %c to i16 + // CHECK: [[D:%.+]] = zext i8 %d to i16 + // CHECK: [[AB:%.+]] = mul nuw i16 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nuw i16 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nuw i16 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i16 [[ABCD]] to i8 + // CHECK: [[HIGHW:%.+]] = lshr i16 [[ABCD]], 8 + // RAW: [[HIGH:%.+]] = trunc i16 [[HIGHW]] to i8 + // OPT: [[HIGH:%.+]] = trunc nuw i16 [[HIGHW]] to i8 + // CHECK: [[PAIR0:%.+]] = insertvalue { i8, i8 } poison, i8 [[LOW]], 0 + // CHECK: [[PAIR1:%.+]] = insertvalue { i8, i8 } [[PAIR0]], i8 [[HIGH]], 1 + // OPT: ret { i8, i8 } [[PAIR1]] + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @cma_u32 +#[no_mangle] +pub unsafe fn cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) { + // CHECK: [[A:%.+]] = zext i32 %a to i64 + // CHECK: [[B:%.+]] = zext i32 %b to i64 + // CHECK: [[C:%.+]] = zext i32 %c to i64 + // CHECK: [[D:%.+]] = zext i32 %d to i64 + // CHECK: [[AB:%.+]] = mul nuw i64 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nuw i64 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nuw i64 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32 + // CHECK: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32 + // RAW: [[HIGH:%.+]] = trunc i64 [[HIGHW]] to i32 + // OPT: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32 + // CHECK: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0 + // CHECK: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1 + // OPT: ret { i32, i32 } [[PAIR1]] + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @cma_u128 +// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d +#[no_mangle] +pub unsafe fn cma_u128(a: u128, b: u128, c: u128, d: u128) -> (u128, u128) { + // CHECK: [[A:%.+]] = zext i128 %a to i256 + // CHECK: [[B:%.+]] = zext i128 %b to i256 + // CHECK: [[C:%.+]] = zext i128 %c to i256 + // CHECK: [[D:%.+]] = zext i128 %d to i256 + // CHECK: [[AB:%.+]] = mul nuw i256 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nuw i256 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nuw i256 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128 + // CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128 + // RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128 + // OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128 + // RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0 + // RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1 + // OPT: store i128 [[LOW]], ptr %_0 + // OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, {{i32|i64}} 16 + // OPT: store i128 [[HIGH]], ptr [[P1]] + // CHECK: ret void + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @cma_i128 +// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d +#[no_mangle] +pub unsafe fn cma_i128(a: i128, b: i128, c: i128, d: i128) -> (u128, i128) { + // CHECK: [[A:%.+]] = sext i128 %a to i256 + // CHECK: [[B:%.+]] = sext i128 %b to i256 + // CHECK: [[C:%.+]] = sext i128 %c to i256 + // CHECK: [[D:%.+]] = sext i128 %d to i256 + // CHECK: [[AB:%.+]] = mul nsw i256 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nsw i256 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nsw i256 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128 + // CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128 + // RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128 + // OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128 + // RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0 + // RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1 + // OPT: store i128 [[LOW]], ptr %_0 + // OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, {{i32|i64}} 16 + // OPT: store i128 [[HIGH]], ptr [[P1]] + // CHECK: ret void + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @fallback_cma_u32 +#[no_mangle] +pub unsafe fn fallback_cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) { + // OPT-DAG: [[A:%.+]] = zext i32 %a to i64 + // OPT-DAG: [[B:%.+]] = zext i32 %b to i64 + // OPT-DAG: [[AB:%.+]] = mul nuw i64 + // OPT-DAG: [[C:%.+]] = zext i32 %c to i64 + // OPT-DAG: [[ABC:%.+]] = add nuw i64{{.+}}[[C]] + // OPT-DAG: [[D:%.+]] = zext i32 %d to i64 + // OPT-DAG: [[ABCD:%.+]] = add nuw i64{{.+}}[[D]] + // OPT-DAG: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32 + // OPT-DAG: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32 + // OPT-DAG: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32 + // OPT-DAG: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0 + // OPT-DAG: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1 + // OPT-DAG: ret { i32, i32 } [[PAIR1]] + fallback::CarryingMulAdd::carrying_mul_add(a, b, c, d) +} |
