about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJan Verbeek <jan.verbeek@posteo.nl>2023-12-03 12:25:11 +0100
committerJan Verbeek <jan.verbeek@posteo.nl>2024-01-21 19:51:49 +0100
commit51a7396ad3d78d9326ee1537b9ff29ab3919556f (patch)
tree72b6d043edefb8425c4fb26e0e19fc5b938fdf8c
parentd9d89fd53dd18b7eeab0cc276353209eb8b073b2 (diff)
downloadrust-51a7396ad3d78d9326ee1537b9ff29ab3919556f.tar.gz
rust-51a7396ad3d78d9326ee1537b9ff29ab3919556f.zip
Move `OsStr::slice_encoded_bytes` validation to platform modules
On Windows and UEFI this improves performance and error messaging.

On other platforms we optimize the fast path a bit more.

This also prepares for later relaxing the checks on certain platforms.
-rw-r--r--library/std/src/ffi/mod.rs7
-rw-r--r--library/std/src/ffi/os_str.rs43
-rw-r--r--library/std/src/ffi/os_str/tests.rs68
-rw-r--r--library/std/src/sys/os_str/bytes.rs43
-rw-r--r--library/std/src/sys/os_str/wtf8.rs7
-rw-r--r--library/std/src/sys_common/wtf8.rs36
-rw-r--r--library/std/src/sys_common/wtf8/tests.rs62
7 files changed, 219 insertions, 47 deletions
diff --git a/library/std/src/ffi/mod.rs b/library/std/src/ffi/mod.rs
index 97e78d17786..818571ddaaa 100644
--- a/library/std/src/ffi/mod.rs
+++ b/library/std/src/ffi/mod.rs
@@ -127,6 +127,11 @@
 //! trait, which provides a [`from_wide`] method to convert a native Windows
 //! string (without the terminating nul character) to an [`OsString`].
 //!
+//! ## Other platforms
+//!
+//! Many other platforms provide their own extension traits in a
+//! `std::os::*::ffi` module.
+//!
 //! ## On all platforms
 //!
 //! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of
@@ -135,6 +140,8 @@
 //! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and
 //! [`OsStr::from_encoded_bytes_unchecked`].
 //!
+//! For basic string processing, see [`OsStr::slice_encoded_bytes`].
+//!
 //! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value
 //! [Unicode code point]: https://www.unicode.org/glossary/#code_point
 //! [`env::set_var()`]: crate::env::set_var "env::set_var"
diff --git a/library/std/src/ffi/os_str.rs b/library/std/src/ffi/os_str.rs
index 81973182148..28747ad8f34 100644
--- a/library/std/src/ffi/os_str.rs
+++ b/library/std/src/ffi/os_str.rs
@@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher};
 use crate::ops::{self, Range};
 use crate::rc::Rc;
 use crate::slice;
-use crate::str::{from_utf8 as str_from_utf8, FromStr};
+use crate::str::FromStr;
 use crate::sync::Arc;
 
 use crate::sys::os_str::{Buf, Slice};
