about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/ptr/mod.rs83
-rw-r--r--library/core/tests/ptr.rs63
-rw-r--r--src/test/assembly/align_offset.rs48
3 files changed, 137 insertions, 57 deletions
diff --git a/library/core/src/ptr/mod.rs b/library/core/src/ptr/mod.rs
index 2fcff3dfc5c..509854225c4 100644
--- a/library/core/src/ptr/mod.rs
+++ b/library/core/src/ptr/mod.rs
@@ -1557,11 +1557,10 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
     // FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
     // 1, where the method versions of these operations are not inlined.
     use intrinsics::{
-        unchecked_shl, unchecked_shr, unchecked_sub, wrapping_add, wrapping_mul, wrapping_sub,
+        cttz_nonzero, exact_div, unchecked_rem, unchecked_shl, unchecked_shr, unchecked_sub,
+        wrapping_add, wrapping_mul, wrapping_sub,
     };
 
-    let addr = p.addr();
-
     /// Calculate multiplicative modular inverse of `x` modulo `m`.
     ///
     /// This implementation is tailored for `align_offset` and has following preconditions:
@@ -1611,36 +1610,61 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
         }
     }
 
+    let addr = p.addr();
     let stride = mem::size_of::<T>();
     // SAFETY: `a` is a power-of-two, therefore non-zero.
     let a_minus_one = unsafe { unchecked_sub(a, 1) };
-    if stride == 1 {
-        // `stride == 1` case can be computed more simply through `-p (mod a)`, but doing so
-        // inhibits LLVM's ability to select instructions like `lea`. Instead we compute
+
+    if stride == 0 {
+        // SPECIAL_CASE: handle 0-sized types. No matter how many times we step, the address will
+        // stay the same, so no offset will be able to align the pointer unless it is already
+        // aligned. This branch _will_ be optimized out as `stride` is known at compile-time.
+        let p_mod_a = addr & a_minus_one;
+        return if p_mod_a == 0 { 0 } else { usize::MAX };
+    }
+
+    // SAFETY: `stride == 0` case has been handled by the special case above.
+    let a_mod_stride = unsafe { unchecked_rem(a, stride) };
+    if a_mod_stride == 0 {
+        // SPECIAL_CASE: In cases where the `a` is divisible by `stride`, byte offset to align a
+        // pointer can be computed more simply through `-p (mod a)`. In the off-chance the byte
+        // offset is not a multiple of `stride`, the input pointer was misaligned and no pointer
+        // offset will be able to produce a `p` aligned to the specified `a`.
         //
-        //    round_up_to_next_alignment(p, a) - p
+        // The naive `-p (mod a)` equation  inhibits LLVM's ability to select instructions
+        // like `lea`. We compute `(round_up_to_next_alignment(p, a) - p)` instead. This
+        // redistributes operations around the load-bearing, but pessimizing `and` instruction
+        // sufficiently for LLVM to be able to utilize the various optimizations it knows about.
         //
-        // which distributes operations around the load-bearing, but pessimizing `and` sufficiently
-        // for LLVM to be able to utilize the various optimizations it knows about.
-        return wrapping_sub(wrapping_add(addr, a_minus_one) & wrapping_sub(0, a), addr);
-    }
+        // LLVM handles the branch here particularly nicely. If this branch needs to be evaluated
+        // at runtime, it will produce a mask `if addr_mod_stride == 0 { 0 } else { usize::MAX }`
+        // in a branch-free way and then bitwise-OR it with whatever result the `-p mod a`
+        // computation produces.
+
+        // SAFETY: `stride == 0` case has been handled by the special case above.
+        let addr_mod_stride = unsafe { unchecked_rem(addr, stride) };
 
-    let pmoda = addr & a_minus_one;
-    if pmoda == 0 {
-        // Already aligned. Yay!
-        return 0;
-    } else if stride == 0 {
-        // If the pointer is not aligned, and the element is zero-sized, then no amount of
-        // elements will ever align the pointer.
-        return usize::MAX;
+        return if addr_mod_stride == 0 {
+            let aligned_address = wrapping_add(addr, a_minus_one) & wrapping_sub(0, a);
+            let byte_offset = wrapping_sub(aligned_address, addr);
+            // SAFETY: `stride` is non-zero. This is guaranteed to divide exactly as well, because
+            // addr has been verified to be aligned to the original type’s alignment requirements.
+            unsafe { exact_div(byte_offset, stride) }
+        } else {
+            usize::MAX
+        };
     }
 
-    let smoda = stride & a_minus_one;
+    // GENERAL_CASE: From here on we’re handling the very general case where `addr` may be
+    // misaligned, there isn’t an obvious relationship between `stride` and `a` that we can take an
+    // advantage of, etc. This case produces machine code that isn’t particularly high quality,
+    // compared to the special cases above. The code produced here is still within the realm of
+    // miracles, given the situations this case has to deal with.
+
     // SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
-    let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) };
+    let gcdpow = unsafe { cttz_nonzero(stride).min(cttz_nonzero(a)) };
     // SAFETY: gcdpow has an upper-bound that’s at most the number of bits in a usize.
     let gcd = unsafe { unchecked_shl(1usize, gcdpow) };
