about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMikhail Zabaluev <mikhail.zabaluev@gmail.com>2024-08-13 08:21:16 +0300
committerMikhail Zabaluev <mikhail.zabaluev@gmail.com>2024-08-13 08:32:36 +0300
commitac88b330b875e8058589b1804ac5d95fcd40905d (patch)
treec63b65f1fd56800fb73d9ae78ee813daa6c4b198
parent2f235343529c39bdab47704ec9620d6784eeeb6d (diff)
downloadrust-ac88b330b875e8058589b1804ac5d95fcd40905d.tar.gz
rust-ac88b330b875e8058589b1804ac5d95fcd40905d.zip
Revert to original loop for const pow exponents
Give LLVM the for original, optimizable loop in pow and wrapped_pow
functions in the case when the exponent is statically known.
-rw-r--r--library/core/src/num/int_macros.rs135
-rw-r--r--library/core/src/num/uint_macros.rs135
2 files changed, 110 insertions, 160 deletions
diff --git a/library/core/src/num/int_macros.rs b/library/core/src/num/int_macros.rs
index be0e6a2a03b..d8ef36f21ac 100644
--- a/library/core/src/num/int_macros.rs
+++ b/library/core/src/num/int_macros.rs
@@ -2174,54 +2174,41 @@ macro_rules! int_impl {
         #[inline]
         #[rustc_allow_const_fn_unstable(is_val_statically_known)]
         pub const fn wrapping_pow(self, mut exp: u32) -> Self {
+            if exp == 0 {
+                return 1;
+            }
             let mut base = self;
+            let mut acc: Self = 1;
 
             if intrinsics::is_val_statically_known(exp) {
-                // Unroll multiplications for small exponent values.
-                // This gives the optimizer a way to efficiently inline call sites
-                // for the most common use cases with constant exponents.
-                // Currently, LLVM is unable to unroll the loop below.
-                match exp {
-                    0 => return 1,
-                    1 => return base,
-                    2 => return base.wrapping_mul(base),
-                    3 => {
-                        let squared = base.wrapping_mul(base);
-                        return squared.wrapping_mul(base);
-                    }
-                    4 => {
-                        let squared = base.wrapping_mul(base);
-                        return squared.wrapping_mul(squared);
+                while exp > 1 {
+                    if (exp & 1) == 1 {
+                        acc = acc.wrapping_mul(base);
                     }
-                    5 => {
-                        let squared = base.wrapping_mul(base);
-                        return squared.wrapping_mul(squared).wrapping_mul(base);
-                    }
-                    6 => {
-                        let cubed = base.wrapping_mul(base).wrapping_mul(base);
-                        return cubed.wrapping_mul(cubed);
-                    }
-                    _ => {}
+                    exp /= 2;
+                    base = base.wrapping_mul(base);
                 }
-            } else {
-                if exp == 0 {
-                    return 1;
-                }
-            }
-            debug_assert!(exp != 0);
 
-            let mut acc: Self = 1;
-
-            loop {
-                if (exp & 1) == 1 {
-                    acc = acc.wrapping_mul(base);
-                    // since exp!=0, finally the exp must be 1.
-                    if exp == 1 {
-                        return acc;
+                // since exp!=0, finally the exp must be 1.
+                // Deal with the final bit of the exponent separately, since
+                // squaring the base afterwards is not necessary.
+                acc.wrapping_mul(base)
+            } else {
+                // This is faster than the above when the exponent is not known
+                // at compile time. We can't use the same code for the constant
+                // exponent case because LLVM is currently unable to unroll
+                // this loop.
+                loop {
+                    if (exp & 1) == 1 {
+                        acc = acc.wrapping_mul(base);
+                        // since exp!=0, finally the exp must be 1.
+                        if exp == 1 {
+                            return acc;
+                        }
                     }
+                    exp /= 2;
+                    base = base.wrapping_mul(base);
                 }
-                exp /= 2;
-                base = base.wrapping_mul(base);
             }
         }
 
@@ -2753,54 +2740,42 @@ macro_rules! int_impl {
         #[rustc_inherit_overflow_checks]
         #[rustc_allow_const_fn_unstable(is_val_statically_known)]
         pub const fn pow(self, mut exp: u32) -> Self {
+            if exp == 0 {
+                return 1;
+            }
             let mut base = self;
+            let mut acc = 1;
 
             if intrinsics::is_val_statically_known(exp) {
-                // Unroll multiplications for small exponent values.
-                // This gives the optimizer a way to efficiently inline call sites
-                // for the most common use cases with constant exponents.
-                // Currently, LLVM is unable to unroll the loop below.
-                match exp {
-                    0 => return 1,
-                    1 => return base,
-                    2 => return base * base,
-                    3 => {
-                        let squared = base * base;
-                        return squared * base;
-                    }
-                    4 => {
-                        let squared = base * base;
-                        return squared * squared;
+                while exp > 1 {
+                    if (exp & 1) == 1 {
+                        acc = acc * base;
                     }
-                    5 => {
-                        let squared = base * base;
-                        return squared * squared * base;
-                    }
-                    6 => {
-                        let cubed = base * base * base;
-                        return cubed * cubed;
-                    }
-                    _ => {}
+                    exp /= 2;
+                    base = base * base;
                 }
-            } else {
-                if exp == 0 {
-                    return 1;
-                }
-            }
-            debug_assert!(exp != 0);
 
-            let mut acc = 1;
-
-            loop {
-                if (exp & 1) == 1 {
-                    acc = acc * base;
-                    // since exp!=0, finally the exp must be 1.
-                    if exp == 1 {
-                        return acc;
+                // since exp!=0, finally the exp must be 1.
+                // Deal with the final bit of the exponent separately, since
+                // squaring the base afterwards is not necessary and may cause a
+                // needless overflow.
+                acc * base
+            } else {
+                // This is faster than the above when the exponent is not known
+                // at compile time. We can't use the same code for the constant
+                // exponent case because LLVM is currently unable to unroll
+                // this loop.
+                loop {
+                    if (exp & 1) == 1 {
+                        acc = acc * base;
+                        // since exp!=0, finally the exp must be 1.
+                        if exp == 1 {
+                            return acc;
+                        }
                     }
+                    exp /= 2;
+                    base = base * base;
                 }
-                exp /= 2;
-                base = base * base;
             }
         }
 
diff --git a/library/core/src/num/uint_macros.rs b/library/core/src/num/uint_macros.rs
index 24352593fca..5b3ef78d39a 100644
--- a/library/core/src/num/uint_macros.rs
+++ b/library/core/src/num/uint_macros.rs
@@ -2050,54 +2050,41 @@ macro_rules! uint_impl {
         #[inline]
         #[rustc_allow_const_fn_unstable(is_val_statically_known)]
         pub const fn wrapping_pow(self, mut exp: u32) -> Self {
+            if exp == 0 {
+                return 1;
+            }
             let mut base = self;
+            let mut acc: Self = 1;
 
             if intrinsics::is_val_statically_known(exp) {
-                // Unroll multiplications for small exponent values.
-                // This gives the optimizer a way to efficiently inline call sites
-                // for the most common use cases with constant exponents.
-                // Currently, LLVM is unable to unroll the loop below.
-                match exp {
-                    0 => return 1,
-                    1 => return base,
-                    2 => return base.wrapping_mul(base),
-                    3 => {
-                        let squared = base.wrapping_mul(base);
-                        return squared.wrapping_mul(base);
-                    }
-                    4 => {
-                        let squared = base.wrapping_mul(base);
-                        return squared.wrapping_mul(squared);
+                while exp > 1 {
+                    if (exp & 1) == 1 {
+                        acc = acc.wrapping_mul(base);
                     }
-                    5 => {
-                        let squared = base.wrapping_mul(base);
-                        return squared.wrapping_mul(squared).wrapping_mul(base);
-                    }
-                    6 => {
-                        let cubed = base.wrapping_mul(base).wrapping_mul(base);
-                        return cubed.wrapping_mul(cubed);
-                    }
-                    _ => {}
+                    exp /= 2;
+                    base = base.wrapping_mul(base);
                 }
-            } else {
-                if exp == 0 {
-                    return 1;
-                }
-            }
-            debug_assert!(exp != 0);
 
-            let mut acc: Self = 1;
-
-            loop {
-                if (exp & 1) == 1 {
-                    acc = acc.wrapping_mul(base);
-                    // since exp!=0, finally the exp must be 1.
-                    if exp == 1 {
-                        return acc;
+                // since exp!=0, finally the exp must be 1.
+                // Deal with the final bit of the exponent separately, since
+                // squaring the base afterwards is not necessary.
+                acc.wrapping_mul(base)
+            } else {
+                // This is faster than the above when the exponent is not known
+                // at compile time. We can't use the same code for the constant
+                // exponent case because LLVM is currently unable to unroll
+                // this loop.
+                loop {
+                    if (exp & 1) == 1 {
+                        acc = acc.wrapping_mul(base);
+                        // since exp!=0, finally the exp must be 1.
+                        if exp == 1 {
+                            return acc;
+                        }
                     }
+                    exp /= 2;
+                    base = base.wrapping_mul(base);
                 }
-                exp /= 2;
-                base = base.wrapping_mul(base);
             }
         }
 
@@ -2578,54 +2565,42 @@ macro_rules! uint_impl {
         #[rustc_inherit_overflow_checks]
         #[rustc_allow_const_fn_unstable(is_val_statically_known)]
         pub const fn pow(self, mut exp: u32) -> Self {
+            if exp == 0 {
+                return 1;
+            }
             let mut base = self;
+            let mut acc = 1;
 
             if intrinsics::is_val_statically_known(exp) {
-                // Unroll multiplications for small exponent values.
-                // This gives the optimizer a way to efficiently inline call sites
-                // for the most common use cases with constant exponents.
-                // Currently, LLVM is unable to unroll the loop below.
-                match exp {
-                    0 => return 1,
-                    1 => return base,
-                    2 => return base * base,
-                    3 => {
-                        let squared = base * base;
-                        return squared * base;
-                    }
-                    4 => {
-                        let squared = base * base;
-                        return squared * squared;
+                while exp > 1 {
+                    if (exp & 1) == 1 {
+                        acc = acc * base;
                     }
-                    5 => {
-                        let squared = base * base;
-                        return squared * squared * base;
-                    }
-                    6 => {
-                        let cubed = base * base * base;
-                        return cubed * cubed;
-                    }
-                    _ => {}
+                    exp /= 2;
+                    base = base * base;
                 }
-            } else {
-                if exp == 0 {
-                    return 1;
-                }
-            }
-            debug_assert!(exp != 0);
 
-            let mut acc = 1;
-
-            loop {
-                if (exp & 1) == 1 {
-                    acc = acc * base;
-                    // since exp!=0, finally the exp must be 1.
-                    if exp == 1 {
-                        return acc;
+                // since exp!=0, finally the exp must be 1.
+                // Deal with the final bit of the exponent separately, since
+                // squaring the base afterwards is not necessary and may cause a
+                // needless overflow.
+                acc * base
+            } else {
+                // This is faster than the above when the exponent is not known
+                // at compile time. We can't use the same code for the constant
+                // exponent case because LLVM is currently unable to unroll
+                // this loop.
+                loop {
+                    if (exp & 1) == 1 {
+                        acc = acc * base;
+                        // since exp!=0, finally the exp must be 1.
+                        if exp == 1 {
+                            return acc;
+                        }
                     }
+                    exp /= 2;
+                    base = base * base;
                 }
-                exp /= 2;
-                base = base * base;
             }
         }