diff options
| author | Alex Saveau <saveau.alexandre@gmail.com> | 2022-02-13 15:04:11 -0800 |
|---|---|---|
| committer | Alex Saveau <saveau.alexandre@gmail.com> | 2022-02-16 18:34:17 -0800 |
| commit | 897c8d0ab9d3cb2cf1c112f31c6ac3e93d9884bc (patch) | |
| tree | c5225865d65254954293e18fd5fd9f3a0e4c66cc | |
| parent | 1e12aef3fab243407f9d71ba9956cb2a1bf105d5 (diff) | |
| download | rust-897c8d0ab9d3cb2cf1c112f31c6ac3e93d9884bc.tar.gz rust-897c8d0ab9d3cb2cf1c112f31c6ac3e93d9884bc.zip | |
Add debug asserts to validate NUL terminator in c strings
Signed-off-by: Alex Saveau <saveau.alexandre@gmail.com>
| -rw-r--r-- | library/std/src/ffi/c_str.rs | 41 | ||||
| -rw-r--r-- | library/std/src/ffi/c_str/tests.rs | 8 |
2 files changed, 29 insertions, 20 deletions
diff --git a/library/std/src/ffi/c_str.rs b/library/std/src/ffi/c_str.rs index c3f024026ef..b52cc97504a 100644 --- a/library/std/src/ffi/c_str.rs +++ b/library/std/src/ffi/c_str.rs @@ -382,7 +382,7 @@ impl CString { let bytes: Vec<u8> = self.into(); match memchr::memchr(0, &bytes) { Some(i) => Err(NulError(i, bytes)), - None => Ok(unsafe { CString::from_vec_unchecked(bytes) }), + None => Ok(unsafe { CString::_from_vec_unchecked(bytes) }), } } } @@ -405,7 +405,7 @@ impl CString { // This allows better optimizations if lto enabled. match memchr::memchr(0, bytes) { Some(i) => Err(NulError(i, buffer)), - None => Ok(unsafe { CString::from_vec_unchecked(buffer) }), + None => Ok(unsafe { CString::_from_vec_unchecked(buffer) }), } } @@ -451,10 +451,15 @@ impl CString { /// ``` #[must_use] #[stable(feature = "rust1", since = "1.0.0")] - pub unsafe fn from_vec_unchecked(mut v: Vec<u8>) -> CString { + pub unsafe fn from_vec_unchecked(v: Vec<u8>) -> Self { + debug_assert!(memchr::memchr(0, &v).is_none()); + unsafe { Self::_from_vec_unchecked(v) } + } + + unsafe fn _from_vec_unchecked(mut v: Vec<u8>) -> Self { v.reserve_exact(1); v.push(0); - CString { inner: v.into_boxed_slice() } + Self { inner: v.into_boxed_slice() } } /// Retakes ownership of a `CString` that was transferred to C via @@ -578,7 +583,7 @@ impl CString { pub fn into_string(self) -> Result<String, IntoStringError> { String::from_utf8(self.into_bytes()).map_err(|e| IntoStringError { error: e.utf8_error(), - inner: unsafe { CString::from_vec_unchecked(e.into_bytes()) }, + inner: unsafe { Self::_from_vec_unchecked(e.into_bytes()) }, }) } @@ -735,6 +740,11 @@ impl CString { #[must_use] #[stable(feature = "cstring_from_vec_with_nul", since = "1.58.0")] pub unsafe fn from_vec_with_nul_unchecked(v: Vec<u8>) -> Self { + debug_assert!(memchr::memchr(0, &v).unwrap() + 1 == v.len()); + unsafe { Self::_from_vec_with_nul_unchecked(v) } + } + + unsafe fn _from_vec_with_nul_unchecked(v: Vec<u8>) -> Self { Self { inner: v.into_boxed_slice() } } @@ -778,7 +788,7 @@ impl CString { Some(nul_pos) if nul_pos + 1 == v.len() => { // SAFETY: We know there is only one nul byte, at the end // of the vec. - Ok(unsafe { Self::from_vec_with_nul_unchecked(v) }) + Ok(unsafe { Self::_from_vec_with_nul_unchecked(v) }) } Some(nul_pos) => Err(FromVecWithNulError { error_kind: FromBytesWithNulErrorKind::InteriorNul(nul_pos), @@ -811,7 +821,7 @@ impl ops::Deref for CString { #[inline] fn deref(&self) -> &CStr { - unsafe { CStr::from_bytes_with_nul_unchecked(self.as_bytes_with_nul()) } + unsafe { CStr::_from_bytes_with_nul_unchecked(self.as_bytes_with_nul()) } } } @@ -922,7 +932,7 @@ impl From<Vec<NonZeroU8>> for CString { }; // SAFETY: `v` cannot contain null bytes, given the type-level // invariant of `NonZeroU8`. - CString::from_vec_unchecked(v) + Self::_from_vec_unchecked(v) } } } @@ -1215,7 +1225,7 @@ impl CStr { unsafe { let len = sys::strlen(ptr); let ptr = ptr as *const u8; - CStr::from_bytes_with_nul_unchecked(slice::from_raw_parts(ptr, len as usize + 1)) + Self::_from_bytes_with_nul_unchecked(slice::from_raw_parts(ptr, len as usize + 1)) } } @@ -1258,7 +1268,7 @@ impl CStr { Some(nul_pos) if nul_pos + 1 == bytes.len() => { // SAFETY: We know there is only one nul byte, at the end // of the byte slice. - Ok(unsafe { Self::from_bytes_with_nul_unchecked(bytes) }) + Ok(unsafe { Self::_from_bytes_with_nul_unchecked(bytes) }) } Some(nul_pos) => Err(FromBytesWithNulError::interior_nul(nul_pos)), None => Err(FromBytesWithNulError::not_nul_terminated()), @@ -1287,12 +1297,19 @@ impl CStr { #[stable(feature = "cstr_from_bytes", since = "1.10.0")] #[rustc_const_stable(feature = "const_cstr_unchecked", since = "1.59.0")] pub const unsafe fn from_bytes_with_nul_unchecked(bytes: &[u8]) -> &CStr { + // We're in a const fn, so this is the best we can do + debug_assert!(!bytes.is_empty() && bytes[bytes.len() - 1] == 0); + unsafe { Self::_from_bytes_with_nul_unchecked(bytes) } + } + + #[inline] + const unsafe fn _from_bytes_with_nul_unchecked(bytes: &[u8]) -> &Self { // SAFETY: Casting to CStr is safe because its internal representation // is a [u8] too (safe only inside std). // Dereferencing the obtained pointer is safe because it comes from a // reference. Making a reference is then safe because its lifetime // is bound by the lifetime of the given `bytes`. - unsafe { &*(bytes as *const [u8] as *const CStr) } + unsafe { &*(bytes as *const [u8] as *const Self) } } /// Returns the inner pointer to this C string. @@ -1555,7 +1572,7 @@ impl ops::Index<ops::RangeFrom<usize>> for CStr { // byte, since otherwise we could get an empty string that doesn't end // in a null. if index.start < bytes.len() { - unsafe { CStr::from_bytes_with_nul_unchecked(&bytes[index.start..]) } + unsafe { CStr::_from_bytes_with_nul_unchecked(&bytes[index.start..]) } } else { panic!( "index out of bounds: the len is {} but the index is {}", diff --git a/library/std/src/ffi/c_str/tests.rs b/library/std/src/ffi/c_str/tests.rs index 4f7ba9ad437..00ba5460821 100644 --- a/library/std/src/ffi/c_str/tests.rs +++ b/library/std/src/ffi/c_str/tests.rs @@ -33,14 +33,6 @@ fn build_with_zero2() { } #[test] -fn build_with_zero3() { - unsafe { - let s = CString::from_vec_unchecked(vec![0]); - assert_eq!(s.as_bytes(), b"\0"); - } -} - -#[test] fn formatted() { let s = CString::new(&b"abc\x01\x02\n\xE2\x80\xA6\xFF"[..]).unwrap(); assert_eq!(format!("{:?}", s), r#""abc\x01\x02\n\xe2\x80\xa6\xff""#); |
