about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs21
-rw-r--r--compiler/rustc_codegen_gcc/src/builder.rs25
-rw-r--r--compiler/rustc_codegen_llvm/src/builder.rs48
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs4
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/ffi.rs1
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/intrinsic.rs32
-rw-r--r--compiler/rustc_codegen_ssa/src/traits/builder.rs5
-rw-r--r--compiler/rustc_hir_analysis/src/check/intrinsic.rs12
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp25
-rw-r--r--compiler/rustc_span/src/symbol.rs5
-rw-r--r--library/core/src/intrinsics.rs40
-rw-r--r--tests/codegen/simd/issue-120720-reduce-nan.rs22
12 files changed, 226 insertions, 14 deletions
diff --git a/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs b/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs
index 476752c7230..199d5df29e7 100644
--- a/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs
+++ b/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs
@@ -1152,17 +1152,26 @@ fn codegen_regular_intrinsic_call<'tcx>(
             ret.write_cvalue(fx, ret_val);
         }
 
-        sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => {
+        sym::fadd_fast
+        | sym::fsub_fast
+        | sym::fmul_fast
+        | sym::fdiv_fast
+        | sym::frem_fast
+        | sym::fadd_algebraic
+        | sym::fsub_algebraic
+        | sym::fmul_algebraic
+        | sym::fdiv_algebraic
+        | sym::frem_algebraic => {
             intrinsic_args!(fx, args => (x, y); intrinsic);
 
             let res = crate::num::codegen_float_binop(
                 fx,
                 match intrinsic {
-                    sym::fadd_fast => BinOp::Add,
-                    sym::fsub_fast => BinOp::Sub,
-                    sym::fmul_fast => BinOp::Mul,
-                    sym::fdiv_fast => BinOp::Div,
-                    sym::frem_fast => BinOp::Rem,
+                    sym::fadd_fast | sym::fadd_algebraic => BinOp::Add,
+                    sym::fsub_fast | sym::fsub_algebraic => BinOp::Sub,
+                    sym::fmul_fast | sym::fmul_algebraic => BinOp::Mul,
+                    sym::fdiv_fast | sym::fdiv_algebraic => BinOp::Div,
+                    sym::frem_fast | sym::frem_algebraic => BinOp::Rem,
                     _ => unreachable!(),
                 },
                 x,
diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs
index 42e61b3ccb5..5f1e4538376 100644
--- a/compiler/rustc_codegen_gcc/src/builder.rs
+++ b/compiler/rustc_codegen_gcc/src/builder.rs
@@ -705,6 +705,31 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
         self.frem(lhs, rhs)
     }
 
+    fn fadd_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
+        // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
+        lhs + rhs
+    }
+
+    fn fsub_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
+        // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
+        lhs - rhs
+    }
+
+    fn fmul_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
+        // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
+        lhs * rhs
+    }
+
+    fn fdiv_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
+        // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
+        lhs / rhs
+    }
+
+    fn frem_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
+        // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
+        self.frem(lhs, rhs)
+    }
+
     fn checked_binop(&mut self, oop: OverflowOp, typ: Ty<'_>, lhs: Self::Value, rhs: Self::Value) -> (Self::Value, Self::Value) {
         self.gcc_checked_binop(oop, typ, lhs, rhs)
     }
diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs
index 7ed27b33dce..cfa266720d2 100644
--- a/compiler/rustc_codegen_llvm/src/builder.rs
+++ b/compiler/rustc_codegen_llvm/src/builder.rs
@@ -340,6 +340,46 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
         }
     }
 
