about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2022-11-22 20:47:17 +0000
committerbors <bors@rust-lang.org>2022-11-22 20:47:17 +0000
commitff8c8dfbe66701531e3e5e335c28c544d0fbc945 (patch)
tree5179c28a5532cf17ceb11c39cdb45c35e88511d0
parente221616639fb87de9dca21e252ee8a2565ec51d0 (diff)
parent3ed8fccff5295e1d92419e7c67502e434ff1e98f (diff)
downloadrust-ff8c8dfbe66701531e3e5e335c28c544d0fbc945.tar.gz
rust-ff8c8dfbe66701531e3e5e335c28c544d0fbc945.zip
Auto merge of #104735 - the8472:simd-contains-fix, r=thomcc
Simd contains fix

Fixes #104726

The bug was introduced by an improvement late in the original PR (#103779) which added the backtracking when the last and first byte of the needle were the same. That changed the meaning of the variable for the last probe offset, which I should have split into the last byte offset and last probe offset. Not doing so lead to incorrect loop conditions.
-rw-r--r--library/alloc/tests/str.rs12
-rw-r--r--library/core/src/str/pattern.rs10
2 files changed, 18 insertions, 4 deletions
diff --git a/library/alloc/tests/str.rs b/library/alloc/tests/str.rs
index 9689196ef21..4d182be02c9 100644
--- a/library/alloc/tests/str.rs
+++ b/library/alloc/tests/str.rs
@@ -1632,6 +1632,18 @@ fn strslice_issue_16878() {
 }
 
 #[test]
+fn strslice_issue_104726() {
+    // Edge-case in the simd_contains impl.
+    // The first and last byte are the same so it backtracks by one byte
+    // which aligns with the end of the string. Previously incorrect offset calculations
+    // lead to out-of-bounds slicing.
+    #[rustfmt::skip]
+    let needle =                        "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaba";
+    let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab";
+    assert!(!haystack.contains(needle));
+}
+
+#[test]
 #[cfg_attr(miri, ignore)] // Miri is too slow
 fn test_strslice_contains() {
     let x = "There are moments, Jeeves, when one asks oneself, 'Do trousers matter?'";
diff --git a/library/core/src/str/pattern.rs b/library/core/src/str/pattern.rs
index c5be32861f9..d76d6f8b2a2 100644
--- a/library/core/src/str/pattern.rs
+++ b/library/core/src/str/pattern.rs
@@ -1741,6 +1741,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
     use crate::simd::{SimdPartialEq, ToBitMask};
 
     let first_probe = needle[0];
+    let last_byte_offset = needle.len() - 1;
 
     // the offset used for the 2nd vector
     let second_probe_offset = if needle.len() == 2 {
@@ -1758,7 +1759,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
     };
 
     // do a naive search if the haystack is too small to fit
-    if haystack.len() < Block::LANES + second_probe_offset {
+    if haystack.len() < Block::LANES + last_byte_offset {
         return Some(haystack.windows(needle.len()).any(|c| c == needle));
     }
 
@@ -1815,7 +1816,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
     // The loop condition must ensure that there's enough headroom to read LANE bytes,
     // and not only at the current index but also at the index shifted by block_offset
     const UNROLL: usize = 4;
-    while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
+    while i + last_byte_offset + UNROLL * Block::LANES < haystack.len() && !result {
         let mut masks = [0u16; UNROLL];
         for j in 0..UNROLL {
             masks[j] = test_chunk(i + j * Block::LANES);
@@ -1828,7 +1829,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
         }
         i += UNROLL * Block::LANES;
     }
-    while i + second_probe_offset + Block::LANES < haystack.len() && !result {
+    while i + last_byte_offset + Block::LANES < haystack.len() && !result {
         let mask = test_chunk(i);
         if mask != 0 {
             result |= check_mask(i, mask, result);
@@ -1840,7 +1841,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
     // This simply repeats the same procedure but as right-aligned chunk instead
     // of a left-aligned one. The last byte must be exactly flush with the string end so
     // we don't miss a single byte or read out of bounds.
-    let i = haystack.len() - second_probe_offset - Block::LANES;
+    let i = haystack.len() - last_byte_offset - Block::LANES;
     let mask = test_chunk(i);
     if mask != 0 {
         result |= check_mask(i, mask, result);
@@ -1860,6 +1861,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
 #[inline]
 unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
+    debug_assert_eq!(x.len(), y.len());
     // This function is adapted from
     // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32