about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-10-17 16:20:02 +0000
committerbors <bors@rust-lang.org>2024-10-17 16:20:02 +0000
commit86bd45979a964678b40b79156744f0057759d840 (patch)
treee84e17b3147fe27e6b879dcda130975b97d90ad2
parent3a85d3fa785d95a7b7bcf4f160b67bffba7afd4a (diff)
parent4484085b18df4b10243b503a21602bb71836e8b3 (diff)
downloadrust-86bd45979a964678b40b79156744f0057759d840.tar.gz
rust-86bd45979a964678b40b79156744f0057759d840.zip
Auto merge of #130223 - LaihoE:faster_str_replace, r=thomcc
optimize str.replace

Adds a fast path for str.replace for the ascii to ascii case. This allows for autovectorizing the code. Also should this instead be done with specialization? This way we could remove one branch. I think it is the kind of branch that is easy to predict though.

Benchmark for the fast path (replace all "a" with "b" in the rust wikipedia article, using criterion) :
| N        | Speedup | Time New (ns) | Time Old (ns) |
|----------|---------|---------------|---------------|
| 2        | 2.03    | 13.567        | 27.576        |
| 8        | 1.73    | 17.478        | 30.259        |
| 11       | 2.46    | 18.296        | 45.055        |
| 16       | 2.71    | 17.181        | 46.526        |
| 37       | 4.43    | 18.526        | 81.997        |
| 64       | 8.54    | 18.670        | 159.470       |
| 200      | 9.82    | 29.634        | 291.010       |
| 2000     | 24.34   | 81.114        | 1974.300      |
| 20000    | 30.61   | 598.520       | 18318.000     |
| 1000000  | 29.31   | 33458.000     | 980540.000    |
-rw-r--r--library/alloc/src/str.rs25
-rw-r--r--library/alloc/src/string.rs7
-rw-r--r--library/core/src/str/pattern.rs33
3 files changed, 63 insertions, 2 deletions
diff --git a/library/alloc/src/str.rs b/library/alloc/src/str.rs
index 42501f9c315..52ceb8b45f9 100644
--- a/library/alloc/src/str.rs
+++ b/library/alloc/src/str.rs
@@ -20,7 +20,7 @@ pub use core::str::SplitInclusive;
 pub use core::str::SplitWhitespace;
 #[stable(feature = "rust1", since = "1.0.0")]
 pub use core::str::pattern;
-use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher};
+use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher, Utf8Pattern};
 #[stable(feature = "rust1", since = "1.0.0")]
 pub use core::str::{Bytes, CharIndices, Chars, from_utf8, from_utf8_mut};
 #[stable(feature = "str_escape", since = "1.34.0")]
@@ -269,6 +269,18 @@ impl str {
     #[stable(feature = "rust1", since = "1.0.0")]
     #[inline]
     pub fn replace<P: Pattern>(&self, from: P, to: &str) -> String {
+        // Fast path for ASCII to ASCII case.
+
+        if let Some(from_byte) = match from.as_utf8_pattern() {
+            Some(Utf8Pattern::StringPattern([from_byte])) => Some(*from_byte),
+            Some(Utf8Pattern::CharPattern(c)) => c.as_ascii().map(|ascii_char| ascii_char.to_u8()),
+            _ => None,
+        } {
+            if let [to_byte] = to.as_bytes() {
+                return unsafe { replace_ascii(self.as_bytes(), from_byte, *to_byte) };
+            }
+        }
+
         let mut result = String::new();
         let mut last_end = 0;
         for (start, part) in self.match_indices(from) {
@@ -686,3 +698,14 @@ pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
         (ascii_string, rest)
     }
 }
+#[inline]
+#[cfg(not(test))]
+#[cfg(not(no_global_oom_handling))]
+#[allow(dead_code)]
+/// Faster implementation of string replacement for ASCII to ASCII cases.
+/// Should produce fast vectorized code.
+unsafe fn replace_ascii(utf8_bytes: &[u8], from: u8, to: u8) -> String {
+    let result: Vec<u8> = utf8_bytes.iter().map(|b| if *b == from { to } else { *b }).collect();
+    // SAFETY: We replaced ascii with ascii on valid utf8 strings.
+    unsafe { String::from_utf8_unchecked(result) }
+}
diff --git a/library/alloc/src/string.rs b/library/alloc/src/string.rs
index 82dbf030608..b042720933b 100644
--- a/library/alloc/src/string.rs
+++ b/library/alloc/src/string.rs
@@ -53,7 +53,7 @@ use core::ops::AddAssign;
 #[cfg(not(no_global_oom_handling))]
 use core::ops::Bound::{Excluded, Included, Unbounded};
 use core::ops::{self, Range, RangeBounds};
-use core::str::pattern::Pattern;
+use core::str::pattern::{Pattern, Utf8Pattern};
 use core::{fmt, hash, ptr, slice};
 
 #[cfg(not(no_global_oom_handling))]
@@ -2436,6 +2436,11 @@ impl<'b> Pattern for &'b String {
     {
         self[..].strip_suffix_of(haystack)
     }
+
+    #[inline]
+    fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
+        Some(Utf8Pattern::StringPattern(self.as_bytes()))
+    }
 }
 
 macro_rules! impl_eq {
diff --git a/library/core/src/str/pattern.rs b/library/core/src/str/pattern.rs
index 9f1294d7606..eb60effe813 100644
--- a/library/core/src/str/pattern.rs
+++ b/library/core/src/str/pattern.rs
@@ -160,6 +160,19 @@ pub trait Pattern: Sized {
             None
         }
     }
+
+    /// Returns the pattern as utf-8 bytes if possible.
+    fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>>;
+}
+/// Result of calling [`Pattern::as_utf8_pattern()`].
+/// Can be used for inspecting the contents of a [`Pattern`] in cases
+/// where the underlying representation can be represented as UTF-8.
+#[derive(Copy, Clone, Eq, PartialEq, Debug)]
+pub enum Utf8Pattern<'a> {
+    /// Type returned by String and str types.
+    StringPattern(&'a [u8]),
+    /// Type returned by char types.
+    CharPattern(char),
 }
 
 // Searcher
@@ -599,6 +612,11 @@ impl Pattern for char {
     {
         self.encode_utf8(&mut [0u8; 4]).strip_suffix_of(haystack)
     }
+
+    #[inline]
+    fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
+        Some(Utf8Pattern::CharPattern(*self))
+    }
 }
 
 /////////////////////////////////////////////////////////////////////////////
@@ -657,6 +675,11 @@ impl<C: MultiCharEq> Pattern for MultiCharEqPattern<C> {
     fn into_searcher(self, haystack: &str) -> MultiCharEqSearcher<'_, C> {
         MultiCharEqSearcher { haystack, char_eq: self.0, char_indices: haystack.char_indices() }
     }
+
+    #[inline]
+    fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
+        None
+    }
 }
 
 unsafe impl<'a, C: MultiCharEq> Searcher<'a> for MultiCharEqSearcher<'a, C> {
@@ -747,6 +770,11 @@ macro_rules! pattern_methods {
         {
             ($pmap)(self).strip_suffix_of(haystack)
         }
+
+        #[inline]
+        fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
+            None
+        }
     };
 }
 
@@ -1022,6 +1050,11 @@ impl<'b> Pattern for &'b str {
             None
         }
     }
+
+    #[inline]
+    fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
+        Some(Utf8Pattern::StringPattern(self.as_bytes()))
+    }
 }
 
 /////////////////////////////////////////////////////////////////////////////