about summary refs log tree commit diff
path: root/library
diff options
context:
space:
mode:
Diffstat (limited to 'library')
-rw-r--r--library/alloc/benches/str.rs65
-rw-r--r--library/alloc/tests/str.rs26
-rw-r--r--library/core/src/str/pattern.rs232
3 files changed, 311 insertions, 12 deletions
diff --git a/library/alloc/benches/str.rs b/library/alloc/benches/str.rs
index 391475bc0c7..54af389dedc 100644
--- a/library/alloc/benches/str.rs
+++ b/library/alloc/benches/str.rs
@@ -1,3 +1,4 @@
+use core::iter::Iterator;
 use test::{black_box, Bencher};
 
 #[bench]
@@ -122,14 +123,13 @@ fn bench_contains_short_short(b: &mut Bencher) {
     let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
     let needle = "sit";
 
+    b.bytes = haystack.len() as u64;
     b.iter(|| {
-        assert!(haystack.contains(needle));
+        assert!(black_box(haystack).contains(black_box(needle)));
     })
 }
 
-#[bench]
-fn bench_contains_short_long(b: &mut Bencher) {
-    let haystack = "\
+static LONG_HAYSTACK: &str = "\
 Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse quis lorem sit amet dolor \
 ultricies condimentum. Praesent iaculis purus elit, ac malesuada quam malesuada in. Duis sed orci \
 eros. Suspendisse sit amet magna mollis, mollis nunc luctus, imperdiet mi. Integer fringilla non \
@@ -164,10 +164,48 @@ feugiat. Etiam quis mauris vel risus luctus mattis a a nunc. Nullam orci quam, i
 vehicula in, porttitor ut nibh. Duis sagittis adipiscing nisl vitae congue. Donec mollis risus eu \
 leo suscipit, varius porttitor nulla porta. Pellentesque ut sem nec nisi euismod vehicula. Nulla \
 malesuada sollicitudin quam eu fermentum.";
+
+#[bench]
+fn bench_contains_2b_repeated_long(b: &mut Bencher) {
+    let haystack = LONG_HAYSTACK;
+    let needle = "::";
+
+    b.bytes = haystack.len() as u64;
+    b.iter(|| {
+        assert!(!black_box(haystack).contains(black_box(needle)));
+    })
+}
+
+#[bench]
+fn bench_contains_short_long(b: &mut Bencher) {
+    let haystack = LONG_HAYSTACK;
     let needle = "english";
 
+    b.bytes = haystack.len() as u64;
+    b.iter(|| {
+        assert!(!black_box(haystack).contains(black_box(needle)));
+    })
+}
+
+#[bench]
+fn bench_contains_16b_in_long(b: &mut Bencher) {
+    let haystack = LONG_HAYSTACK;
+    let needle = "english language";
+
+    b.bytes = haystack.len() as u64;
+    b.iter(|| {
+        assert!(!black_box(haystack).contains(black_box(needle)));
+    })
+}
+
+#[bench]
+fn bench_contains_32b_in_long(b: &mut Bencher) {
+    let haystack = LONG_HAYSTACK;
+    let needle = "the english language sample text";
+
+    b.bytes = haystack.len() as u64;
     b.iter(|| {
-        assert!(!haystack.contains(needle));
+        assert!(!black_box(haystack).contains(black_box(needle)));
     })
 }
 
@@ -176,8 +214,20 @@ fn bench_contains_bad_naive(b: &mut Bencher) {
     let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
     let needle = "aaaaaaaab";
 
+    b.bytes = haystack.len() as u64;
+    b.iter(|| {
+        assert!(!black_box(haystack).contains(black_box(needle)));
+    })
+}
+
+#[bench]
+fn bench_contains_bad_simd(b: &mut Bencher) {
+    let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
+    let needle = "aaabaaaa";
+
+    b.bytes = haystack.len() as u64;
     b.iter(|| {
-        assert!(!haystack.contains(needle));
+        assert!(!black_box(haystack).contains(black_box(needle)));
     })
 }
 
@@ -186,8 +236,9 @@ fn bench_contains_equal(b: &mut Bencher) {
     let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
     let needle = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
 
+    b.bytes = haystack.len() as u64;
     b.iter(|| {
-        assert!(haystack.contains(needle));
+        assert!(black_box(haystack).contains(black_box(needle)));
     })
 }
 
diff --git a/library/alloc/tests/str.rs b/library/alloc/tests/str.rs
index e30329aa1cb..9689196ef21 100644
--- a/library/alloc/tests/str.rs
+++ b/library/alloc/tests/str.rs
@@ -1590,11 +1590,27 @@ fn test_bool_from_str() {
     assert_eq!("not even a boolean".parse::<bool>().ok(), None);
 }
 
-fn check_contains_all_substrings(s: &str) {
-    assert!(s.contains(""));
-    for i in 0..s.len() {
-        for j in i + 1..=s.len() {
-            assert!(s.contains(&s[i..j]));
+fn check_contains_all_substrings(haystack: &str) {
+    let mut modified_needle = String::new();
+
+    for i in 0..haystack.len() {
+        // check different haystack lengths since we special-case short haystacks.
+        let haystack = &haystack[0..i];
+        assert!(haystack.contains(""));
+        for j in 0..haystack.len() {
+            for k in j + 1..=haystack.len() {
+                let needle = &haystack[j..k];
+                assert!(haystack.contains(needle));
+                modified_needle.clear();
+                modified_needle.push_str(needle);
+                modified_needle.replace_range(0..1, "\0");
+                assert!(!haystack.contains(&modified_needle));
+
+                modified_needle.clear();
+                modified_needle.push_str(needle);
+                modified_needle.replace_range(needle.len() - 1..needle.len(), "\0");
+                assert!(!haystack.contains(&modified_needle));
+            }
         }
     }
 }
diff --git a/library/core/src/str/pattern.rs b/library/core/src/str/pattern.rs
index ec2cb429e67..c5be32861f9 100644
--- a/library/core/src/str/pattern.rs
+++ b/library/core/src/str/pattern.rs
@@ -39,6 +39,7 @@
 )]
 
 use crate::cmp;
+use crate::cmp::Ordering;
 use crate::fmt;
 use crate::slice::memchr;
 
@@ -946,6 +947,32 @@ impl<'a, 'b> Pattern<'a> for &'b str {
         haystack.as_bytes().starts_with(self.as_bytes())
     }
 
+    /// Checks whether the pattern matches anywhere in the haystack
+    #[inline]
+    fn is_contained_in(self, haystack: &'a str) -> bool {
+        if self.len() == 0 {
+            return true;
+        }
+
+        match self.len().cmp(&haystack.len()) {
+            Ordering::Less => {
+                if self.len() == 1 {
+                    return haystack.as_bytes().contains(&self.as_bytes()[0]);
+                }
+
+                #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
+                if self.len() <= 32 {
+                    if let Some(result) = simd_contains(self, haystack) {
+                        return result;
+                    }
+                }
+
+                self.into_searcher(haystack).next_match().is_some()
+            }
+            _ => self == haystack,
+        }
+    }
+
     /// Removes the pattern from the front of haystack, if it matches.
     #[inline]
     fn strip_prefix_of(self, haystack: &'a str) -> Option<&'a str> {
@@ -1684,3 +1711,208 @@ impl TwoWayStrategy for RejectAndMatch {
         SearchStep::Match(a, b)
     }
 }
+
+/// SIMD search for short needles based on
+/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
+///
+/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
+/// does) by probing the first and last byte of the needle for the whole vector width
+/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
+///
+/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
+/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
+/// should be evaluated.
+///
+/// For haystacks smaller than vector-size + needle length it falls back to
+/// a naive O(n*m) search so this implementation should not be called on larger needles.
+///
+/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
+#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
+#[inline]
+fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
+    let needle = needle.as_bytes();
+    let haystack = haystack.as_bytes();
+
+    debug_assert!(needle.len() > 1);
+
+    use crate::ops::BitAnd;
+    use crate::simd::mask8x16 as Mask;
+    use crate::simd::u8x16 as Block;
+    use crate::simd::{SimdPartialEq, ToBitMask};
+
+    let first_probe = needle[0];
+
+    // the offset used for the 2nd vector
+    let second_probe_offset = if needle.len() == 2 {
+        // never bail out on len=2 needles because the probes will fully cover them and have
+        // no degenerate cases.
+        1
+    } else {
+        // try a few bytes in case first and last byte of the needle are the same
+        let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
+            // fall back to other search methods if we can't find any different bytes
+            // since we could otherwise hit some degenerate cases
+            return None;
+        };
+        second_probe_offset
+    };
+
+    // do a naive search if the haystack is too small to fit
+    if haystack.len() < Block::LANES + second_probe_offset {
+        return Some(haystack.windows(needle.len()).any(|c| c == needle));
+    }
+
+    let first_probe: Block = Block::splat(first_probe);
+    let second_probe: Block = Block::splat(needle[second_probe_offset]);
+    // first byte are already checked by the outer loop. to verify a match only the
+    // remainder has to be compared.
+    let trimmed_needle = &needle[1..];
+
+    // this #[cold] is load-bearing, benchmark before removing it...
+    let check_mask = #[cold]
+    |idx, mask: u16, skip: bool| -> bool {
+        if skip {
+            return false;
+        }
+
+        // and so is this. optimizations are weird.
+        let mut mask = mask;
+
+        while mask != 0 {
+            let trailing = mask.trailing_zeros();
+            let offset = idx + trailing as usize + 1;
+            // SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
+            // and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
+            unsafe {
+                let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
+                if small_slice_eq(sub, trimmed_needle) {
+                    return true;
+                }
+            }
+            mask &= !(1 << trailing);
+        }
+        return false;
+    };
+
+    let test_chunk = |idx| -> u16 {
+        // SAFETY: this requires at least LANES bytes being readable at idx
+        // that is ensured by the loop ranges (see comments below)
+        let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
+        // SAFETY: this requires LANES + block_offset bytes being readable at idx
+        let b: Block = unsafe {
+            haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
+        };
+        let eq_first: Mask = a.simd_eq(first_probe);
+        let eq_last: Mask = b.simd_eq(second_probe);
+        let both = eq_first.bitand(eq_last);
+        let mask = both.to_bitmask();
+
+        return mask;
+    };
+
+    let mut i = 0;
+    let mut result = false;
+    // The loop condition must ensure that there's enough headroom to read LANE bytes,
+    // and not only at the current index but also at the index shifted by block_offset
+    const UNROLL: usize = 4;
+    while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
+        let mut masks = [0u16; UNROLL];
+        for j in 0..UNROLL {
+            masks[j] = test_chunk(i + j * Block::LANES);
+        }
+        for j in 0..UNROLL {
+            let mask = masks[j];
+            if mask != 0 {
+                result |= check_mask(i + j * Block::LANES, mask, result);
+            }
+        }
+        i += UNROLL * Block::LANES;
+    }
+    while i + second_probe_offset + Block::LANES < haystack.len() && !result {
+        let mask = test_chunk(i);
+        if mask != 0 {
+            result |= check_mask(i, mask, result);
+        }
+        i += Block::LANES;
+    }
+
+    // Process the tail that didn't fit into LANES-sized steps.
+    // This simply repeats the same procedure but as right-aligned chunk instead
+    // of a left-aligned one. The last byte must be exactly flush with the string end so
+    // we don't miss a single byte or read out of bounds.
+    let i = haystack.len() - second_probe_offset - Block::LANES;
+    let mask = test_chunk(i);
+    if mask != 0 {
+        result |= check_mask(i, mask, result);
+    }
+
+    Some(result)
+}
+
+/// Compares short slices for equality.
+///
+/// It avoids a call to libc's memcmp which is faster on long slices
+/// due to SIMD optimizations but it incurs a function call overhead.
+///
+/// # Safety
+///
+/// Both slices must have the same length.
+#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
+#[inline]
+unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
+    // This function is adapted from
+    // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
+
+    // If we don't have enough bytes to do 4-byte at a time loads, then
+    // fall back to the naive slow version.
+    //
+    // Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
+    // of a loop. Benchmark it.
+    if x.len() < 4 {
+        for (&b1, &b2) in x.iter().zip(y) {
+            if b1 != b2 {
+                return false;
+            }
+        }
+        return true;
+    }
+    // When we have 4 or more bytes to compare, then proceed in chunks of 4 at
+    // a time using unaligned loads.
+    //
+    // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
+    // that this particular version of memcmp is likely to be called with tiny
+    // needles. That means that if we do 8 byte loads, then a higher proportion
+    // of memcmp calls will use the slower variant above. With that said, this
+    // is a hypothesis and is only loosely supported by benchmarks. There's
+    // likely some improvement that could be made here. The main thing here
+    // though is to optimize for latency, not throughput.
+
+    // SAFETY: Via the conditional above, we know that both `px` and `py`
+    // have the same length, so `px < pxend` implies that `py < pyend`.
+    // Thus, derefencing both `px` and `py` in the loop below is safe.
+    //
+    // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
+    // end of of `px` and `py`. Thus, the final dereference outside of the
+    // loop is guaranteed to be valid. (The final comparison will overlap with
+    // the last comparison done in the loop for lengths that aren't multiples
+    // of four.)
+    //
+    // Finally, we needn't worry about alignment here, since we do unaligned
+    // loads.
+    unsafe {
+        let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
+        let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
+        while px < pxend {
+            let vx = (px as *const u32).read_unaligned();
+            let vy = (py as *const u32).read_unaligned();
+            if vx != vy {
+                return false;
+            }
+            px = px.add(4);
+            py = py.add(4);
+        }
+        let vx = (pxend as *const u32).read_unaligned();
+        let vy = (pyend as *const u32).read_unaligned();
+        vx == vy
+    }
+}