about summary refs log tree commit diff
diff options
context:
space:
mode:
authorChai T. Rex <ChaiTRex@users.noreply.github.com>2024-08-26 01:34:25 -0400
committerChai T. Rex <ChaiTRex@users.noreply.github.com>2024-08-28 23:06:54 -0400
commit0cac9152116b5b4e4c5f84a05e0ceb20dc325a0c (patch)
tree12940e47527f4794730e791098a65cad16b977fa
parentbf662eb80838008acabc307dd64d84935ce3a20d (diff)
downloadrust-0cac9152116b5b4e4c5f84a05e0ceb20dc325a0c.tar.gz
rust-0cac9152116b5b4e4c5f84a05e0ceb20dc325a0c.zip
Improve `isqrt` tests and add benchmarks
* Choose test inputs more thoroughly and systematically.
* Check that `isqrt` and `checked_isqrt` have equivalent results for
  signed types, either equivalent numerically or equivalent as a panic
  and a `None`.
* Check that `isqrt` has numerically-equivalent results for unsigned
  types and their `NonZero` counterparts.
* Reuse `ilog10` benchmarks, plus benchmarks that use a uniform
  distribution.
-rw-r--r--library/core/benches/lib.rs1
-rw-r--r--library/core/benches/num/int_sqrt/mod.rs62
-rw-r--r--library/core/benches/num/mod.rs1
-rw-r--r--library/core/tests/num/int_macros.rs32
-rw-r--r--library/core/tests/num/int_sqrt.rs248
-rw-r--r--library/core/tests/num/mod.rs1
6 files changed, 313 insertions, 32 deletions
diff --git a/library/core/benches/lib.rs b/library/core/benches/lib.rs
index 32d15c386cb..3f1c58bbd72 100644
--- a/library/core/benches/lib.rs
+++ b/library/core/benches/lib.rs
@@ -8,6 +8,7 @@
 #![feature(iter_array_chunks)]
 #![feature(iter_next_chunk)]
 #![feature(iter_advance_by)]
+#![feature(isqrt)]
 
 extern crate test;
 
diff --git a/library/core/benches/num/int_sqrt/mod.rs b/library/core/benches/num/int_sqrt/mod.rs
new file mode 100644
index 00000000000..3c9d173e456
--- /dev/null
+++ b/library/core/benches/num/int_sqrt/mod.rs
@@ -0,0 +1,62 @@
+use rand::Rng;
+use test::{black_box, Bencher};
+
+macro_rules! int_sqrt_bench {
+    ($t:ty, $predictable:ident, $random:ident, $random_small:ident, $random_uniform:ident) => {
+        #[bench]
+        fn $predictable(bench: &mut Bencher) {
+            bench.iter(|| {
+                for n in 0..(<$t>::BITS / 8) {
+                    for i in 1..=(100 as $t) {
+                        let x = black_box(i << (n * 8));
+                        black_box(x.isqrt());
+                    }
+                }
+            });
+        }
+
+        #[bench]
+        fn $random(bench: &mut Bencher) {
+            let mut rng = crate::bench_rng();
+            /* Exponentially distributed random numbers from the whole range of the type.  */
+            let numbers: Vec<$t> =
+                (0..256).map(|_| rng.gen::<$t>() >> rng.gen_range(0..<$t>::BITS)).collect();
+            bench.iter(|| {
+                for x in &numbers {
+                    black_box(black_box(x).isqrt());
+                }
+            });
+        }
+
+        #[bench]
+        fn $random_small(bench: &mut Bencher) {
+            let mut rng = crate::bench_rng();
+            /* Exponentially distributed random numbers from the range 0..256.  */
+            let numbers: Vec<$t> =
+                (0..256).map(|_| (rng.gen::<u8>() >> rng.gen_range(0..u8::BITS)) as $t).collect();
+            bench.iter(|| {
+                for x in &numbers {
+                    black_box(black_box(x).isqrt());
+                }
+            });
+        }
+
+        #[bench]
+        fn $random_uniform(bench: &mut Bencher) {
+            let mut rng = crate::bench_rng();
+            /* Exponentially distributed random numbers from the whole range of the type.  */
+            let numbers: Vec<$t> = (0..256).map(|_| rng.gen::<$t>()).collect();
+            bench.iter(|| {
+                for x in &numbers {
+                    black_box(black_box(x).isqrt());
+                }
+            });
+        }
+    };
+}
+
+int_sqrt_bench! {u8, u8_sqrt_predictable, u8_sqrt_random, u8_sqrt_random_small, u8_sqrt_uniform}
+int_sqrt_bench! {u16, u16_sqrt_predictable, u16_sqrt_random, u16_sqrt_random_small, u16_sqrt_uniform}
+int_sqrt_bench! {u32, u32_sqrt_predictable, u32_sqrt_random, u32_sqrt_random_small, u32_sqrt_uniform}
+int_sqrt_bench! {u64, u64_sqrt_predictable, u64_sqrt_random, u64_sqrt_random_small, u64_sqrt_uniform}
+int_sqrt_bench! {u128, u128_sqrt_predictable, u128_sqrt_random, u128_sqrt_random_small, u128_sqrt_uniform}
diff --git a/library/core/benches/num/mod.rs b/library/core/benches/num/mod.rs
index c1dc3a30622..7ff7443cfa7 100644
--- a/library/core/benches/num/mod.rs
+++ b/library/core/benches/num/mod.rs
@@ -2,6 +2,7 @@ mod dec2flt;
 mod flt2dec;
 mod int_log;
 mod int_pow;
