about summary refs log tree commit diff
diff options
context:
space:
mode:
authorhkalbasi <hamidrezakalbasi@protonmail.com>2023-06-02 17:29:17 +0330
committerhkalbasi <hamidrezakalbasi@protonmail.com>2023-06-02 17:29:17 +0330
commita6a27a7ff84cd23bbb9ad65ae2c06c2715247ca2 (patch)
treeb46733cb2301260a1eb70f8a0f1cbf7a2056897f
parentf9e3b180b7cbee429c465408bbe0e8dbfc104cd7 (diff)
downloadrust-a6a27a7ff84cd23bbb9ad65ae2c06c2715247ca2.tar.gz
rust-a6a27a7ff84cd23bbb9ad65ae2c06c2715247ca2.zip
Support floating point intrinsics in const eval
-rw-r--r--crates/hir-ty/src/consteval/tests/intrinsics.rs34
-rw-r--r--crates/hir-ty/src/mir/eval.rs2
-rw-r--r--crates/hir-ty/src/mir/eval/shim.rs341
-rw-r--r--crates/ide/src/hover/tests.rs24
4 files changed, 294 insertions, 107 deletions
diff --git a/crates/hir-ty/src/consteval/tests/intrinsics.rs b/crates/hir-ty/src/consteval/tests/intrinsics.rs
index 1feb9a4441c..e05d824dbac 100644
--- a/crates/hir-ty/src/consteval/tests/intrinsics.rs
+++ b/crates/hir-ty/src/consteval/tests/intrinsics.rs
@@ -175,6 +175,40 @@ fn likely() {
 }
 
 #[test]
+fn floating_point() {
+    check_number(
+        r#"
+        extern "rust-intrinsic" {
+            pub fn sqrtf32(x: f32) -> f32;
+            pub fn powf32(a: f32, x: f32) -> f32;
+            pub fn fmaf32(a: f32, b: f32, c: f32) -> f32;
+        }
+
+        const GOAL: f32 = sqrtf32(1.2) + powf32(3.4, 5.6) + fmaf32(-7.8, 1.3, 2.4);
+        "#,
+        i128::from_le_bytes(pad16(
+            &f32::to_le_bytes(1.2f32.sqrt() + 3.4f32.powf(5.6) + (-7.8f32).mul_add(1.3, 2.4)),
+            true,
+        )),
+    );
+    check_number(
+        r#"
+        extern "rust-intrinsic" {
+            pub fn powif64(a: f64, x: i32) -> f64;
+            pub fn sinf64(x: f64) -> f64;
+            pub fn minnumf64(x: f64, y: f64) -> f64;
+        }
+
+        const GOAL: f64 = powif64(1.2, 5) + sinf64(3.4) + minnumf64(-7.8, 1.3);
+        "#,
+        i128::from_le_bytes(pad16(
+            &f64::to_le_bytes(1.2f64.powi(5) + 3.4f64.sin() + (-7.8f64).min(1.3)),
+            true,
+        )),
+    );
+}
+
+#[test]
 fn atomic() {
     check_number(
         r#"
diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs
index 6e26d1f22aa..6d7701c9e8d 100644
--- a/crates/hir-ty/src/mir/eval.rs
+++ b/crates/hir-ty/src/mir/eval.rs
@@ -870,7 +870,7 @@ impl Evaluator<'_> {
                             Owned(c.to_le_bytes().into())
                         }
                         chalk_ir::FloatTy::F64 => {
-                            let c = -from_bytes!(f32, c);
+                            let c = -from_bytes!(f64, c);
                             Owned(c.to_le_bytes().into())
                         }
                     }
