about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/str/pattern.rs234
1 files changed, 182 insertions, 52 deletions
diff --git a/library/core/src/str/pattern.rs b/library/core/src/str/pattern.rs
index def11ca45c0..c5be32861f9 100644
--- a/library/core/src/str/pattern.rs
+++ b/library/core/src/str/pattern.rs
@@ -956,15 +956,20 @@ impl<'a, 'b> Pattern<'a> for &'b str {
 
         match self.len().cmp(&haystack.len()) {
             Ordering::Less => {
+                if self.len() == 1 {
+                    return haystack.as_bytes().contains(&self.as_bytes()[0]);
+                }
+
                 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
-                if self.as_bytes().len() <= 8 {
-                    return simd_contains(self, haystack);
+                if self.len() <= 32 {
+                    if let Some(result) = simd_contains(self, haystack) {
+                        return result;
+                    }
                 }
 
                 self.into_searcher(haystack).next_match().is_some()
             }
-            Ordering::Equal => self == haystack,
-            Ordering::Greater => false,
+            _ => self == haystack,
         }
     }
 
@@ -1707,82 +1712,207 @@ impl TwoWayStrategy for RejectAndMatch {
     }
 }
 
+/// SIMD search for short needles based on
+/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
+///
+/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
+/// does) by probing the first and last byte of the needle for the whole vector width
+/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
+///
+/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
+/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
+/// should be evaluated.
+///
+/// For haystacks smaller than vector-size + needle length it falls back to
+/// a naive O(n*m) search so this implementation should not be called on larger needles.
+///
+/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
 #[inline]
-fn simd_contains(needle: &str, haystack: &str) -> bool {
+fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
     let needle = needle.as_bytes();
     let haystack = haystack.as_bytes();
 
-    if needle.len() == 1 {
-        return haystack.contains(&needle[0]);
-    }
-
-    const CHUNK: usize = 16;
+    debug_assert!(needle.len() > 1);
+
+    use crate::ops::BitAnd;
+    use crate::simd::mask8x16 as Mask;
+    use crate::simd::u8x16 as Block;
+    use crate::simd::{SimdPartialEq, ToBitMask};
+
+    let first_probe = needle[0];
+
+    // the offset used for the 2nd vector
+    let second_probe_offset = if needle.len() == 2 {
+        // never bail out on len=2 needles because the probes will fully cover them and have
+        // no degenerate cases.
+        1
+    } else {
+        // try a few bytes in case first and last byte of the needle are the same
+        let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
+            // fall back to other search methods if we can't find any different bytes
+            // since we could otherwise hit some degenerate cases
+            return None;
+        };
+        second_probe_offset
+    };
 
