about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2022-05-26 15:29:01 +0000
committerbors <bors@rust-lang.org>2022-05-26 15:29:01 +0000
commit1851f0802e148bb7fa0bfd7dabcb7397bf371b0b (patch)
tree7e6f122ecd392e2ab7698fe3dab91de707fb27d1
parent1ab98933fa75d72e882b86feac1a0be3a5b02cb0 (diff)
parentd0f993070929428f7a484cfce66d5690e709006d (diff)
downloadrust-1851f0802e148bb7fa0bfd7dabcb7397bf371b0b.tar.gz
rust-1851f0802e148bb7fa0bfd7dabcb7397bf371b0b.zip
Auto merge of #97046 - conradludgate:faster-ascii-case-conv-path, r=thomcc
improve case conversion happy path

Someone shared the source code for [Go's string case conversion](https://github.com/golang/go/blob/19156a54741d4f353c9e8e0860197ca95a6ee6ca/src/strings/strings.go#L558-L616).

It features a hot path for ascii-only strings (although I assume for reasons specific to go, they've opted for a read safe hot loop).

I've borrowed these ideas and also kept our existing code to provide a fast path + seamless utf-8 correct path fallback.

(Naive) Benchmarks can be found here https://github.com/conradludgate/case-conv

For the cases where non-ascii is found near the start, the performance of this algorithm does fall back to original speeds and has not had any measurable speed loss
-rw-r--r--library/alloc/src/str.rs74
-rw-r--r--library/alloc/tests/str.rs14
2 files changed, 83 insertions, 5 deletions
diff --git a/library/alloc/src/str.rs b/library/alloc/src/str.rs
index 0eaa2639863..39dfd98ddcc 100644
--- a/library/alloc/src/str.rs
+++ b/library/alloc/src/str.rs
@@ -383,15 +383,23 @@ impl str {
                   without modifying the original"]
     #[stable(feature = "unicode_case_mapping", since = "1.2.0")]
     pub fn to_lowercase(&self) -> String {
-        let mut s = String::with_capacity(self.len());
-        for (i, c) in self[..].char_indices() {
+        let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
+
+        // Safety: we know this is a valid char boundary since
+        // out.len() is only progressed if ascii bytes are found
+        let rest = unsafe { self.get_unchecked(out.len()..) };
+
+        // Safety: We have written only valid ASCII to our vec
+        let mut s = unsafe { String::from_utf8_unchecked(out) };
+
+        for (i, c) in rest[..].char_indices() {
             if c == 'Σ' {
                 // Σ maps to σ, except at the end of a word where it maps to ς.
                 // This is the only conditional (contextual) but language-independent mapping
                 // in `SpecialCasing.txt`,
                 // so hard-code it rather than have a generic "condition" mechanism.
                 // See https://github.com/rust-lang/rust/issues/26035
-                map_uppercase_sigma(self, i, &mut s)
+                map_uppercase_sigma(rest, i, &mut s)
             } else {
                 match conversions::to_lower(c) {
                     [a, '\0', _] => s.push(a),
@@ -466,8 +474,16 @@ impl str {
                   without modifying the original"]
     #[stable(feature = "unicode_case_mapping", since = "1.2.0")]
     pub fn to_uppercase(&self) -> String {
-        let mut s = String::with_capacity(self.len());
-        for c in self[..].chars() {
+        let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);
+
+        // Safety: we know this is a valid char boundary since
+        // out.len() is only progressed if ascii bytes are found
+        let rest = unsafe { self.get_unchecked(out.len()..) };
+
+        // Safety: We have written only valid ASCII to our vec
+        let mut s = unsafe { String::from_utf8_unchecked(out) };
+
+        for c in rest.chars() {
             match conversions::to_upper(c) {
                 [a, '\0', _] => s.push(a),
                 [a, b, '\0'] => {
@@ -619,3 +635,51 @@ impl str {
 pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
     unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
 }
+
+/// Converts the bytes while the bytes are still ascii.
+/// For better average performance, this is happens in chunks of `2*size_of::<usize>()`.
+/// Returns a vec with the converted bytes.
+#[inline]
+#[cfg(not(test))]
+#[cfg(not(no_global_oom_handling))]
+fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
+    let mut out = Vec::with_capacity(b.len());
+
+    const USIZE_SIZE: usize = mem::size_of::<usize>();
+    const MAGIC_UNROLL: usize = 2;
+    const N: usize = USIZE_SIZE * MAGIC_UNROLL;
+    const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
+
+    let mut i = 0;
+    unsafe {
+        while i + N <= b.len() {
+            // Safety: we have checks the sizes `b` and `out` to know that our
+            let in_chunk = b.get_unchecked(i..i + N);
+            let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);
+
+            let mut bits = 0;
+            for j in 0..MAGIC_UNROLL {
+                // read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
+                // safety: in_chunk is valid bytes in the range
+                bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
+            }
+            // if our chunks aren't ascii, then return only the prior bytes as init
+            if bits & NONASCII_MASK != 0 {
+                break;
+            }
+
+            // perform the case conversions on N bytes (gets heavily autovec'd)
+            for j in 0..N {
+                // safety: in_chunk and out_chunk is valid bytes in the range
+                let out = out_chunk.get_unchecked_mut(j);
+                out.write(convert(in_chunk.get_unchecked(j)));
+            }
+
+            // mark these bytes as initialised
+            i += N;
+        }
+        out.set_len(i);
+    }
+
+    out
+}
diff --git a/library/alloc/tests/str.rs b/library/alloc/tests/str.rs
index 273b39aa45a..7379569dd68 100644
--- a/library/alloc/tests/str.rs
+++ b/library/alloc/tests/str.rs
@@ -1772,6 +1772,20 @@ fn to_lowercase() {
     assert_eq!("ΑΣΑ".to_lowercase(), "ασα");
     assert_eq!("ΑΣ'Α".to_lowercase(), "ασ'α");
     assert_eq!("ΑΣ''Α".to_lowercase(), "ασ''α");
+
+    // a really long string that has it's lowercase form
+    // even longer. this tests that implementations don't assume
+    // an incorrect upper bound on allocations
+    let upper = str::repeat("İ", 512);
+    let lower = str::repeat("i̇", 512);
+    assert_eq!(upper.to_lowercase(), lower);
+
+    // a really long ascii-only string.
+    // This test that the ascii hot-path
+    // functions correctly
+    let upper = str::repeat("A", 511);
+    let lower = str::repeat("a", 511);
+    assert_eq!(upper.to_lowercase(), lower);
 }
 
 #[test]