+    fn fadd_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
+        unsafe {
+            let instr = llvm::LLVMBuildFAdd(self.llbuilder, lhs, rhs, UNNAMED);
+            llvm::LLVMRustSetAlgebraicMath(instr);
+            instr
+        }
+    }
+
+    fn fsub_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
+        unsafe {
+            let instr = llvm::LLVMBuildFSub(self.llbuilder, lhs, rhs, UNNAMED);
+            llvm::LLVMRustSetAlgebraicMath(instr);
+            instr
+        }
+    }
+
+    fn fmul_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
+        unsafe {
+            let instr = llvm::LLVMBuildFMul(self.llbuilder, lhs, rhs, UNNAMED);
+            llvm::LLVMRustSetAlgebraicMath(instr);
+            instr
+        }
+    }
+
+    fn fdiv_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
+        unsafe {
+            let instr = llvm::LLVMBuildFDiv(self.llbuilder, lhs, rhs, UNNAMED);
+            llvm::LLVMRustSetAlgebraicMath(instr);
+            instr
+        }
+    }
+
+    fn frem_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
+        unsafe {
+            let instr = llvm::LLVMBuildFRem(self.llbuilder, lhs, rhs, UNNAMED);
+            llvm::LLVMRustSetAlgebraicMath(instr);
+            instr
+        }
+    }
+
     fn checked_binop(
         &mut self,
         oop: OverflowOp,
@@ -1327,17 +1367,17 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
     pub fn vector_reduce_fmul(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
         unsafe { llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src) }
     }
-    pub fn vector_reduce_fadd_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
+    pub fn vector_reduce_fadd_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
         unsafe {
             let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src);
-            llvm::LLVMRustSetFastMath(instr);
+            llvm::LLVMRustSetAlgebraicMath(instr);
             instr
         }
     }
-    pub fn vector_reduce_fmul_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
+    pub fn vector_reduce_fmul_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
         unsafe {
             let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src);
-            llvm::LLVMRustSetFastMath(instr);
+            llvm::LLVMRustSetAlgebraicMath(instr);
             instr
         }
     }
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index 4415c51acf6..3b091fca28b 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -1880,14 +1880,14 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
     arith_red!(simd_reduce_mul_ordered: vector_reduce_mul, vector_reduce_fmul, true, mul, 1.0);
     arith_red!(
         simd_reduce_add_unordered: vector_reduce_add,
-        vector_reduce_fadd_fast,
+        vector_reduce_fadd_algebraic,
         false,
         add,
         0.0
     );
     arith_red!(
         simd_reduce_mul_unordered: vector_reduce_mul,
-        vector_reduce_fmul_fast,
+        vector_reduce_fmul_algebraic,
         false,
         mul,
         1.0
diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
index d0044086c61..f9eb1da5dc7 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
@@ -1618,6 +1618,7 @@ extern "C" {
     ) -> &'a Value;
 
     pub fn LLVMRustSetFastMath(Instr: &Value);
+    pub fn LLVMRustSetAlgebraicMath(Instr: &Value);
 
     // Miscellaneous instructions
     pub fn LLVMRustGetInstrProfIncrementIntrinsic(M: &Module) -> &Value;
diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
index e4633acd817..82488829b6e 100644
--- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
@@ -250,6 +250,38 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
                     }
                 }
             }
