about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEduardo Sánchez Muñoz <esm@eduardosm.net>2021-10-06 22:26:36 +0200
committerEduardo Sánchez Muñoz <esm@eduardosm.net>2021-11-21 17:05:55 +0100
commit23637e20cdf3f7b8e01b42dbaf25357e5d3c31ae (patch)
tree46020c09f2853966b1fd3cd36273be410464af20
parent3bfde2f1f4fc9409ecb63dfe1370df66171cf861 (diff)
downloadrust-23637e20cdf3f7b8e01b42dbaf25357e5d3c31ae.tar.gz
rust-23637e20cdf3f7b8e01b42dbaf25357e5d3c31ae.zip
libcore: assume the input of `next_code_point` and `next_code_point_reverse` is UTF-8-like
The functions are now `unsafe` and they use `Option::unwrap_unchecked` instead of `unwrap_or_0`

`unwrap_or_0` was added in 42357d772b8a3a1ce4395deeac0a5cf1f66e951d. I guess `unwrap_unchecked` was not available back then.

Given this example:

```rust
pub fn first_char(s: &str) -> Option<char> {
    s.chars().next()
}
```

Previously, the following assembly was produced:

```asm
_ZN7example10first_char17ha056ddea6bafad1cE:
	.cfi_startproc
	test	rsi, rsi
	je	.LBB0_1
	movzx	edx, byte ptr [rdi]
	test	dl, dl
	js	.LBB0_3
	mov	eax, edx
	ret
.LBB0_1:
	mov	eax, 1114112
	ret
.LBB0_3:
	lea	r8, [rdi + rsi]
	xor	eax, eax
	mov	r9, r8
	cmp	rsi, 1
	je	.LBB0_5
	movzx	eax, byte ptr [rdi + 1]
	add	rdi, 2
	and	eax, 63
	mov	r9, rdi
.LBB0_5:
	mov	ecx, edx
	and	ecx, 31
	cmp	dl, -33
	jbe	.LBB0_6
	cmp	r9, r8
	je	.LBB0_9
	movzx	esi, byte ptr [r9]
	add	r9, 1
	and	esi, 63
	shl	eax, 6
	or	eax, esi
	cmp	dl, -16
	jb	.LBB0_12
.LBB0_13:
	cmp	r9, r8
	je	.LBB0_14
	movzx	edx, byte ptr [r9]
	and	edx, 63
	jmp	.LBB0_16
.LBB0_6:
	shl	ecx, 6
	or	eax, ecx
	ret
.LBB0_9:
	xor	esi, esi
	mov	r9, r8
	shl	eax, 6
	or	eax, esi
	cmp	dl, -16
	jae	.LBB0_13
.LBB0_12:
	shl	ecx, 12
	or	eax, ecx
	ret
.LBB0_14:
	xor	edx, edx
.LBB0_16:
	and	ecx, 7
	shl	ecx, 18
	shl	eax, 6
	or	eax, ecx
	or	eax, edx
	ret
```

After this change, the assembly is reduced to:

