diff options
| author | Matthias Krüger <matthias.krueger@famsik.de> | 2024-02-18 18:54:32 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-18 18:54:32 +0100 |
| commit | 99560a428a8713c371ac2ef8f46a4d0aec3605c3 (patch) | |
| tree | 6839174b5f86104678fe93494804d2c224bdc22b | |
| parent | 8b21296b5db6d5724d6b8440dcf459fa82fd88b5 (diff) | |
| parent | 51a7396ad3d78d9326ee1537b9ff29ab3919556f (diff) | |
| download | rust-99560a428a8713c371ac2ef8f46a4d0aec3605c3.tar.gz rust-99560a428a8713c371ac2ef8f46a4d0aec3605c3.zip | |
Rollup merge of #118569 - blyxxyz:platform-os-str-slice, r=Mark-Simulacrum
Move `OsStr::slice_encoded_bytes` validation to platform modules This delegates OS string slicing (`OsStr::slice_encoded_bytes`) validation to the underlying platform implementation. For now that results in increased performance and better error messages on Windows without any changes to semantics. In the future we may want to provide different semantics for different platforms. The existing implementation is still used on Unix and most other platforms and is now optimized a little better. Tracking issue: https://github.com/rust-lang/rust/issues/118485 cc `@epage,` `@BurntSushi`
| -rw-r--r-- | library/std/src/ffi/mod.rs | 7 | ||||
| -rw-r--r-- | library/std/src/ffi/os_str.rs | 43 | ||||
| -rw-r--r-- | library/std/src/ffi/os_str/tests.rs | 68 | ||||
| -rw-r--r-- | library/std/src/sys/os_str/bytes.rs | 43 | ||||
| -rw-r--r-- | library/std/src/sys/os_str/wtf8.rs | 7 | ||||
| -rw-r--r-- | library/std/src/sys_common/wtf8.rs | 36 | ||||
| -rw-r--r-- | library/std/src/sys_common/wtf8/tests.rs | 62 |
7 files changed, 219 insertions, 47 deletions
diff --git a/library/std/src/ffi/mod.rs b/library/std/src/ffi/mod.rs index 97e78d17786..818571ddaaa 100644 --- a/library/std/src/ffi/mod.rs +++ b/library/std/src/ffi/mod.rs @@ -127,6 +127,11 @@ //! trait, which provides a [`from_wide`] method to convert a native Windows //! string (without the terminating nul character) to an [`OsString`]. //! +//! ## Other platforms +//! +//! Many other platforms provide their own extension traits in a +//! `std::os::*::ffi` module. +//! //! ## On all platforms //! //! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of @@ -135,6 +140,8 @@ //! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and //! [`OsStr::from_encoded_bytes_unchecked`]. //! +//! For basic string processing, see [`OsStr::slice_encoded_bytes`]. +//! //! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value //! [Unicode code point]: https://www.unicode.org/glossary/#code_point //! [`env::set_var()`]: crate::env::set_var "env::set_var" diff --git a/library/std/src/ffi/os_str.rs b/library/std/src/ffi/os_str.rs index 81973182148..28747ad8f34 100644 --- a/library/std/src/ffi/os_str.rs +++ b/library/std/src/ffi/os_str.rs @@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher}; use crate::ops::{self, Range}; use crate::rc::Rc; use crate::slice; -use crate::str::{from_utf8 as str_from_utf8, FromStr}; +use crate::str::FromStr; use crate::sync::Arc; use crate::sys::os_str::{Buf, Slice}; @@ -997,42 +997,15 @@ impl OsStr { /// ``` #[unstable(feature = "os_str_slice", issue = "118485")] pub fn slice_encoded_bytes<R: ops::RangeBounds<usize>>(&self, range: R) -> &Self { - #[track_caller] - fn check_valid_boundary(bytes: &[u8], index: usize) { - if index == 0 || index == bytes.len() { - return; - } - - // Fast path - if bytes[index - 1].is_ascii() || bytes[index].is_ascii() { - return; - } - - let (before, after) = bytes.split_at(index); - - // UTF-8 takes at most 4 bytes per codepoint, so we don't - // need to check more than that. - let after = after.get(..4).unwrap_or(after); - match str_from_utf8(after) { - Ok(_) => return, - Err(err) if err.valid_up_to() != 0 => return, - Err(_) => (), - } - - for len in 2..=4.min(index) { - let before = &before[index - len..]; - if str_from_utf8(before).is_ok() { - return; - } - } - - panic!("byte index {index} is not an OsStr boundary"); - } - let encoded_bytes = self.as_encoded_bytes(); let Range { start, end } = slice::range(range, ..encoded_bytes.len()); - check_valid_boundary(encoded_bytes, start); - check_valid_boundary(encoded_bytes, end); + + // `check_public_boundary` should panic if the index does not lie on an + // `OsStr` boundary as described above. It's possible to do this in an + // encoding-agnostic way, but details of the internal encoding might + // permit a more efficient implementation. + self.inner.check_public_boundary(start); + self.inner.check_public_boundary(end); // SAFETY: `slice::range` ensures that `start` and `end` are valid let slice = unsafe { encoded_bytes.get_unchecked(start..end) }; diff --git a/library/std/src/ffi/os_str/tests.rs b/library/std/src/ffi/os_str/tests.rs index 60cde376d32..b020e05eaab 100644 --- a/library/std/src/ffi/os_str/tests.rs +++ b/library/std/src/ffi/os_str/tests.rs @@ -194,15 +194,65 @@ fn slice_encoded_bytes() { } #[test] -#[should_panic(expected = "byte index 2 is not an OsStr boundary")] +#[should_panic] +fn slice_out_of_bounds() { + let crab = OsStr::new("🦀"); + let _ = crab.slice_encoded_bytes(..5); +} + +#[test] +#[should_panic] fn slice_mid_char() { let crab = OsStr::new("🦀"); let _ = crab.slice_encoded_bytes(..2); } +#[cfg(unix)] +#[test] +#[should_panic(expected = "byte index 1 is not an OsStr boundary")] +fn slice_invalid_data() { + use crate::os::unix::ffi::OsStrExt; + + let os_string = OsStr::from_bytes(b"\xFF\xFF"); + let _ = os_string.slice_encoded_bytes(1..); +} + +#[cfg(unix)] +#[test] +#[should_panic(expected = "byte index 1 is not an OsStr boundary")] +fn slice_partial_utf8() { + use crate::os::unix::ffi::{OsStrExt, OsStringExt}; + + let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]); + let mut os_string = OsString::from_vec(vec![0xFF]); + os_string.push(part_crab); + let _ = os_string.slice_encoded_bytes(1..); +} + +#[cfg(unix)] +#[test] +fn slice_invalid_edge() { + use crate::os::unix::ffi::{OsStrExt, OsStringExt}; + + let os_string = OsStr::from_bytes(b"a\xFFa"); + assert_eq!(os_string.slice_encoded_bytes(..1), "a"); + assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa")); + assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF")); + assert_eq!(os_string.slice_encoded_bytes(2..), "a"); + + let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]); + assert_eq!(os_string.slice_encoded_bytes(..3), "abc"); + assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6")); + + let mut os_string = OsString::from_vec(vec![0xFF]); + os_string.push("🦀"); + assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF")); + assert_eq!(os_string.slice_encoded_bytes(1..), "🦀"); +} + #[cfg(windows)] #[test] -#[should_panic(expected = "byte index 3 is not an OsStr boundary")] +#[should_panic(expected = "byte index 3 lies between surrogate codepoints")] fn slice_between_surrogates() { use crate::os::windows::ffi::OsStringExt; @@ -216,10 +266,14 @@ fn slice_between_surrogates() { fn slice_surrogate_edge() { use crate::os::windows::ffi::OsStringExt; - let os_string = OsString::from_wide(&[0xD800]); - let mut with_crab = os_string.clone(); - with_crab.push("🦀"); + let surrogate = OsString::from_wide(&[0xD800]); + let mut pre_crab = surrogate.clone(); + pre_crab.push("🦀"); + assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate); + assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀"); - assert_eq!(with_crab.slice_encoded_bytes(..3), os_string); - assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀"); + let mut post_crab = OsString::from("🦀"); + post_crab.push(&surrogate); + assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀"); + assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate); } diff --git a/library/std/src/sys/os_str/bytes.rs b/library/std/src/sys/os_str/bytes.rs index 3a75ce9ebb7..4ca3f1cd185 100644 --- a/library/std/src/sys/os_str/bytes.rs +++ b/library/std/src/sys/os_str/bytes.rs @@ -211,6 +211,49 @@ impl Slice { unsafe { mem::transmute(s) } } + #[track_caller] + #[inline] + pub fn check_public_boundary(&self, index: usize) { + if index == 0 || index == self.inner.len() { + return; + } + if index < self.inner.len() + && (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii()) + { + return; + } + + slow_path(&self.inner, index); + + /// We're betting that typical splits will involve an ASCII character. + /// + /// Putting the expensive checks in a separate function generates notably + /// better assembly. + #[track_caller] + #[inline(never)] + fn slow_path(bytes: &[u8], index: usize) { + let (before, after) = bytes.split_at(index); + + // UTF-8 takes at most 4 bytes per codepoint, so we don't + // need to check more than that. + let after = after.get(..4).unwrap_or(after); + match str::from_utf8(after) { + Ok(_) => return, + Err(err) if err.valid_up_to() != 0 => return, + Err(_) => (), + } + + for len in 2..=4.min(index) { + let before = &before[index - len..]; + if str::from_utf8(before).is_ok() { + return; + } + } + + panic!("byte index {index} is not an OsStr boundary"); + } + } + #[inline] pub fn from_str(s: &str) -> &Slice { unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) } diff --git a/library/std/src/sys/os_str/wtf8.rs b/library/std/src/sys/os_str/wtf8.rs index 237854fac4e..352bd735903 100644 --- a/library/std/src/sys/os_str/wtf8.rs +++ b/library/std/src/sys/os_str/wtf8.rs @@ -6,7 +6,7 @@ use crate::fmt; use crate::mem; use crate::rc::Rc; use crate::sync::Arc; -use crate::sys_common::wtf8::{Wtf8, Wtf8Buf}; +use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf}; use crate::sys_common::{AsInner, FromInner, IntoInner}; #[derive(Clone, Hash)] @@ -171,6 +171,11 @@ impl Slice { mem::transmute(Wtf8::from_bytes_unchecked(s)) } + #[track_caller] + pub fn check_public_boundary(&self, index: usize) { + check_utf8_boundary(&self.inner, index); + } + #[inline] pub fn from_str(s: &str) -> &Slice { unsafe { mem::transmute(Wtf8::from_str(s)) } diff --git a/library/std/src/sys_common/wtf8.rs b/library/std/src/sys_common/wtf8.rs index 67db5ebd89c..2dbd19d7171 100644 --- a/library/std/src/sys_common/wtf8.rs +++ b/library/std/src/sys_common/wtf8.rs @@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char { unsafe { char::from_u32_unchecked(code_point) } } -/// Copied from core::str::StrPrelude::is_char_boundary +/// Copied from str::is_char_boundary #[inline] pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { - if index == slice.len() { + if index == 0 { return true; } match slice.bytes.get(index) { - None => false, - Some(&b) => b < 128 || b >= 192, + None => index == slice.len(), + Some(&b) => (b as i8) >= -0x40, + } +} + +/// Verify that `index` is at the edge of either a valid UTF-8 codepoint +/// (i.e. a codepoint that's not a surrogate) or of the whole string. +/// +/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`. +/// Splitting between surrogates is valid as far as WTF-8 is concerned, but +/// we do not permit it in the public API because WTF-8 is considered an +/// implementation detail. +#[track_caller] +#[inline] +pub fn check_utf8_boundary(slice: &Wtf8, index: usize) { + if index == 0 { + return; + } + match slice.bytes.get(index) { + Some(0xED) => (), // Might be a surrogate + Some(&b) if (b as i8) >= -0x40 => return, + Some(_) => panic!("byte index {index} is not a codepoint boundary"), + None if index == slice.len() => return, + None => panic!("byte index {index} is out of bounds"), + } + if slice.bytes[index + 1] >= 0xA0 { + // There's a surrogate after index. Now check before index. + if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 { + panic!("byte index {index} lies between surrogate codepoints"); + } } } diff --git a/library/std/src/sys_common/wtf8/tests.rs b/library/std/src/sys_common/wtf8/tests.rs index 28a426648e5..6a1cc41a8fb 100644 --- a/library/std/src/sys_common/wtf8/tests.rs +++ b/library/std/src/sys_common/wtf8/tests.rs @@ -663,3 +663,65 @@ fn wtf8_to_owned() { assert_eq!(string.bytes, b"\xED\xA0\x80"); assert!(!string.is_known_utf8); } + +#[test] +fn wtf8_valid_utf8_boundaries() { + let mut string = Wtf8Buf::from_str("aé 💩"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 0); + check_utf8_boundary(&string, 1); + check_utf8_boundary(&string, 3); + check_utf8_boundary(&string, 4); + check_utf8_boundary(&string, 8); + check_utf8_boundary(&string, 14); + assert_eq!(string.len(), 14); + + string.push_char('a'); + check_utf8_boundary(&string, 14); + check_utf8_boundary(&string, 15); + + let mut string = Wtf8Buf::from_str("a"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 1); + + let mut string = Wtf8Buf::from_str("\u{D7FF}"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 3); + + let mut string = Wtf8Buf::new(); + string.push(CodePoint::from_u32(0xD800).unwrap()); + string.push_char('\u{D7FF}'); + check_utf8_boundary(&string, 3); +} + +#[test] +#[should_panic(expected = "byte index 4 is out of bounds")] +fn wtf8_utf8_boundary_out_of_bounds() { + let string = Wtf8::from_str("aé"); + check_utf8_boundary(&string, 4); +} + +#[test] +#[should_panic(expected = "byte index 1 is not a codepoint boundary")] +fn wtf8_utf8_boundary_inside_codepoint() { + let string = Wtf8::from_str("é"); + check_utf8_boundary(&string, 1); +} + +#[test] +#[should_panic(expected = "byte index 1 is not a codepoint boundary")] +fn wtf8_utf8_boundary_inside_surrogate() { + let mut string = Wtf8Buf::new(); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 1); +} + +#[test] +#[should_panic(expected = "byte index 3 lies between surrogate codepoints")] +fn wtf8_utf8_boundary_between_surrogates() { + let mut string = Wtf8Buf::new(); + string.push(CodePoint::from_u32(0xD800).unwrap()); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 3); +} |
