about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-09-29 07:35:44 +0000
committerbors <bors@rust-lang.org>2023-09-29 07:35:44 +0000
commitb8536c1aa1973dd2438841815b1eeec129480e45 (patch)
treec8ef27dc78f977d2db8d7bb22f4ac8714fb653a2
parentc1f86f0bc87eaa0cf16bcf3de38793096ec4db94 (diff)
parent25648de28f10799fa6274f64fa12475292231c72 (diff)
downloadrust-b8536c1aa1973dd2438841815b1eeec129480e45.tar.gz
rust-b8536c1aa1973dd2438841815b1eeec129480e45.zip
Auto merge of #116176 - FedericoStra:isqrt, r=dtolnay
Add "integer square root" method to integer primitive types

For every suffix `N` among `8`, `16`, `32`, `64`, `128` and `size`, this PR adds the methods

```rust
const fn uN::isqrt() -> uN;
const fn iN::isqrt() -> iN;
const fn iN::checked_isqrt() -> Option<iN>;
```

to compute the [integer square root](https://en.wikipedia.org/wiki/Integer_square_root), addressing issue #89273.

The implementation is based on the [base 2 digit-by-digit algorithm](https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)) on Wikipedia, which after some benchmarking has proved to be faster than both binary search and Heron's/Newton's method. I haven't had the time to understand and port [this code](http://atoms.alife.co.uk/sqrt/SquareRoot.java) based on lookup tables instead, but I'm not sure whether it's worth complicating such a function this much for relatively little benefit.
-rw-r--r--library/core/src/lib.rs1
-rw-r--r--library/core/src/num/int_macros.rs54
-rw-r--r--library/core/src/num/uint_macros.rs48
-rw-r--r--library/core/tests/lib.rs1
-rw-r--r--library/core/tests/num/int_macros.rs32
-rw-r--r--library/core/tests/num/uint_macros.rs29
6 files changed, 165 insertions, 0 deletions
diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs
index 8b04bafcda5..be734a9ba52 100644
--- a/library/core/src/lib.rs
+++ b/library/core/src/lib.rs
@@ -178,6 +178,7 @@
 #![feature(ip)]
 #![feature(ip_bits)]
 #![feature(is_ascii_octdigit)]
+#![feature(isqrt)]
 #![feature(maybe_uninit_uninit_array)]
 #![feature(ptr_alignment_type)]
 #![feature(ptr_metadata)]
diff --git a/library/core/src/num/int_macros.rs b/library/core/src/num/int_macros.rs
index 1f43520e1b3..3cbb55af3bc 100644
--- a/library/core/src/num/int_macros.rs
+++ b/library/core/src/num/int_macros.rs
@@ -898,6 +898,30 @@ macro_rules! int_impl {
             acc.checked_mul(base)
         }
 
+        /// Returns the square root of the number, rounded down.
+        ///
+        /// Returns `None` if `self` is negative.
+        ///
+        /// # Examples
+        ///
+        /// Basic usage:
+        /// ```
+        /// #![feature(isqrt)]
+        #[doc = concat!("assert_eq!(10", stringify!($SelfT), ".checked_isqrt(), Some(3));")]
+        /// ```
+        #[unstable(feature = "isqrt", issue = "116226")]
+        #[rustc_const_unstable(feature = "isqrt", issue = "116226")]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn checked_isqrt(self) -> Option<Self> {
+            if self < 0 {
+                None
+            } else {
+                Some((self as $UnsignedT).isqrt() as Self)
+            }
+        }
+
         /// Saturating integer addition. Computes `self + rhs`, saturating at the numeric
         /// bounds instead of overflowing.
         ///
@@ -2061,6 +2085,36 @@ macro_rules! int_impl {
             acc * base
         }
 
+        /// Returns the square root of the number, rounded down.
+        ///
+        /// # Panics
+        ///
+        /// This function will panic if `self` is negative.
+        ///
+        /// # Examples
+        ///
+        /// Basic usage:
+        /// ```
+        /// #![feature(isqrt)]
+        #[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")]
+        /// ```
+        #[unstable(feature = "isqrt", issue = "116226")]
+        #[rustc_const_unstable(feature = "isqrt", issue = "116226")]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn isqrt(self) -> Self {
+            // I would like to implement it as
+            // ```
+            // self.checked_isqrt().expect("argument of integer square root must be non-negative")
+            // ```
+            // but `expect` is not yet stable as a `const fn`.
+            match self.checked_isqrt() {
+                Some(sqrt) => sqrt,
+                None => panic!("argument of integer square root must be non-negative"),
+            }
+        }
+
         /// Calculates the quotient of Euclidean division of `self` by `rhs`.
         ///
         /// This computes the integer `q` such that `self = q * rhs + r`, with
diff --git a/library/core/src/num/uint_macros.rs b/library/core/src/num/uint_macros.rs
index 7cbef9e7793..a9c5312a1c0 100644
--- a/library/core/src/num/uint_macros.rs
+++ b/library/core/src/num/uint_macros.rs
@@ -1995,6 +1995,54 @@ macro_rules! uint_impl {
             acc * base
         }
 