```asm
_ZN7example10first_char17h4318683472f884ccE:
	.cfi_startproc
	test	rsi, rsi
	je	.LBB0_1
	movzx	ecx, byte ptr [rdi]
	test	cl, cl
	js	.LBB0_3
	mov	eax, ecx
	ret
.LBB0_1:
	mov	eax, 1114112
	ret
.LBB0_3:
	mov	eax, ecx
	and	eax, 31
	movzx	esi, byte ptr [rdi + 1]
	and	esi, 63
	cmp	cl, -33
	jbe	.LBB0_4
	movzx	edx, byte ptr [rdi + 2]
	shl	esi, 6
	and	edx, 63
	or	edx, esi
	cmp	cl, -16
	jb	.LBB0_7
	movzx	ecx, byte ptr [rdi + 3]
	and	eax, 7
	shl	eax, 18
	shl	edx, 6
	and	ecx, 63
	or	ecx, edx
	or	eax, ecx
	ret
.LBB0_4:
	shl	eax, 6
	or	eax, esi
	ret
.LBB0_7:
	shl	eax, 12
	or	eax, edx
	ret
```
-rw-r--r--library/core/src/str/iter.rs14
-rw-r--r--library/core/src/str/validations.rs44
-rw-r--r--library/std/src/sys_common/wtf8.rs3
3 files changed, 36 insertions, 25 deletions
diff --git a/library/core/src/str/iter.rs b/library/core/src/str/iter.rs
index 94a534c6e79..48410446716 100644
--- a/library/core/src/str/iter.rs
+++ b/library/core/src/str/iter.rs
@@ -39,10 +39,9 @@ impl<'a> Iterator for Chars<'a> {
 
     #[inline]
     fn next(&mut self) -> Option<char> {
-        next_code_point(&mut self.iter).map(|ch| {
-            // SAFETY: `str` invariant says `ch` is a valid Unicode Scalar Value.
-            unsafe { char::from_u32_unchecked(ch) }
-        })
+        // SAFETY: `str` invariant says `self.iter` is a valid UTF-8 string and
+        // the resulting `ch` is a valid Unicode Scalar Value.
+        unsafe { next_code_point(&mut self.iter).map(|ch| char::from_u32_unchecked(ch)) }
     }
 
     #[inline]
@@ -81,10 +80,9 @@ impl fmt::Debug for Chars<'_> {
 impl<'a> DoubleEndedIterator for Chars<'a> {
     #[inline]
     fn next_back(&mut self) -> Option<char> {
-        next_code_point_reverse(&mut self.iter).map(|ch| {
-            // SAFETY: `str` invariant says `ch` is a valid Unicode Scalar Value.
-            unsafe { char::from_u32_unchecked(ch) }
-        })
+        // SAFETY: `str` invariant says `self.iter` is a valid UTF-8 string and
+        // the resulting `ch` is a valid Unicode Scalar Value.
+        unsafe { next_code_point_reverse(&mut self.iter).map(|ch| char::from_u32_unchecked(ch)) }
     }
 }
 
diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs
index e362d5c05c1..be9c41a491b 100644
--- a/library/core/src/str/validations.rs
+++ b/library/core/src/str/validations.rs
@@ -25,19 +25,15 @@ pub(super) const fn utf8_is_cont_byte(byte: u8) -> bool {
     (byte as i8) < -64
 }
 
-#[inline]
-const fn unwrap_or_0(opt: Option<&u8>) -> u8 {
-    match opt {
-        Some(&byte) => byte,
-        None => 0,
-    }
-}
-
 /// Reads the next code point out of a byte iterator (assuming a
 /// UTF-8-like encoding).
+///
+/// # Safety
+///
+/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
 #[unstable(feature = "str_internals", issue = "none")]
 #[inline]
-pub fn next_code_point<'a, I: Iterator<Item = &'a u8>>(bytes: &mut I) -> Option<u32> {
+pub unsafe fn next_code_point<'a, I: Iterator<Item = &'a u8>>(bytes: &mut I) -> Option<u32> {
     // Decode UTF-8
     let x = *bytes.next()?;
     if x < 128 {
@@ -48,18 +44,24 @@ pub fn next_code_point<'a, I: Iterator<Item = &'a u8>>(bytes: &mut I) -> Option<
     // Decode from a byte combination out of: [[[x y] z] w]
     // NOTE: Performance is sensitive to the exact formulation here
     let init = utf8_first_byte(x, 2);
-    let y = unwrap_or_0(bytes.next());
+    // SAFETY: `bytes` produces an UTF-8-like string,
+    // so the iterator must produce a value here.
+    let y = unsafe { *bytes.next().unwrap_unchecked() };
     let mut ch = utf8_acc_cont_byte(init, y);
     if x >= 0xE0 {
         // [[x y z] w] case
         // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid
-        let z = unwrap_or_0(bytes.next());
+        // SAFETY: `bytes` produces an UTF-8-like string,
+        // so the iterator must produce a value here.
+        let z = unsafe { *bytes.next().unwrap_unchecked() };
         let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z);
         ch = init << 12 | y_z;
         if x >= 0xF0 {
             // [x y z w] case
             // use only the lower 3 bits of `init`
-            let w = unwrap_or_0(bytes.next());
+            // SAFETY: `bytes` produces an UTF-8-like string,
+            // so the iterator must produce a value here.
+            let w = unsafe { *bytes.next().unwrap_unchecked() };
             ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w);
         }
     }
@@ -69,8 +71,12 @@ pub fn next_code_point<'a, I: Iterator<Item = &'a u8>>(bytes: &mut I) -> Option<
 
 /// Reads the last code point out of a byte iterator (assuming a
 /// UTF-8-like encoding).
+///
+/// # Safety
+///
+/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
 #[inline]
-pub(super) fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option<u32>
+pub(super) unsafe fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option<u32>
 where
     I: DoubleEndedIterator<Item = &'a u8>,
 {
@@ -83,13 +89,19 @@ where
     // Multibyte case follows
     // Decode from a byte combination out of: [x [y [z w]]]
     let mut ch;
-    let z = unwrap_or_0(bytes.next_back());
+    // SAFETY: `bytes` produces an UTF-8-like string,
+    // so the iterator must produce a value here.
+    let z = unsafe { *bytes.next_back().unwrap_unchecked() };
     ch = utf8_first_byte(z, 2);
     if utf8_is_cont_byte(z) {
-        let y = unwrap_or_0(bytes.next_back());
+        // SAFETY: `bytes` produces an UTF-8-like string,
+        // so the iterator must produce a value here.
+        let y = unsafe { *bytes.next_back().unwrap_unchecked() };
         ch = utf8_first_byte(y, 3);
         if utf8_is_cont_byte(y) {
-            let x = unwrap_or_0(bytes.next_back());
+            // SAFETY: `bytes` produces an UTF-8-like string,
+            // so the iterator must produce a value here.
+            let x = unsafe { *bytes.next_back().unwrap_unchecked() };
             ch = utf8_first_byte(x, 4);
             ch = utf8_acc_cont_byte(ch, y);
         }
diff --git a/library/std/src/sys_common/wtf8.rs b/library/std/src/sys_common/wtf8.rs
index 0629859bd9d..6e29bc61454 100644
--- a/library/std/src/sys_common/wtf8.rs
+++ b/library/std/src/sys_common/wtf8.rs
@@ -809,7 +809,8 @@ impl<'a> Iterator for Wtf8CodePoints<'a> {
 
     #[inline]
     fn next(&mut self) -> Option<CodePoint> {
-        next_code_point(&mut self.bytes).map(|c| CodePoint { value: c })
+        // SAFETY: `self.bytes` has been created from a WTF-8 string
+        unsafe { next_code_point(&mut self.bytes).map(|c| CodePoint { value: c }) }
     }
 
     #[inline]