@@ -997,42 +997,15 @@ impl OsStr {
     /// ```
     #[unstable(feature = "os_str_slice", issue = "118485")]
     pub fn slice_encoded_bytes<R: ops::RangeBounds<usize>>(&self, range: R) -> &Self {
-        #[track_caller]
-        fn check_valid_boundary(bytes: &[u8], index: usize) {
-            if index == 0 || index == bytes.len() {
-                return;
-            }
-
-            // Fast path
-            if bytes[index - 1].is_ascii() || bytes[index].is_ascii() {
-                return;
-            }
-
-            let (before, after) = bytes.split_at(index);
-
-            // UTF-8 takes at most 4 bytes per codepoint, so we don't
-            // need to check more than that.
-            let after = after.get(..4).unwrap_or(after);
-            match str_from_utf8(after) {
-                Ok(_) => return,
-                Err(err) if err.valid_up_to() != 0 => return,
-                Err(_) => (),
-            }
-
-            for len in 2..=4.min(index) {
-                let before = &before[index - len..];
-                if str_from_utf8(before).is_ok() {
-                    return;
-                }
-            }
-
-            panic!("byte index {index} is not an OsStr boundary");
-        }
-
         let encoded_bytes = self.as_encoded_bytes();
         let Range { start, end } = slice::range(range, ..encoded_bytes.len());
-        check_valid_boundary(encoded_bytes, start);
-        check_valid_boundary(encoded_bytes, end);
+
+        // `check_public_boundary` should panic if the index does not lie on an
+        // `OsStr` boundary as described above. It's possible to do this in an
+        // encoding-agnostic way, but details of the internal encoding might
+        // permit a more efficient implementation.
+        self.inner.check_public_boundary(start);
+        self.inner.check_public_boundary(end);
 
         // SAFETY: `slice::range` ensures that `start` and `end` are valid
         let slice = unsafe { encoded_bytes.get_unchecked(start..end) };
diff --git a/library/std/src/ffi/os_str/tests.rs b/library/std/src/ffi/os_str/tests.rs
index 60cde376d32..b020e05eaab 100644
--- a/library/std/src/ffi/os_str/tests.rs
+++ b/library/std/src/ffi/os_str/tests.rs
@@ -194,15 +194,65 @@ fn slice_encoded_bytes() {
 }
 
 #[test]
-#[should_panic(expected = "byte index 2 is not an OsStr boundary")]
+#[should_panic]
+fn slice_out_of_bounds() {
+    let crab = OsStr::new("🦀");
+    let _ = crab.slice_encoded_bytes(..5);
+}
+
+#[test]
+#[should_panic]
 fn slice_mid_char() {
     let crab = OsStr::new("🦀");
     let _ = crab.slice_encoded_bytes(..2);
 }
 
+#[cfg(unix)]
+#[test]
+#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
+fn slice_invalid_data() {
+    use crate::os::unix::ffi::OsStrExt;
+
+    let os_string = OsStr::from_bytes(b"\xFF\xFF");
+    let _ = os_string.slice_encoded_bytes(1..);
+}
+
+#[cfg(unix)]
+#[test]
+#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
+fn slice_partial_utf8() {
+    use crate::os::unix::ffi::{OsStrExt, OsStringExt};
+
+    let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]);
+    let mut os_string = OsString::from_vec(vec![0xFF]);
+    os_string.push(part_crab);
+    let _ = os_string.slice_encoded_bytes(1..);
+}
+
+#[cfg(unix)]
+#[test]
+fn slice_invalid_edge() {
+    use crate::os::unix::ffi::{OsStrExt, OsStringExt};
+
+    let os_string = OsStr::from_bytes(b"a\xFFa");
+    assert_eq!(os_string.slice_encoded_bytes(..1), "a");
+    assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa"));
+    assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF"));
+    assert_eq!(os_string.slice_encoded_bytes(2..), "a");
+
+    let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]);
+    assert_eq!(os_string.slice_encoded_bytes(..3), "abc");
+    assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6"));
+
+    let mut os_string = OsString::from_vec(vec![0xFF]);
+    os_string.push("🦀");
+    assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF"));
+    assert_eq!(os_string.slice_encoded_bytes(1..), "🦀");
+}
+
 #[cfg(windows)]
 #[test]
-#[should_panic(expected = "byte index 3 is not an OsStr boundary")]
+#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
 fn slice_between_surrogates() {
     use crate::os::windows::ffi::OsStringExt;
 
@@ -216,10 +266,14 @@ fn slice_between_surrogates() {
 fn slice_surrogate_edge() {
     use crate::os::windows::ffi::OsStringExt;
 
-    let os_string = OsString::from_wide(&[0xD800]);
-    let mut with_crab = os_string.clone();
-    with_crab.push("🦀");
+    let surrogate = OsString::from_wide(&[0xD800]);
+    let mut pre_crab = surrogate.clone();
+    pre_crab.push("🦀");
+    assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate);
+    assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀");
 
-    assert_eq!(with_crab.slice_encoded_bytes(..3), os_string);
-    assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀");
+    let mut post_crab = OsString::from("🦀");
+    post_crab.push(&surrogate);
+    assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀");
+    assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate);
 }
diff --git a/library/std/src/sys/os_str/bytes.rs b/library/std/src/sys/os_str/bytes.rs
index 3a75ce9ebb7..4ca3f1cd185 100644
--- a/library/std/src/sys/os_str/bytes.rs
+++ b/library/std/src/sys/os_str/bytes.rs
@@ -211,6 +211,49 @@ impl Slice {
         unsafe { mem::transmute(s) }
     }
 
+    #[track_caller]
+    #[inline]
+    pub fn check_public_boundary(&self, index: usize) {
+        if index == 0 || index == self.inner.len() {
+            return;
+        }
+        if index < self.inner.len()
+            && (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii())
+        {
+            return;
+        }
+
+        slow_path(&self.inner, index);
+
+        /// We're betting that typical splits will involve an ASCII character.
+        ///
+        /// Putting the expensive checks in a separate function generates notably
+        /// better assembly.
+        #[track_caller]
+        #[inline(never)]
+        fn slow_path(bytes: &[u8], index: usize) {
+            let (before, after) = bytes.split_at(index);
+
+            // UTF-8 takes at most 4 bytes per codepoint, so we don't
+            // need to check more than that.
+            let after = after.get(..4).unwrap_or(after);
+            match str::from_utf8(after) {
+                Ok(_) => return,
+                Err(err) if err.valid_up_to() != 0 => return,
+                Err(_) => (),
+            }
+
+            for len in 2..=4.min(index) {
+                let before = &before[index - len..];
+                if str::from_utf8(before).is_ok() {
+                    return;
+                }
+            }
+
+            panic!("byte index {index} is not an OsStr boundary");
+        }
+    }
+
     #[inline]
     pub fn from_str(s: &str) -> &Slice {
         unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) }
diff --git a/library/std/src/sys/os_str/wtf8.rs b/library/std/src/sys/os_str/wtf8.rs
index 237854fac4e..352bd735903 100644
--- a/library/std/src/sys/os_str/wtf8.rs
+++ b/library/std/src/sys/os_str/wtf8.rs
@@ -6,7 +6,7 @@ use crate::fmt;
 use crate::mem;
 use crate::rc::Rc;
 use crate::sync::Arc;