+        /// Returns the square root of the number, rounded down.
+        ///
+        /// # Examples
+        ///
+        /// Basic usage:
+        /// ```
+        /// #![feature(isqrt)]
+        #[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")]
+        /// ```
+        #[unstable(feature = "isqrt", issue = "116226")]
+        #[rustc_const_unstable(feature = "isqrt", issue = "116226")]
+        #[must_use = "this returns the result of the operation, \
+                      without modifying the original"]
+        #[inline]
+        pub const fn isqrt(self) -> Self {
+            if self < 2 {
+                return self;
+            }
+
+            // The algorithm is based on the one presented in
+            // <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
+            // which cites as source the following C code:
+            // <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.
+
+            let mut op = self;
+            let mut res = 0;
+            let mut one = 1 << (self.ilog2() & !1);
+
+            while one != 0 {
+                if op >= res + one {
+                    op -= res + one;
+                    res = (res >> 1) + one;
+                } else {
+                    res >>= 1;
+                }
+                one >>= 2;
+            }
+
+            // SAFETY: the result is positive and fits in an integer with half as many bits.
+            // Inform the optimizer about it.
+            unsafe {
+                intrinsics::assume(0 < res);
+                intrinsics::assume(res < 1 << (Self::BITS / 2));
+            }
+
+            res
+        }
+
         /// Performs Euclidean division.
         ///
         /// Since, for the positive integers, all common
diff --git a/library/core/tests/lib.rs b/library/core/tests/lib.rs
index 773f2b955d8..e4003a208bc 100644
--- a/library/core/tests/lib.rs
+++ b/library/core/tests/lib.rs
@@ -56,6 +56,7 @@
 #![feature(min_specialization)]
 #![feature(numfmt)]
 #![feature(num_midpoint)]
+#![feature(isqrt)]
 #![feature(step_trait)]
 #![feature(str_internals)]
 #![feature(std_internals)]
diff --git a/library/core/tests/num/int_macros.rs b/library/core/tests/num/int_macros.rs
index 439bbe66997..165d9a29617 100644
--- a/library/core/tests/num/int_macros.rs
+++ b/library/core/tests/num/int_macros.rs
@@ -291,6 +291,38 @@ macro_rules! int_module {
             }
 
             #[test]
+            fn test_isqrt() {
+                assert_eq!($T::MIN.checked_isqrt(), None);
+                assert_eq!((-1 as $T).checked_isqrt(), None);
+                assert_eq!((0 as $T).isqrt(), 0 as $T);
+                assert_eq!((1 as $T).isqrt(), 1 as $T);
+                assert_eq!((2 as $T).isqrt(), 1 as $T);
+                assert_eq!((99 as $T).isqrt(), 9 as $T);
+                assert_eq!((100 as $T).isqrt(), 10 as $T);
+            }
+
+            #[cfg(not(miri))] // Miri is too slow
+            #[test]
+            fn test_lots_of_isqrt() {
+                let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
+                for n in 0..=n_max {
+                    let isqrt: $T = n.isqrt();
+
+                    assert!(isqrt.pow(2) <= n);
+                    let (square, overflow) = (isqrt + 1).overflowing_pow(2);
+                    assert!(overflow || square > n);
+                }
+
+                for n in ($T::MAX - 127)..=$T::MAX {
+                    let isqrt: $T = n.isqrt();
+
+                    assert!(isqrt.pow(2) <= n);
+                    let (square, overflow) = (isqrt + 1).overflowing_pow(2);
+                    assert!(overflow || square > n);
+                }
+            }
+
+            #[test]
             fn test_div_floor() {
                 let a: $T = 8;
                 let b = 3;
diff --git a/library/core/tests/num/uint_macros.rs b/library/core/tests/num/uint_macros.rs
index 7d6203db0b9..955440647eb 100644
--- a/library/core/tests/num/uint_macros.rs
+++ b/library/core/tests/num/uint_macros.rs
@@ -207,6 +207,35 @@ macro_rules! uint_module {
             }
 
             #[test]
+            fn test_isqrt() {
+                assert_eq!((0 as $T).isqrt(), 0 as $T);
+                assert_eq!((1 as $T).isqrt(), 1 as $T);
+                assert_eq!((2 as $T).isqrt(), 1 as $T);
+                assert_eq!((99 as $T).isqrt(), 9 as $T);
+                assert_eq!((100 as $T).isqrt(), 10 as $T);
+                assert_eq!($T::MAX.isqrt(), (1 << ($T::BITS / 2)) - 1);
+            }
+
+            #[cfg(not(miri))] // Miri is too slow
+            #[test]
+            fn test_lots_of_isqrt() {
+                let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
+                for n in 0..=n_max {
+                    let isqrt: $T = n.isqrt();
+
+                    assert!(isqrt.pow(2) <= n);
+                    assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n);
+                }
+
+                for n in ($T::MAX - 255)..=$T::MAX {
+                    let isqrt: $T = n.isqrt();
+
+                    assert!(isqrt.pow(2) <= n);
+                    assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n);
+                }
+            }
+
+            #[test]
             fn test_div_floor() {
                 assert_eq!((8 as $T).div_floor(3), 2);
             }