+            sym::fadd_algebraic
+            | sym::fsub_algebraic
+            | sym::fmul_algebraic
+            | sym::fdiv_algebraic
+            | sym::frem_algebraic => match float_type_width(arg_tys[0]) {
+                Some(_width) => match name {
+                    sym::fadd_algebraic => {
+                        bx.fadd_algebraic(args[0].immediate(), args[1].immediate())
+                    }
+                    sym::fsub_algebraic => {
+                        bx.fsub_algebraic(args[0].immediate(), args[1].immediate())
+                    }
+                    sym::fmul_algebraic => {
+                        bx.fmul_algebraic(args[0].immediate(), args[1].immediate())
+                    }
+                    sym::fdiv_algebraic => {
+                        bx.fdiv_algebraic(args[0].immediate(), args[1].immediate())
+                    }
+                    sym::frem_algebraic => {
+                        bx.frem_algebraic(args[0].immediate(), args[1].immediate())
+                    }
+                    _ => bug!(),
+                },
+                None => {
+                    bx.tcx().dcx().emit_err(InvalidMonomorphization::BasicFloatType {
+                        span,
+                        name,
+                        ty: arg_tys[0],
+                    });
+                    return Ok(());
+                }
+            },
 
             sym::float_to_int_unchecked => {
                 if float_type_width(arg_tys[0]).is_none() {
diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs
index 1c5c78e6ca2..86d3d1260c3 100644
--- a/compiler/rustc_codegen_ssa/src/traits/builder.rs
+++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs
@@ -86,22 +86,27 @@ pub trait BuilderMethods<'a, 'tcx>:
     fn add(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fadd(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fadd_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
+    fn fadd_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn sub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fsub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fsub_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
+    fn fsub_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn mul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fmul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fmul_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
+    fn fmul_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn udiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn exactudiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn sdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn exactsdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn fdiv_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
+    fn fdiv_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn urem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn srem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn frem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn frem_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
+    fn frem_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn shl(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn lshr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
     fn ashr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
index 903c98e8317..05fab60fd8d 100644
--- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs
+++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs
@@ -123,7 +123,12 @@ pub fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -
         | sym::variant_count
         | sym::is_val_statically_known
         | sym::ptr_mask
-        | sym::debug_assertions => hir::Unsafety::Normal,
+        | sym::debug_assertions
+        | sym::fadd_algebraic
+        | sym::fsub_algebraic
+        | sym::fmul_algebraic
+        | sym::fdiv_algebraic
+        | sym::frem_algebraic => hir::Unsafety::Normal,
         _ => hir::Unsafety::Unsafe,
     };
 
@@ -405,6 +410,11 @@ pub fn check_intrinsic_type(
             sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => {
                 (1, 0, vec![param(0), param(0)], param(0))
             }
+            sym::fadd_algebraic
+            | sym::fsub_algebraic
+            | sym::fmul_algebraic
+            | sym::fdiv_algebraic
+            | sym::frem_algebraic => (1, 0, vec![param(0), param(0)], param(0)),
             sym::float_to_int_unchecked => (2, 0, vec![param(0)], param(1)),
 
             sym::assume => (0, 1, vec![tcx.types.bool], Ty::new_unit(tcx)),
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index b45706fd1e5..7326f2e8e2a 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -418,7 +418,11 @@ extern "C" LLVMAttributeRef LLVMRustCreateMemoryEffectsAttr(LLVMContextRef C,
   }
 }
 
-// Enable a fast-math flag
+// Enable all fast-math flags, including those which will cause floating-point operations
+// to return poison for some well-defined inputs. This function can only be used to build
+// unsafe Rust intrinsics. That unsafety does permit additional optimizations, but at the
+// time of writing, their value is not well-understood relative to those enabled by
+// LLVMRustSetAlgebraicMath.
 //
 // https://llvm.org/docs/LangRef.html#fast-math-flags
 extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
@@ -427,6 +431,25 @@ extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
   }
 }
 
+// Enable fast-math flags which permit algebraic transformations that are not allowed by
+// IEEE floating point. For example:
+// a + (b + c) = (a + b) + c
+// and
+// a / b = a * (1 / b)
+// Note that this does NOT enable any flags which can cause a floating-point operation on
+// well-defined inputs to return poison, and therefore this function can be used to build
+// safe Rust intrinsics (such as fadd_algebraic).
+//
+// https://llvm.org/docs/LangRef.html#fast-math-flags
+extern "C" void LLVMRustSetAlgebraicMath(LLVMValueRef V) {
+  if (auto I = dyn_cast<Instruction>(unwrap<Value>(V))) {
+    I->setHasAllowReassoc(true);
+    I->setHasAllowContract(true);
+    I->setHasAllowReciprocal(true);
+    I->setHasNoSignedZeros(true);
+  }
+}
+
 extern "C" LLVMValueRef
 LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source,
                         const char *Name, LLVMAtomicOrdering Order) {
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 29c88783357..181ab0d4d56 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -764,8 +764,10 @@ symbols! {
         f64_nan,
         fabsf32,
         fabsf64,
+        fadd_algebraic,
         fadd_fast,
         fake_variadic,
+        fdiv_algebraic,
         fdiv_fast,
         feature,
         fence,
@@ -785,6 +787,7 @@ symbols! {
         fmaf32,
         fmaf64,
         fmt,
+        fmul_algebraic,
         fmul_fast,
         fn_align,
         fn_delegation,
@@ -810,6 +813,7 @@ symbols! {
         format_unsafe_arg,
         freeze,
         freg,
+        frem_algebraic,
         frem_fast,
         from,
         from_desugaring,
@@ -823,6 +827,7 @@ symbols! {
         from_usize,
         from_yeet,
         fs_create_dir,
+        fsub_algebraic,
         fsub_fast,
         fundamental,
         future,
diff --git a/library/core/src/intrinsics.rs b/library/core/src/intrinsics.rs
index 067457e235b..4a1187561b3 100644
--- a/library/core/src/intrinsics.rs
+++ b/library/core/src/intrinsics.rs
@@ -1898,6 +1898,46 @@ extern "rust-intrinsic" {
     #[rustc_nounwind]
     pub fn frem_fast<T: Copy>(a: T, b: T) -> T;
 
+    /// Float addition that allows optimizations based on algebraic rules.
+    ///
+    /// This intrinsic does not have a stable counterpart.
+    #[rustc_nounwind]
+    #[rustc_safe_intrinsic]
+    #[cfg(not(bootstrap))]
+    pub fn fadd_algebraic<T: Copy>(a: T, b: T) -> T;
+
+    /// Float subtraction that allows optimizations based on algebraic rules.
+    ///
+    /// This intrinsic does not have a stable counterpart.
+    #[rustc_nounwind]
+    #[rustc_safe_intrinsic]
+    #[cfg(not(bootstrap))]
+    pub fn fsub_algebraic<T: Copy>(a: T, b: T) -> T;
+
+    /// Float multiplication that allows optimizations based on algebraic rules.
+    ///
+    /// This intrinsic does not have a stable counterpart.
+    #[rustc_nounwind]
+    #[rustc_safe_intrinsic]
+    #[cfg(not(bootstrap))]
+    pub fn fmul_algebraic<T: Copy>(a: T, b: T) -> T;
+
+    /// Float division that allows optimizations based on algebraic rules.
+    ///
+    /// This intrinsic does not have a stable counterpart.
+    #[rustc_nounwind]
+    #[rustc_safe_intrinsic]
+    #[cfg(not(bootstrap))]
+    pub fn fdiv_algebraic<T: Copy>(a: T, b: T) -> T;
+
+    /// Float remainder that allows optimizations based on algebraic rules.
+    ///
+    /// This intrinsic does not have a stable counterpart.
+    #[rustc_nounwind]
+    #[rustc_safe_intrinsic]
+    #[cfg(not(bootstrap))]
+    pub fn frem_algebraic<T: Copy>(a: T, b: T) -> T;
+
     /// Convert with LLVM’s fptoui/fptosi, which may return undef for values out of range
     /// (<https://github.com/rust-lang/rust/issues/10184>)
     ///
diff --git a/tests/codegen/simd/issue-120720-reduce-nan.rs b/tests/codegen/simd/issue-120720-reduce-nan.rs
new file mode 100644
index 00000000000..233131aa01c
--- /dev/null
+++ b/tests/codegen/simd/issue-120720-reduce-nan.rs
@@ -0,0 +1,22 @@
+// compile-flags: -C opt-level=3 -C target-cpu=cannonlake
+// only-x86_64
+
+// In a previous implementation, _mm512_reduce_add_pd did the reduction with all fast-math flags
+// enabled, making it UB to reduce a vector containing a NaN.
+
+#![crate_type = "lib"]
+#![feature(stdarch_x86_avx512, avx512_target_feature)]
+use std::arch::x86_64::*;
+
+// CHECK-label: @demo(
+#[no_mangle]
+#[target_feature(enable = "avx512f")] // Function-level target feature mismatches inhibit inlining
+pub unsafe fn demo() -> bool {
+    // CHECK: %0 = tail call reassoc nsz arcp contract double @llvm.vector.reduce.fadd.v8f64(
+    // CHECK: %_0.i = fcmp uno double %0, 0.000000e+00
+    // CHECK: ret i1 %_0.i
+    let res = unsafe {
+        _mm512_reduce_add_pd(_mm512_set_pd(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, f64::NAN))
+    };
+    res.is_nan()
+}