-
     // SAFETY: gcd is always greater or equal to 1.
     if addr & unsafe { unchecked_sub(gcd, 1) } == 0 {
         // This branch solves for the following linear congruence equation:
@@ -1656,14 +1680,13 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
         // ` p' + s'o = 0 mod a' `
         // ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
         //
-        // The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the second
-        // term is "how does incrementing `p` by `s` bytes change the relative alignment of `p`" (again
-        // divided by `g`).
-        // Division by `g` is necessary to make the inverse well formed if `a` and `s` are not
-        // co-prime.
+        // The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the
+        // second term is "how does incrementing `p` by `s` bytes change the relative alignment of
+        // `p`" (again divided by `g`). Division by `g` is necessary to make the inverse well
+        // formed if `a` and `s` are not co-prime.
         //
         // Furthermore, the result produced by this solution is not "minimal", so it is necessary
-        // to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
+        // to take the result `o mod lcm(s, a)`. This `lcm(s, a)` is the same as `a'`.
 
         // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
         // `a`.
@@ -1673,11 +1696,11 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
         let a2minus1 = unsafe { unchecked_sub(a2, 1) };
         // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
         // `a`.
-        let s2 = unsafe { unchecked_shr(smoda, gcdpow) };
+        let s2 = unsafe { unchecked_shr(stride & a_minus_one, gcdpow) };
         // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
         // `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
         // always be strictly greater than `(p % a) >> gcdpow`.
-        let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) };
+        let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(addr & a_minus_one, gcdpow)) };
         // SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
         // because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
         return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;
diff --git a/library/core/tests/ptr.rs b/library/core/tests/ptr.rs
index bab2b1792f6..12861794c2d 100644
--- a/library/core/tests/ptr.rs
+++ b/library/core/tests/ptr.rs
@@ -359,7 +359,7 @@ fn align_offset_zst() {
 }
 
 #[test]