-    // do a naive search if if the haystack is too small to fit
-    if haystack.len() < CHUNK + needle.len() - 1 {
-        return haystack.windows(needle.len()).any(|c| c == needle);
+    // do a naive search if the haystack is too small to fit
+    if haystack.len() < Block::LANES + second_probe_offset {
+        return Some(haystack.windows(needle.len()).any(|c| c == needle));
     }
 
-    use crate::arch::x86_64::{
-        __m128i, _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
-    };
-
-    // SAFETY: no preconditions other than sse2 being available
-    let first: __m128i = unsafe { _mm_set1_epi8(needle[0] as i8) };
-    // SAFETY: no preconditions other than sse2 being available
-    let last: __m128i = unsafe { _mm_set1_epi8(*needle.last().unwrap() as i8) };
+    let first_probe: Block = Block::splat(first_probe);
+    let second_probe: Block = Block::splat(needle[second_probe_offset]);
+    // first byte are already checked by the outer loop. to verify a match only the
+    // remainder has to be compared.
+    let trimmed_needle = &needle[1..];
 
+    // this #[cold] is load-bearing, benchmark before removing it...
     let check_mask = #[cold]
-    |idx, mut mask: u32| -> bool {
+    |idx, mask: u16, skip: bool| -> bool {
+        if skip {
+            return false;
+        }
+
+        // and so is this. optimizations are weird.
+        let mut mask = mask;
+
         while mask != 0 {
             let trailing = mask.trailing_zeros();
             let offset = idx + trailing as usize + 1;
-            let sub = &haystack[offset..][..needle.len() - 2];
-            let trimmed_needle = &needle[1..needle.len() - 1];
-
-            if sub == trimmed_needle {
-                return true;
+            // SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
+            // and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
+            unsafe {
+                let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
+                if small_slice_eq(sub, trimmed_needle) {
+                    return true;
+                }
             }
             mask &= !(1 << trailing);
         }
         return false;
     };
 
-    let test_chunk = |i| -> bool {
-        // SAFETY: this requires at least CHUNK bytes being readable at offset i
+    let test_chunk = |idx| -> u16 {
+        // SAFETY: this requires at least LANES bytes being readable at idx
         // that is ensured by the loop ranges (see comments below)
-        let a: __m128i = unsafe { _mm_loadu_si128(haystack.as_ptr().add(i) as *const _) };
-        let b: __m128i =
-            // SAFETY: this requires CHUNK + needle.len() - 1 bytes being readable at offset i
-            unsafe { _mm_loadu_si128(haystack.as_ptr().add(i + needle.len() - 1) as *const _) };
-
-        // SAFETY: no preconditions other than sse2 being available
-        let eq_first: __m128i = unsafe { _mm_cmpeq_epi8(first, a) };
-        // SAFETY: no preconditions other than sse2 being available
-        let eq_last: __m128i = unsafe { _mm_cmpeq_epi8(last, b) };
-
-        // SAFETY: no preconditions other than sse2 being available
-        let mask: u32 = unsafe { _mm_movemask_epi8(_mm_and_si128(eq_first, eq_last)) } as u32;
+        let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
+        // SAFETY: this requires LANES + block_offset bytes being readable at idx
+        let b: Block = unsafe {
+            haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
+        };
+        let eq_first: Mask = a.simd_eq(first_probe);
+        let eq_last: Mask = b.simd_eq(second_probe);
+        let both = eq_first.bitand(eq_last);
+        let mask = both.to_bitmask();
 
-        if mask != 0 {
-            return check_mask(i, mask);
-        }
-        return false;
+        return mask;
     };
 
     let mut i = 0;
     let mut result = false;
-    while !result && i + CHUNK + needle.len() <= haystack.len() {
-        result |= test_chunk(i);
-        i += CHUNK;
+    // The loop condition must ensure that there's enough headroom to read LANE bytes,
+    // and not only at the current index but also at the index shifted by block_offset
+    const UNROLL: usize = 4;
+    while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
+        let mut masks = [0u16; UNROLL];
+        for j in 0..UNROLL {
+            masks[j] = test_chunk(i + j * Block::LANES);
+        }
+        for j in 0..UNROLL {
+            let mask = masks[j];
+            if mask != 0 {
+                result |= check_mask(i + j * Block::LANES, mask, result);
+            }
+        }
+        i += UNROLL * Block::LANES;
+    }
+    while i + second_probe_offset + Block::LANES < haystack.len() && !result {
+        let mask = test_chunk(i);
+        if mask != 0 {
+            result |= check_mask(i, mask, result);
+        }
+        i += Block::LANES;
     }
 
-    // process the tail that didn't fit into CHUNK-sized steps
-    // this simply repeats the same procedure but as right-aligned chunk instead
+    // Process the tail that didn't fit into LANES-sized steps.
+    // This simply repeats the same procedure but as right-aligned chunk instead
     // of a left-aligned one. The last byte must be exactly flush with the string end so
     // we don't miss a single byte or read out of bounds.
-    result |= test_chunk(haystack.len() + 1 - needle.len() - CHUNK);
+    let i = haystack.len() - second_probe_offset - Block::LANES;
+    let mask = test_chunk(i);
+    if mask != 0 {
+        result |= check_mask(i, mask, result);
+    }
+
+    Some(result)
+}
+
+/// Compares short slices for equality.
+///
+/// It avoids a call to libc's memcmp which is faster on long slices
+/// due to SIMD optimizations but it incurs a function call overhead.
+///
+/// # Safety
+///
+/// Both slices must have the same length.
+#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
+#[inline]
+unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
+    // This function is adapted from
+    // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
 
-    return result;
+    // If we don't have enough bytes to do 4-byte at a time loads, then
+    // fall back to the naive slow version.
+    //
+    // Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
+    // of a loop. Benchmark it.
+    if x.len() < 4 {
+        for (&b1, &b2) in x.iter().zip(y) {
+            if b1 != b2 {
+                return false;
+            }
+        }
+        return true;
+    }
+    // When we have 4 or more bytes to compare, then proceed in chunks of 4 at
+    // a time using unaligned loads.
+    //
+    // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
+    // that this particular version of memcmp is likely to be called with tiny
+    // needles. That means that if we do 8 byte loads, then a higher proportion
+    // of memcmp calls will use the slower variant above. With that said, this
+    // is a hypothesis and is only loosely supported by benchmarks. There's
+    // likely some improvement that could be made here. The main thing here
+    // though is to optimize for latency, not throughput.
+
+    // SAFETY: Via the conditional above, we know that both `px` and `py`
+    // have the same length, so `px < pxend` implies that `py < pyend`.
+    // Thus, derefencing both `px` and `py` in the loop below is safe.
+    //
+    // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
+    // end of of `px` and `py`. Thus, the final dereference outside of the
+    // loop is guaranteed to be valid. (The final comparison will overlap with
+    // the last comparison done in the loop for lengths that aren't multiples
+    // of four.)
+    //
+    // Finally, we needn't worry about alignment here, since we do unaligned
+    // loads.
+    unsafe {
+        let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
+        let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
+        while px < pxend {
+            let vx = (px as *const u32).read_unaligned();
+            let vy = (py as *const u32).read_unaligned();
+            if vx != vy {
+                return false;
+            }
+            px = px.add(4);
+            py = py.add(4);
+        }
+        let vx = (pxend as *const u32).read_unaligned();
+        let vy = (pyend as *const u32).read_unaligned();
+        vx == vy
+    }
 }