-use crate::sys_common::wtf8::{Wtf8, Wtf8Buf};
+use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf};
 use crate::sys_common::{AsInner, FromInner, IntoInner};
 
 #[derive(Clone, Hash)]
@@ -171,6 +171,11 @@ impl Slice {
         mem::transmute(Wtf8::from_bytes_unchecked(s))
     }
 
+    #[track_caller]
+    pub fn check_public_boundary(&self, index: usize) {
+        check_utf8_boundary(&self.inner, index);
+    }
+
     #[inline]
     pub fn from_str(s: &str) -> &Slice {
         unsafe { mem::transmute(Wtf8::from_str(s)) }
diff --git a/library/std/src/sys_common/wtf8.rs b/library/std/src/sys_common/wtf8.rs
index 67db5ebd89c..2dbd19d7171 100644
--- a/library/std/src/sys_common/wtf8.rs
+++ b/library/std/src/sys_common/wtf8.rs
@@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
     unsafe { char::from_u32_unchecked(code_point) }
 }
 
-/// Copied from core::str::StrPrelude::is_char_boundary
+/// Copied from str::is_char_boundary
 #[inline]
 pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {
-    if index == slice.len() {
+    if index == 0 {
         return true;
     }
     match slice.bytes.get(index) {
-        None => false,
-        Some(&b) => b < 128 || b >= 192,
+        None => index == slice.len(),
+        Some(&b) => (b as i8) >= -0x40,
+    }
+}
+
+/// Verify that `index` is at the edge of either a valid UTF-8 codepoint
+/// (i.e. a codepoint that's not a surrogate) or of the whole string.
+///
+/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`.
+/// Splitting between surrogates is valid as far as WTF-8 is concerned, but
+/// we do not permit it in the public API because WTF-8 is considered an
+/// implementation detail.
+#[track_caller]
+#[inline]
+pub fn check_utf8_boundary(slice: &Wtf8, index: usize) {
+    if index == 0 {
+        return;
+    }
+    match slice.bytes.get(index) {
+        Some(0xED) => (), // Might be a surrogate
+        Some(&b) if (b as i8) >= -0x40 => return,
+        Some(_) => panic!("byte index {index} is not a codepoint boundary"),
+        None if index == slice.len() => return,
+        None => panic!("byte index {index} is out of bounds"),
+    }
+    if slice.bytes[index + 1] >= 0xA0 {
+        // There's a surrogate after index. Now check before index.
+        if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 {
+            panic!("byte index {index} lies between surrogate codepoints");
+        }
     }
 }
 
diff --git a/library/std/src/sys_common/wtf8/tests.rs b/library/std/src/sys_common/wtf8/tests.rs
index 28a426648e5..6a1cc41a8fb 100644
--- a/library/std/src/sys_common/wtf8/tests.rs
+++ b/library/std/src/sys_common/wtf8/tests.rs
@@ -663,3 +663,65 @@ fn wtf8_to_owned() {
     assert_eq!(string.bytes, b"\xED\xA0\x80");
     assert!(!string.is_known_utf8);
 }
+
+#[test]
+fn wtf8_valid_utf8_boundaries() {
+    let mut string = Wtf8Buf::from_str("aé 💩");
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    check_utf8_boundary(&string, 0);
+    check_utf8_boundary(&string, 1);
+    check_utf8_boundary(&string, 3);
+    check_utf8_boundary(&string, 4);
+    check_utf8_boundary(&string, 8);
+    check_utf8_boundary(&string, 14);
+    assert_eq!(string.len(), 14);
+
+    string.push_char('a');
+    check_utf8_boundary(&string, 14);
+    check_utf8_boundary(&string, 15);
+
+    let mut string = Wtf8Buf::from_str("a");
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    check_utf8_boundary(&string, 1);
+
+    let mut string = Wtf8Buf::from_str("\u{D7FF}");
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    check_utf8_boundary(&string, 3);
+
+    let mut string = Wtf8Buf::new();
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    string.push_char('\u{D7FF}');
+    check_utf8_boundary(&string, 3);
+}
+
+#[test]
+#[should_panic(expected = "byte index 4 is out of bounds")]
+fn wtf8_utf8_boundary_out_of_bounds() {
+    let string = Wtf8::from_str("aé");
+    check_utf8_boundary(&string, 4);
+}
+
+#[test]
+#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
+fn wtf8_utf8_boundary_inside_codepoint() {
+    let string = Wtf8::from_str("é");
+    check_utf8_boundary(&string, 1);
+}
+
+#[test]
+#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
+fn wtf8_utf8_boundary_inside_surrogate() {
+    let mut string = Wtf8Buf::new();
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    check_utf8_boundary(&string, 1);
+}
+
+#[test]
+#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
+fn wtf8_utf8_boundary_between_surrogates() {
+    let mut string = Wtf8Buf::new();
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    string.push(CodePoint::from_u32(0xD800).unwrap());
+    check_utf8_boundary(&string, 3);
+}