diff --git a/crates/hir-ty/src/mir/eval/shim.rs b/crates/hir-ty/src/mir/eval/shim.rs
index e05004eeb6a..77ee7b6b6e2 100644
--- a/crates/hir-ty/src/mir/eval/shim.rs
+++ b/crates/hir-ty/src/mir/eval/shim.rs
@@ -316,119 +316,145 @@ impl Evaluator<'_> {
 
     fn exec_intrinsic(
         &mut self,
-        as_str: &str,
+        name: &str,
         args: &[IntervalAndTy],
         generic_args: &Substitution,
         destination: Interval,
         locals: &Locals<'_>,
         span: MirSpan,
     ) -> Result<()> {
-        // We are a single threaded runtime with no UB checking and no optimization, so
-        // we can implement these as normal functions.
-        if let Some(name) = as_str.strip_prefix("atomic_") {
-            let Some(ty) = generic_args.as_slice(Interner).get(0).and_then(|x| x.ty(Interner)) else {
-                return Err(MirEvalError::TypeError("atomic intrinsic generic arg is not provided"));
-            };
-            let Some(arg0) = args.get(0) else {
-                return Err(MirEvalError::TypeError("atomic intrinsic arg0 is not provided"));
-            };
-            let arg0_addr = Address::from_bytes(arg0.get(self)?)?;
-            let arg0_interval = Interval::new(
-                arg0_addr,
-                self.size_of_sized(ty, locals, "atomic intrinsic type arg")?,
-            );
-            if name.starts_with("load_") {
-                return destination.write_from_interval(self, arg0_interval);
-            }
-            let Some(arg1) = args.get(1) else {
-                return Err(MirEvalError::TypeError("atomic intrinsic arg1 is not provided"));
+        if let Some(name) = name.strip_prefix("atomic_") {
+            return self.exec_atomic_intrinsic(name, args, generic_args, destination, locals, span);
+        }
+        if let Some(name) = name.strip_suffix("f64") {
+            let result = match name {
+                "sqrt" | "sin" | "cos" | "exp" | "exp2" | "log" | "log10" | "log2" | "fabs"
+                | "floor" | "ceil" | "trunc" | "rint" | "nearbyint" | "round" | "roundeven" => {
+                    let [arg] = args else {
+                        return Err(MirEvalError::TypeError("f64 intrinsic signature doesn't match fn (f64) -> f64"));
+                    };
+                    let arg = from_bytes!(f64, arg.get(self)?);
+                    match name {
+                        "sqrt" => arg.sqrt(),
+                        "sin" => arg.sin(),
+                        "cos" => arg.cos(),
+                        "exp" => arg.exp(),
+                        "exp2" => arg.exp2(),
+                        "log" => arg.ln(),
+                        "log10" => arg.log10(),
+                        "log2" => arg.log2(),
+                        "fabs" => arg.abs(),
+                        "floor" => arg.floor(),
+                        "ceil" => arg.ceil(),
+                        "trunc" => arg.trunc(),
+                        // FIXME: these rounds should be different, but only `.round()` is stable now.
+                        "rint" => arg.round(),
+                        "nearbyint" => arg.round(),
+                        "round" => arg.round(),
+                        "roundeven" => arg.round(),
+                        _ => unreachable!(),
+                    }
+                }
+                "pow" | "minnum" | "maxnum" | "copysign" => {
+                    let [arg1, arg2] = args else {
+                        return Err(MirEvalError::TypeError("f64 intrinsic signature doesn't match fn (f64, f64) -> f64"));
+                    };
+                    let arg1 = from_bytes!(f64, arg1.get(self)?);
+                    let arg2 = from_bytes!(f64, arg2.get(self)?);
+                    match name {
+                        "pow" => arg1.powf(arg2),
+                        "minnum" => arg1.min(arg2),
+                        "maxnum" => arg1.max(arg2),
+                        "copysign" => arg1.copysign(arg2),
+                        _ => unreachable!(),
+                    }
+                }
+                "powi" => {
+                    let [arg1, arg2] = args else {
+                        return Err(MirEvalError::TypeError("powif64 signature doesn't match fn (f64, i32) -> f64"));
+                    };
+                    let arg1 = from_bytes!(f64, arg1.get(self)?);
+                    let arg2 = from_bytes!(i32, arg2.get(self)?);
+                    arg1.powi(arg2)
+                }
+                "fma" => {
+                    let [arg1, arg2, arg3] = args else {
+                        return Err(MirEvalError::TypeError("fmaf64 signature doesn't match fn (f64, f64, f64) -> f64"));
+                    };
+                    let arg1 = from_bytes!(f64, arg1.get(self)?);
+                    let arg2 = from_bytes!(f64, arg2.get(self)?);
+                    let arg3 = from_bytes!(f64, arg3.get(self)?);
+                    arg1.mul_add(arg2, arg3)
+                }
+                _ => not_supported!("unknown f64 intrinsic {name}"),
             };
-            if name.starts_with("store_") {
-                return arg0_interval.write_from_interval(self, arg1.interval);
-            }
-            if name.starts_with("xchg_") {
-                destination.write_from_interval(self, arg0_interval)?;
-                return arg0_interval.write_from_interval(self, arg1.interval);
-            }
-            if name.starts_with("xadd_") {
-                destination.write_from_interval(self, arg0_interval)?;
-                let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
-                let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
-                let ans = lhs.wrapping_add(rhs);
-                return arg0_interval
-                    .write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
-            }
-            if name.starts_with("xsub_") {
-                destination.write_from_interval(self, arg0_interval)?;
-                let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
-                let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
-                let ans = lhs.wrapping_sub(rhs);
-                return arg0_interval
-                    .write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
-            }
-            if name.starts_with("and_") {
-                destination.write_from_interval(self, arg0_interval)?;
-                let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
-                let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
-                let ans = lhs & rhs;
-                return arg0_interval
-                    .write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
-            }
-            if name.starts_with("or_") {
-                destination.write_from_interval(self, arg0_interval)?;
-                let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
-                let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
-                let ans = lhs | rhs;
-                return arg0_interval
-                    .write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
-            }
-            if name.starts_with("xor_") {
-                destination.write_from_interval(self, arg0_interval)?;
-                let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
-                let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
-                let ans = lhs ^ rhs;
-                return arg0_interval
-                    .write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
-            }
-            if name.starts_with("nand_") {
-                destination.write_from_interval(self, arg0_interval)?;
-                let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
-                let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
-                let ans = !(lhs & rhs);
-                return arg0_interval
-                    .write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
-            }
-            let Some(arg2) = args.get(2) else {
-                return Err(MirEvalError::TypeError("atomic intrinsic arg2 is not provided"));
+            return destination.write_from_bytes(self, &result.to_le_bytes());
+        }
+        if let Some(name) = name.strip_suffix("f32") {
+            let result = match name {
+                "sqrt" | "sin" | "cos" | "exp" | "exp2" | "log" | "log10" | "log2" | "fabs"
+                | "floor" | "ceil" | "trunc" | "rint" | "nearbyint" | "round" | "roundeven" => {
+                    let [arg] = args else {
+                        return Err(MirEvalError::TypeError("f32 intrinsic signature doesn't match fn (f32) -> f32"));
+                    };
+                    let arg = from_bytes!(f32, arg.get(self)?);
+                    match name {
+                        "sqrt" => arg.sqrt(),
+                        "sin" => arg.sin(),
+                        "cos" => arg.cos(),
+                        "exp" => arg.exp(),
+                        "exp2" => arg.exp2(),
+                        "log" => arg.ln(),
+                        "log10" => arg.log10(),
+                        "log2" => arg.log2(),
+                        "fabs" => arg.abs(),
+                        "floor" => arg.floor(),
+                        "ceil" => arg.ceil(),
+                        "trunc" => arg.trunc(),
+                        // FIXME: these rounds should be different, but only `.round()` is stable now.
+                        "rint" => arg.round(),
+                        "nearbyint" => arg.round(),
+                        "round" => arg.round(),
+                        "roundeven" => arg.round(),
+                        _ => unreachable!(),
+                    }
+                }
+                "pow" | "minnum" | "maxnum" | "copysign" => {
+                    let [arg1, arg2] = args else {
+                        return Err(MirEvalError::TypeError("f32 intrinsic signature doesn't match fn (f32, f32) -> f32"));
+                    };
+                    let arg1 = from_bytes!(f32, arg1.get(self)?);
+                    let arg2 = from_bytes!(f32, arg2.get(self)?);
+                    match name {
+                        "pow" => arg1.powf(arg2),
+                        "minnum" => arg1.min(arg2),
+                        "maxnum" => arg1.max(arg2),
+                        "copysign" => arg1.copysign(arg2),
+                        _ => unreachable!(),
+                    }
+                }
+                "powi" => {
+                    let [arg1, arg2] = args else {
+                        return Err(MirEvalError::TypeError("powif32 signature doesn't match fn (f32, i32) -> f32"));
+                    };
+                    let arg1 = from_bytes!(f32, arg1.get(self)?);
+                    let arg2 = from_bytes!(i32, arg2.get(self)?);
+                    arg1.powi(arg2)
+                }
+                "fma" => {
+                    let [arg1, arg2, arg3] = args else {
+                        return Err(MirEvalError::TypeError("fmaf32 signature doesn't match fn (f32, f32, f32) -> f32"));
+                    };
+                    let arg1 = from_bytes!(f32, arg1.get(self)?);
+                    let arg2 = from_bytes!(f32, arg2.get(self)?);
+                    let arg3 = from_bytes!(f32, arg3.get(self)?);
+                    arg1.mul_add(arg2, arg3)
+                }
+                _ => not_supported!("unknown f32 intrinsic {name}"),
             };
-            if name.starts_with("cxchg_") || name.starts_with("cxchgweak_") {
-                let dest = if arg1.get(self)? == arg0_interval.get(self)? {
-                    arg0_interval.write_from_interval(self, arg2.interval)?;
-                    (arg1.interval, true)
-                } else {
-                    (arg0_interval, false)
-                };
-                let result_ty = TyKind::Tuple(
-                    2,
-                    Substitution::from_iter(Interner, [ty.clone(), TyBuilder::bool()]),
-                )
-                .intern(Interner);
-                let layout = self.layout(&result_ty)?;
-                let result = self.make_by_layout(
-                    layout.size.bytes_usize(),
-                    &layout,
-                    None,
-                    [
-                        IntervalOrOwned::Borrowed(dest.0),
-                        IntervalOrOwned::Owned(vec![u8::from(dest.1)]),
-                    ]
-                    .into_iter(),
-                )?;
-                return destination.write_from_bytes(self, &result);
-            }
-            not_supported!("unknown atomic intrinsic {name}");
+            return destination.write_from_bytes(self, &result.to_le_bytes());
         }
-        match as_str {
+        match name {
             "size_of" => {
                 let Some(ty) = generic_args.as_slice(Interner).get(0).and_then(|x| x.ty(Interner)) else {
                     return Err(MirEvalError::TypeError("size_of generic arg is not provided"));
@@ -539,7 +565,7 @@ impl Evaluator<'_> {
                     self.size_of_sized(&lhs.ty, locals, "operand of add_with_overflow")?;
                 let lhs = u128::from_le_bytes(pad16(lhs.get(self)?, false));
                 let rhs = u128::from_le_bytes(pad16(rhs.get(self)?, false));
-                let (ans, u128overflow) = match as_str {
+                let (ans, u128overflow) = match name {
                     "add_with_overflow" => lhs.overflowing_add(rhs),
                     "sub_with_overflow" => lhs.overflowing_sub(rhs),
                     "mul_with_overflow" => lhs.overflowing_mul(rhs),
@@ -641,7 +667,110 @@ impl Evaluator<'_> {
                 }
                 self.exec_fn_trait(&args, destination, locals, span)
             }
-            _ => not_supported!("unknown intrinsic {as_str}"),
+            _ => not_supported!("unknown intrinsic {name}"),
+        }
+    }
+
+    fn exec_atomic_intrinsic(
+        &mut self,
+        name: &str,
+        args: &[IntervalAndTy],
+        generic_args: &Substitution,
+        destination: Interval,
+        locals: &Locals<'_>,
+        _span: MirSpan,
+    ) -> Result<()> {
+        // We are a single threaded runtime with no UB checking and no optimization, so
+        // we can implement these as normal functions.
+        let Some(ty) = generic_args.as_slice(Interner).get(0).and_then(|x| x.ty(Interner)) else {
+            return Err(MirEvalError::TypeError("atomic intrinsic generic arg is not provided"));
+        };
+        let Some(arg0) = args.get(0) else {
+            return Err(MirEvalError::TypeError("atomic intrinsic arg0 is not provided"));
+        };
+        let arg0_addr = Address::from_bytes(arg0.get(self)?)?;
+        let arg0_interval =
+            Interval::new(arg0_addr, self.size_of_sized(ty, locals, "atomic intrinsic type arg")?);
+        if name.starts_with("load_") {
+            return destination.write_from_interval(self, arg0_interval);
+        }
+        let Some(arg1) = args.get(1) else {
+            return Err(MirEvalError::TypeError("atomic intrinsic arg1 is not provided"));
+        };
+        if name.starts_with("store_") {
+            return arg0_interval.write_from_interval(self, arg1.interval);
+        }
+        if name.starts_with("xchg_") {
+            destination.write_from_interval(self, arg0_interval)?;
+            return arg0_interval.write_from_interval(self, arg1.interval);
+        }
+        if name.starts_with("xadd_") {
+            destination.write_from_interval(self, arg0_interval)?;
+            let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
+            let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
+            let ans = lhs.wrapping_add(rhs);
+            return arg0_interval.write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
+        }
+        if name.starts_with("xsub_") {
+            destination.write_from_interval(self, arg0_interval)?;
+            let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
+            let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
+            let ans = lhs.wrapping_sub(rhs);
+            return arg0_interval.write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
+        }
+        if name.starts_with("and_") {
+            destination.write_from_interval(self, arg0_interval)?;
+            let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
+            let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
+            let ans = lhs & rhs;
+            return arg0_interval.write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
+        }
+        if name.starts_with("or_") {
+            destination.write_from_interval(self, arg0_interval)?;
+            let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
+            let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
+            let ans = lhs | rhs;
+            return arg0_interval.write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
+        }
+        if name.starts_with("xor_") {
+            destination.write_from_interval(self, arg0_interval)?;
+            let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
+            let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
+            let ans = lhs ^ rhs;
+            return arg0_interval.write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
+        }
+        if name.starts_with("nand_") {
+            destination.write_from_interval(self, arg0_interval)?;
+            let lhs = u128::from_le_bytes(pad16(arg0_interval.get(self)?, false));
+            let rhs = u128::from_le_bytes(pad16(arg1.get(self)?, false));
+            let ans = !(lhs & rhs);
+            return arg0_interval.write_from_bytes(self, &ans.to_le_bytes()[0..destination.size]);
+        }
+        let Some(arg2) = args.get(2) else {
+            return Err(MirEvalError::TypeError("atomic intrinsic arg2 is not provided"));
+        };
+        if name.starts_with("cxchg_") || name.starts_with("cxchgweak_") {
+            let dest = if arg1.get(self)? == arg0_interval.get(self)? {
+                arg0_interval.write_from_interval(self, arg2.interval)?;
+                (arg1.interval, true)
+            } else {
+                (arg0_interval, false)
+            };
+            let result_ty = TyKind::Tuple(
+                2,
+                Substitution::from_iter(Interner, [ty.clone(), TyBuilder::bool()]),
+            )
+            .intern(Interner);
+            let layout = self.layout(&result_ty)?;
+            let result = self.make_by_layout(
+                layout.size.bytes_usize(),
+                &layout,
+                None,
+                [IntervalOrOwned::Borrowed(dest.0), IntervalOrOwned::Owned(vec![u8::from(dest.1)])]
+                    .into_iter(),
+            )?;
+            return destination.write_from_bytes(self, &result);
         }
+        not_supported!("unknown atomic intrinsic {name}");
     }
 }
diff --git a/crates/ide/src/hover/tests.rs b/crates/ide/src/hover/tests.rs
index d2c035c471e..a2f96977581 100644
--- a/crates/ide/src/hover/tests.rs
+++ b/crates/ide/src/hover/tests.rs
@@ -4338,6 +4338,30 @@ const FOO$0: f64 = 1.0f64;
 }
 
 #[test]
+fn hover_const_eval_floating_point() {
+    check(
+        r#"
+extern "rust-intrinsic" {
+    pub fn expf64(x: f64) -> f64;
+}
+
+const FOO$0: f64 = expf64(1.2);
+"#,
+        expect![[r#"
+            *FOO*
+
+            ```rust
+            test
+            ```
+
+            ```rust
+            const FOO: f64 = 3.3201169227365472
+            ```
+        "#]],
+    );
+}
+
+#[test]
 fn hover_const_eval_enum() {
     check(
         r#"