+mod int_sqrt;
 
 use std::str::FromStr;
 
diff --git a/library/core/tests/num/int_macros.rs b/library/core/tests/num/int_macros.rs
index 7cd3b54e3f3..830a96204ca 100644
--- a/library/core/tests/num/int_macros.rs
+++ b/library/core/tests/num/int_macros.rs
@@ -289,38 +289,6 @@ macro_rules! int_module {
         }
 
         #[test]
-        fn test_isqrt() {
-            assert_eq!($T::MIN.checked_isqrt(), None);
-            assert_eq!((-1 as $T).checked_isqrt(), None);
-            assert_eq!((0 as $T).isqrt(), 0 as $T);
-            assert_eq!((1 as $T).isqrt(), 1 as $T);
-            assert_eq!((2 as $T).isqrt(), 1 as $T);
-            assert_eq!((99 as $T).isqrt(), 9 as $T);
-            assert_eq!((100 as $T).isqrt(), 10 as $T);
-        }
-
-        #[cfg(not(miri))] // Miri is too slow
-        #[test]
-        fn test_lots_of_isqrt() {
-            let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
-            for n in 0..=n_max {
-                let isqrt: $T = n.isqrt();
-
-                assert!(isqrt.pow(2) <= n);
-                let (square, overflow) = (isqrt + 1).overflowing_pow(2);
-                assert!(overflow || square > n);
-            }
-
-            for n in ($T::MAX - 127)..=$T::MAX {
-                let isqrt: $T = n.isqrt();
-
-                assert!(isqrt.pow(2) <= n);
-                let (square, overflow) = (isqrt + 1).overflowing_pow(2);
-                assert!(overflow || square > n);
-            }
-        }
-
-        #[test]
         fn test_div_floor() {
             let a: $T = 8;
             let b = 3;
diff --git a/library/core/tests/num/int_sqrt.rs b/library/core/tests/num/int_sqrt.rs
new file mode 100644
index 00000000000..d68db0787d2
--- /dev/null
+++ b/library/core/tests/num/int_sqrt.rs
@@ -0,0 +1,248 @@
+macro_rules! tests {
+    ($isqrt_consistency_check_fn_macro:ident : $($T:ident)+) => {
+        $(
+            mod $T {
+                $isqrt_consistency_check_fn_macro!($T);
+
+                // Check that the following produce the correct values from
+                // `isqrt`:
+                //
+                // * the first and last 128 nonnegative values
+                // * powers of two, minus one
+                // * powers of two
+                //
+                // For signed types, check that `checked_isqrt` and `isqrt`
+                // either produce the same numeric value or respectively
+                // produce `None` and a panic. Make sure to do a consistency
+                // check for `<$T>::MIN` as well, as no nonnegative values
+                // negate to it.
+                //
+                // For unsigned types check that `isqrt` produces the same
+                // numeric value for `$T` and `NonZero<$T>`.
+                #[test]
+                fn isqrt() {
+                    isqrt_consistency_check(<$T>::MIN);
+
+                    for n in (0..=127)
+                        .chain(<$T>::MAX - 127..=<$T>::MAX)
+                        .chain((0..<$T>::MAX.count_ones()).map(|exponent| (1 << exponent) - 1))
+                        .chain((0..<$T>::MAX.count_ones()).map(|exponent| 1 << exponent))
+                    {
+                        isqrt_consistency_check(n);
+
+                        let isqrt_n = n.isqrt();
+                        assert!(
+                            isqrt_n
+                                .checked_mul(isqrt_n)
+                                .map(|isqrt_n_squared| isqrt_n_squared <= n)
+                                .unwrap_or(false),
+                            "`{n}.isqrt()` should be lower than {isqrt_n}."
+                        );
+                        assert!(
+                            (isqrt_n + 1)
+                                .checked_mul(isqrt_n + 1)
+                                .map(|isqrt_n_plus_1_squared| n < isqrt_n_plus_1_squared)
+                                .unwrap_or(true),
+                            "`{n}.isqrt()` should be higher than {isqrt_n})."
+                        );
+                    }
+                }
+
+                // Check the square roots of:
+                //
+                // * the first 1,024 perfect squares
+                // * halfway between each of the first 1,024 perfect squares
+                //   and the next perfect square
+                // * the next perfect square after the each of the first 1,024
+                //   perfect squares, minus one
+                // * the last 1,024 perfect squares
+                // * the last 1,024 perfect squares, minus one
+                // * halfway between each of the last 1,024 perfect squares
+                //   and the previous perfect square
+                #[test]
+                // Skip this test on Miri, as it takes too long to run.
+                #[cfg(not(miri))]
+                fn isqrt_extended() {
+                    // The correct value is worked out by using the fact that
+                    // the nth nonzero perfect square is the sum of the first n
+                    // odd numbers:
+                    //
+                    //  1 = 1
+                    //  4 = 1 + 3
+                    //  9 = 1 + 3 + 5
+                    // 16 = 1 + 3 + 5 + 7
+                    //
+                    // Note also that the last odd number added in is two times
+                    // the square root of the previous perfect square, plus
+                    // one:
+                    //
+                    // 1 = 2*0 + 1
+                    // 3 = 2*1 + 1
+                    // 5 = 2*2 + 1
+                    // 7 = 2*3 + 1
+                    //
+                    // That means we can add the square root of this perfect
+                    // square once to get about halfway to the next perfect
+                    // square, then we can add the square root of this perfect
+                    // square again to get to the next perfect square, minus
+                    // one, then we can add one to get to the next perfect
+                    // square.
+                    //
+                    // This allows us to, for each of the first 1,024 perfect
+                    // squares, test that the square roots of the following are
+                    // all correct and equal to each other:
+                    //
+                    // * the current perfect square
+                    // * about halfway to the next perfect square
+                    // * the next perfect square, minus one
+                    let mut n: $T = 0;
+                    for sqrt_n in 0..1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T {
+                        isqrt_consistency_check(n);
+                        assert_eq!(
+                            n.isqrt(),
+                            sqrt_n,
+                            "`{sqrt_n}.pow(2).isqrt()` should be {sqrt_n}."
+                        );
+
+                        n += sqrt_n;
+                        isqrt_consistency_check(n);
+                        assert_eq!(
+                            n.isqrt(),
+                            sqrt_n,
+                            "{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.",
+                            sqrt_n + 1
+                        );
+
+                        n += sqrt_n;
+                        isqrt_consistency_check(n);
+                        assert_eq!(
+                            n.isqrt(),
+                            sqrt_n,
+                            "`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.",
+                            sqrt_n + 1
+                        );
+
+                        n += 1;
+                    }
+
+                    // Similarly, for each of the last 1,024 perfect squares,
+                    // check:
+                    //
+                    // * the current perfect square
+                    // * the current perfect square, minus one
+                    // * about halfway to the previous perfect square
+                    //
+                    // `MAX`'s `isqrt` return value is verified in the `isqrt`
+                    // test function above.
+                    let maximum_sqrt = <$T>::MAX.isqrt();
+                    let mut n = maximum_sqrt * maximum_sqrt;
+
+                    for sqrt_n in (maximum_sqrt - 1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T..maximum_sqrt).rev() {
+                        isqrt_consistency_check(n);
+                        assert_eq!(
+                            n.isqrt(),
+                            sqrt_n + 1,
+                            "`{0}.pow(2).isqrt()` should be {0}.",
+                            sqrt_n + 1
+                        );
+
+                        n -= 1;
+                        isqrt_consistency_check(n);
+                        assert_eq!(
+                            n.isqrt(),
+                            sqrt_n,
+                            "`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.",
+                            sqrt_n + 1
+                        );
+
+                        n -= sqrt_n;
+                        isqrt_consistency_check(n);
+                        assert_eq!(
+                            n.isqrt(),
+                            sqrt_n,
+                            "{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.",
+                            sqrt_n + 1
+                        );
+
+                        n -= sqrt_n;
+                    }
+                }
+            }
+        )*
+    };
+}
+
+macro_rules! signed_check {
+    ($T:ident) => {
+        /// This takes an input and, if it's nonnegative or
+        #[doc = concat!("`", stringify!($T), "::MIN`,")]
+        /// checks that `isqrt` and `checked_isqrt` produce equivalent results
+        /// for that input and for the negative of that input.
+        ///
+        /// # Note
+        ///
+        /// This cannot check that negative inputs to `isqrt` cause panics if
+        /// panics abort instead of unwind.
+        fn isqrt_consistency_check(n: $T) {
+            // `<$T>::MIN` will be negative, so ignore it in this nonnegative
+            // section.
+            if n >= 0 {
+                assert_eq!(
+                    Some(n.isqrt()),
+                    n.checked_isqrt(),
+                    "`{n}.checked_isqrt()` should match `Some({n}.isqrt())`.",
+                );
+            }
+
+            // `wrapping_neg` so that `<$T>::MIN` will negate to itself rather
+            // than panicking.
+            let negative_n = n.wrapping_neg();
+
+            // Zero negated will still be nonnegative, so ignore it in this
+            // negative section.
+            if negative_n < 0 {
+                assert_eq!(
+                    negative_n.checked_isqrt(),
+                    None,
+                    "`({negative_n}).checked_isqrt()` should be `None`, as {negative_n} is negative.",
+                );
+
+                // `catch_unwind` only works when panics unwind rather than abort.
+                #[cfg(panic = "unwind")]
+                {
+                    std::panic::catch_unwind(core::panic::AssertUnwindSafe(|| (-n).isqrt())).expect_err(
+                        &format!("`({negative_n}).isqrt()` should have panicked, as {negative_n} is negative.")
+                    );
+                }
+            }
+        }
+    };
+}
+
+macro_rules! unsigned_check {
+    ($T:ident) => {
+        /// This takes an input and, if it's nonzero, checks that `isqrt`
+        /// produces the same numeric value for both
+        #[doc = concat!("`", stringify!($T), "` and ")]
+        #[doc = concat!("`NonZero<", stringify!($T), ">`.")]
+        fn isqrt_consistency_check(n: $T) {
+            // Zero cannot be turned into a `NonZero` value, so ignore it in
+            // this nonzero section.
+            if n > 0 {
+                assert_eq!(
+                    n.isqrt(),
+                    core::num::NonZero::<$T>::new(n)
+                        .expect(
+                            "Was not able to create a new `NonZero` value from a nonzero number."
+                        )
+                        .isqrt()
+                        .get(),
+                    "`{n}.isqrt` should match `NonZero`'s `{n}.isqrt().get()`.",
+                );
+            }
+        }
+    };
+}
+
+tests!(signed_check: i8 i16 i32 i64 i128);
+tests!(unsigned_check: u8 u16 u32 u64 u128);
diff --git a/library/core/tests/num/mod.rs b/library/core/tests/num/mod.rs
index dad46ad88fe..b14fe0b22c3 100644
--- a/library/core/tests/num/mod.rs
+++ b/library/core/tests/num/mod.rs
@@ -27,6 +27,7 @@ mod const_from;
 mod dec2flt;
 mod flt2dec;
 mod int_log;
+mod int_sqrt;
 mod ops;
 mod wrapping;