-fn align_offset_stride1() {
+fn align_offset_stride_one() {
     // For pointers of stride = 1, the pointer can always be aligned. The offset is equal to
     // number of bytes.
     let mut align = 1;
@@ -380,24 +380,8 @@ fn align_offset_stride1() {
 }
 
 #[test]
-fn align_offset_weird_strides() {
-    #[repr(packed)]
-    struct A3(u16, u8);
-    struct A4(u32);
-    #[repr(packed)]
-    struct A5(u32, u8);
-    #[repr(packed)]
-    struct A6(u32, u16);
-    #[repr(packed)]
-    struct A7(u32, u16, u8);
-    #[repr(packed)]
-    struct A8(u32, u32);
-    #[repr(packed)]
-    struct A9(u32, u32, u8);
-    #[repr(packed)]
-    struct A10(u32, u32, u16);
-
-    unsafe fn test_weird_stride<T>(ptr: *const T, align: usize) -> bool {
+fn align_offset_various_strides() {
+    unsafe fn test_stride<T>(ptr: *const T, align: usize) -> bool {
         let numptr = ptr as usize;
         let mut expected = usize::MAX;
         // Naive but definitely correct way to find the *first* aligned element of stride::<T>.
@@ -431,14 +415,39 @@ fn align_offset_weird_strides() {
     while align < limit {
         for ptr in 1usize..4 * align {
             unsafe {
-                x |= test_weird_stride::<A3>(ptr::invalid::<A3>(ptr), align);
-                x |= test_weird_stride::<A4>(ptr::invalid::<A4>(ptr), align);
-                x |= test_weird_stride::<A5>(ptr::invalid::<A5>(ptr), align);
-                x |= test_weird_stride::<A6>(ptr::invalid::<A6>(ptr), align);
-                x |= test_weird_stride::<A7>(ptr::invalid::<A7>(ptr), align);
-                x |= test_weird_stride::<A8>(ptr::invalid::<A8>(ptr), align);
-                x |= test_weird_stride::<A9>(ptr::invalid::<A9>(ptr), align);
-                x |= test_weird_stride::<A10>(ptr::invalid::<A10>(ptr), align);
+                #[repr(packed)]
+                struct A3(u16, u8);
+                x |= test_stride::<A3>(ptr::invalid::<A3>(ptr), align);
+
+                struct A4(u32);
+                x |= test_stride::<A4>(ptr::invalid::<A4>(ptr), align);
+
+                #[repr(packed)]
+                struct A5(u32, u8);
+                x |= test_stride::<A5>(ptr::invalid::<A5>(ptr), align);
+
+                #[repr(packed)]
+                struct A6(u32, u16);
+                x |= test_stride::<A6>(ptr::invalid::<A6>(ptr), align);
+
+                #[repr(packed)]
+                struct A7(u32, u16, u8);
+                x |= test_stride::<A7>(ptr::invalid::<A7>(ptr), align);
+
+                #[repr(packed)]
+                struct A8(u32, u32);
+                x |= test_stride::<A8>(ptr::invalid::<A8>(ptr), align);
+
+                #[repr(packed)]
+                struct A9(u32, u32, u8);
+                x |= test_stride::<A9>(ptr::invalid::<A9>(ptr), align);
+
+                #[repr(packed)]
+                struct A10(u32, u32, u16);
+                x |= test_stride::<A10>(ptr::invalid::<A10>(ptr), align);
+
+                x |= test_stride::<u32>(ptr::invalid::<u32>(ptr), align);
+                x |= test_stride::<u128>(ptr::invalid::<u128>(ptr), align);
             }
         }
         align = (align + 1).next_power_of_two();
diff --git a/src/test/assembly/align_offset.rs b/src/test/assembly/align_offset.rs
new file mode 100644
index 00000000000..c5eefca3467
--- /dev/null
+++ b/src/test/assembly/align_offset.rs
@@ -0,0 +1,48 @@
+// assembly-output: emit-asm
+// compile-flags: -Copt-level=1
+// only-x86_64
+// min-llvm-version: 14.0
+#![crate_type="rlib"]
+
+// CHECK-LABEL: align_offset_byte_ptr
+// CHECK: leaq 31
+// CHECK: andq $-32
+// CHECK: subq
+#[no_mangle]
+pub fn align_offset_byte_ptr(ptr: *const u8) -> usize {
+    ptr.align_offset(32)
+}
+
+// CHECK-LABEL: align_offset_byte_slice
+// CHECK: leaq 31
+// CHECK: andq $-32
+// CHECK: subq
+#[no_mangle]
+pub fn align_offset_byte_slice(slice: &[u8]) -> usize {
+    slice.as_ptr().align_offset(32)
+}
+
+// CHECK-LABEL: align_offset_word_ptr
+// CHECK: leaq 31
+// CHECK: andq $-32
+// CHECK: subq
+// CHECK: shrq
+// This `ptr` is not known to be aligned, so it is required to check if it is at all possible to
+// align. LLVM applies a simple mask.
+// CHECK: orq
+#[no_mangle]
+pub fn align_offset_word_ptr(ptr: *const u32) -> usize {
+    ptr.align_offset(32)
+}
+
+// CHECK-LABEL: align_offset_word_slice
+// CHECK: leaq 31
+// CHECK: andq $-32
+// CHECK: subq
+// CHECK: shrq
+// `slice` is known to be aligned, so `!0` is not possible as a return
+// CHECK-NOT: orq
+#[no_mangle]
+pub fn align_offset_word_slice(slice: &[u32]) -> usize {
+    slice.as_ptr().align_offset(32)
+}