about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/ptr/mod.rs55
1 files changed, 38 insertions, 17 deletions
diff --git a/library/core/src/ptr/mod.rs b/library/core/src/ptr/mod.rs
index 5f028f9ea76..78308f97461 100644
--- a/library/core/src/ptr/mod.rs
+++ b/library/core/src/ptr/mod.rs
@@ -1166,6 +1166,10 @@ pub unsafe fn write_volatile<T>(dst: *mut T, src: T) {
 /// Any questions go to @nagisa.
 #[lang = "align_offset"]
 pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
+    // FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
+    // 1, where the method versions of these operations are not inlined.
+    use intrinsics::{unchecked_shl, unchecked_shr, unchecked_sub, wrapping_mul, wrapping_sub};
+
     /// Calculate multiplicative modular inverse of `x` modulo `m`.
     ///
     /// This implementation is tailored for align_offset and has following preconditions:
@@ -1175,7 +1179,7 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
     ///
     /// Implementation of this function shall not panic. Ever.
     #[inline]
-    fn mod_inv(x: usize, m: usize) -> usize {
+    unsafe fn mod_inv(x: usize, m: usize) -> usize {
         /// Multiplicative modular inverse table modulo 2⁴ = 16.
         ///
         /// Note, that this table does not contain values where inverse does not exist (i.e., for
@@ -1187,8 +1191,10 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
         const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;
 
         let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
+        // SAFETY: `m` is required to be a power-of-two, hence non-zero.
+        let m_minus_one = unsafe { unchecked_sub(m, 1) };
         if m <= INV_TABLE_MOD {
-            table_inverse & (m - 1)
+            table_inverse & m_minus_one
         } else {
             // We iterate "up" using the following formula:
             //
@@ -1204,17 +1210,18 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
                 // uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
                 // usize::MAX` instead, because we take the result `mod n` at the end
                 // anyway.
-                inverse = inverse.wrapping_mul(2usize.wrapping_sub(x.wrapping_mul(inverse)));
+                inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
                 if going_mod >= m {
-                    return inverse & (m - 1);
+                    return inverse & m_minus_one;
                 }
-                going_mod = going_mod.wrapping_mul(going_mod);
+                going_mod = wrapping_mul(going_mod, going_mod);
             }
         }
     }
 
     let stride = mem::size_of::<T>();
-    let a_minus_one = a.wrapping_sub(1);
+    // SAFETY: `a` is a power-of-two, hence non-zero.
+    let a_minus_one = unsafe { unchecked_sub(a, 1) };
     let pmoda = p as usize & a_minus_one;
 
     if pmoda == 0 {
@@ -1228,16 +1235,18 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
             // elements will ever align the pointer.
             !0
         } else {
-            a.wrapping_sub(pmoda)
+            wrapping_sub(a, pmoda)
         };
     }
 
     let smoda = stride & a_minus_one;
-    // SAFETY: a is power-of-two so cannot be 0. stride = 0 is handled above.
+    // SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
     let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) };
-    let gcd = 1usize << gcdpow;
+    // SAFETY: gcdpow has an upper-bound that’s at most the number of bits in an usize.
+    let gcd = unsafe { unchecked_shl(1usize, gcdpow) };
 
-    if p as usize & (gcd.wrapping_sub(1)) == 0 {
+    // SAFETY: gcd is always greater or equal to 1.
+    if p as usize & unsafe { unchecked_sub(gcd, 1) } == 0 {
         // This branch solves for the following linear congruence equation:
         //
         // ` p + so = 0 mod a `
@@ -1245,8 +1254,8 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
         // `p` here is the pointer value, `s` - stride of `T`, `o` offset in `T`s, and `a` - the
         // requested alignment.
         //
-        // With `g = gcd(a, s)`, and the above asserting that `p` is also divisible by `g`, we can
-        // denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
+        // With `g = gcd(a, s)`, and the above condition asserting that `p` is also divisible by
+        // `g`, we can denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
         //
         // ` p' + s'o = 0 mod a' `
         // ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
@@ -1259,11 +1268,23 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
         //
         // Furthermore, the result produced by this solution is not "minimal", so it is necessary
         // to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
-        let a2 = a >> gcdpow;
-        let a2minus1 = a2.wrapping_sub(1);
-        let s2 = smoda >> gcdpow;
-        let minusp2 = a2.wrapping_sub(pmoda >> gcdpow);
-        return (minusp2.wrapping_mul(mod_inv(s2, a2))) & a2minus1;
+
+        // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
+        // `a`.
+        let a2 = unsafe { unchecked_shr(a, gcdpow) };
+        // SAFETY: `a2` is non-zero. Shifting `a` by `gcdpow` cannot shift out any of the set bits
+        // in `a` (of which it has exactly one).
+        let a2minus1 = unsafe { unchecked_sub(a2, 1) };
+        // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
+        // `a`.
+        let s2 = unsafe { unchecked_shr(smoda, gcdpow) };
+        // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
+        // `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
+        // always be strictly greater than `(p % a) >> gcdpow`.
+        let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) };
+        // SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
+        // because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
+        return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;
     }
 
     // Cannot be aligned at all.