about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDylan DPC <99973273+Dylan-DPC@users.noreply.github.com>2023-05-23 00:32:18 +0530
committerGitHub <noreply@github.com>2023-05-23 00:32:18 +0530
commit47fe1a3e1f625f0e15baeb742e476977de26e6f3 (patch)
tree532be334beb2ba89c84449d1a02ce8b54e46cca0
parentdf86200965c89d31063475ae8106637f66fdf889 (diff)
parent77481099ca4be8fd8515d5e5bb873f3674cb7a2b (diff)
downloadrust-47fe1a3e1f625f0e15baeb742e476977de26e6f3.tar.gz
rust-47fe1a3e1f625f0e15baeb742e476977de26e6f3.zip
Rollup merge of #111609 - LegionMammal978:internal-unsafe, r=thomcc
Mark internal functions and traits unsafe to reflect preconditions

No semantics are changed in this PR; I only mark some functions and and a trait `unsafe` which already had implicit preconditions. Although it seems somewhat redundant for `numfmt::Part::Copy` to contain a `&[u8]` instead of a `&str`, given that all of its current consumers ultimately expect valid UTF-8. Is the type also intended to work for byte-slice formatting in the future?
-rw-r--r--library/alloc/src/vec/in_place_collect.rs9
-rw-r--r--library/core/src/fmt/float.rs12
-rw-r--r--library/core/src/fmt/mod.rs42
-rw-r--r--library/core/src/fmt/num.rs15
-rw-r--r--library/std/src/path.rs11
5 files changed, 57 insertions, 32 deletions
diff --git a/library/alloc/src/vec/in_place_collect.rs b/library/alloc/src/vec/in_place_collect.rs
index 2f1ee8b0353..5ecd0479971 100644
--- a/library/alloc/src/vec/in_place_collect.rs
+++ b/library/alloc/src/vec/in_place_collect.rs
@@ -178,7 +178,8 @@ where
             )
         };
 
-        let len = SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end);
+        // SAFETY: `dst_buf` and `dst_end` are the start and end of the buffer.
+        let len = unsafe { SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end) };
 
         let src = unsafe { iterator.as_inner().as_into_iter() };
         // check if SourceIter contract was upheld
@@ -239,7 +240,7 @@ trait SpecInPlaceCollect<T, I>: Iterator<Item = T> {
     /// `Iterator::__iterator_get_unchecked` calls with a `TrustedRandomAccessNoCoerce` bound
     /// on `I` which means the caller of this method must take the safety conditions
     /// of that trait into consideration.
-    fn collect_in_place(&mut self, dst: *mut T, end: *const T) -> usize;
+    unsafe fn collect_in_place(&mut self, dst: *mut T, end: *const T) -> usize;
 }
 
 impl<T, I> SpecInPlaceCollect<T, I> for I
@@ -247,7 +248,7 @@ where
     I: Iterator<Item = T>,
 {
     #[inline]
-    default fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
+    default unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
         // use try-fold since
         // - it vectorizes better for some iterator adapters
         // - unlike most internal iteration methods, it only takes a &mut self
@@ -265,7 +266,7 @@ where
     I: Iterator<Item = T> + TrustedRandomAccessNoCoerce,
 {
     #[inline]
-    fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
+    unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
         let len = self.size();
         let mut drop_guard = InPlaceDrop { inner: dst_buf, dst: dst_buf };
         for i in 0..len {
diff --git a/library/core/src/fmt/float.rs b/library/core/src/fmt/float.rs
index 89d5fac30d3..3bbf5d8770b 100644
--- a/library/core/src/fmt/float.rs
+++ b/library/core/src/fmt/float.rs
@@ -45,7 +45,8 @@ where
         &mut buf,
         &mut parts,
     );
-    fmt.pad_formatted_parts(&formatted)
+    // SAFETY: `to_exact_fixed_str` and `format_exact` produce only ASCII characters.
+    unsafe { fmt.pad_formatted_parts(&formatted) }
 }
 
 // Don't inline this so callers that call both this and the above won't wind
@@ -71,7 +72,8 @@ where
         &mut buf,
         &mut parts,
     );
-    fmt.pad_formatted_parts(&formatted)
+    // SAFETY: `to_shortest_str` and `format_shortest` produce only ASCII characters.
+    unsafe { fmt.pad_formatted_parts(&formatted) }
 }
 
 fn float_to_decimal_display<T>(fmt: &mut Formatter<'_>, num: &T) -> Result
@@ -116,7 +118,8 @@ where
         &mut buf,
         &mut parts,
     );
-    fmt.pad_formatted_parts(&formatted)
+    // SAFETY: `to_exact_exp_str` and `format_exact` produce only ASCII characters.
+    unsafe { fmt.pad_formatted_parts(&formatted) }
 }
 
 // Don't inline this so callers that call both this and the above won't wind
@@ -143,7 +146,8 @@ where
         &mut buf,
         &mut parts,
     );
-    fmt.pad_formatted_parts(&formatted)
+    // SAFETY: `to_shortest_exp_str` and `format_shortest` produce only ASCII characters.
+    unsafe { fmt.pad_formatted_parts(&formatted) }
 }
 
 // Common code of floating point LowerExp and UpperExp.
diff --git a/library/core/src/fmt/mod.rs b/library/core/src/fmt/mod.rs
index bac2f31878b..1786b309c5b 100644
--- a/library/core/src/fmt/mod.rs
+++ b/library/core/src/fmt/mod.rs
@@ -1415,7 +1415,11 @@ impl<'a> Formatter<'a> {
     /// Takes the formatted parts and applies the padding.
     /// Assumes that the caller already has rendered the parts with required precision,
     /// so that `self.precision` can be ignored.
-    fn pad_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
+    ///
+    /// # Safety
+    ///
+    /// Any `numfmt::Part::Copy` parts in `formatted` must contain valid UTF-8.
+    unsafe fn pad_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
         if let Some(mut width) = self.width {
             // for the sign-aware zero padding, we render the sign first and
             // behave as if we had no sign from the beginning.
@@ -1438,10 +1442,14 @@ impl<'a> Formatter<'a> {
             let len = formatted.len();
             let ret = if width <= len {
                 // no padding
-                self.write_formatted_parts(&formatted)
+                // SAFETY: Per the precondition.
+                unsafe { self.write_formatted_parts(&formatted) }
             } else {
                 let post_padding = self.padding(width - len, Alignment::Right)?;
-                self.write_formatted_parts(&formatted)?;
+                // SAFETY: Per the precondition.
+                unsafe {
+                    self.write_formatted_parts(&formatted)?;
+                }
                 post_padding.write(self)
             };
             self.fill = old_fill;
@@ -1449,20 +1457,20 @@ impl<'a> Formatter<'a> {
             ret
         } else {
             // this is the common case and we take a shortcut
-            self.write_formatted_parts(formatted)
+            // SAFETY: Per the precondition.
+            unsafe { self.write_formatted_parts(formatted) }
         }
     }
 
-    fn write_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
-        fn write_bytes(buf: &mut dyn Write, s: &[u8]) -> Result {
+    /// # Safety
+    ///
+    /// Any `numfmt::Part::Copy` parts in `formatted` must contain valid UTF-8.
+    unsafe fn write_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
+        unsafe fn write_bytes(buf: &mut dyn Write, s: &[u8]) -> Result {
             // SAFETY: This is used for `numfmt::Part::Num` and `numfmt::Part::Copy`.
             // It's safe to use for `numfmt::Part::Num` since every char `c` is between
-            // `b'0'` and `b'9'`, which means `s` is valid UTF-8.
-            // It's also probably safe in practice to use for `numfmt::Part::Copy(buf)`
-            // since `buf` should be plain ASCII, but it's possible for someone to pass
-            // in a bad value for `buf` into `numfmt::to_shortest_str` since it is a
-            // public function.
-            // FIXME: Determine whether this could result in UB.
+            // `b'0'` and `b'9'`, which means `s` is valid UTF-8. It's safe to use for
+            // `numfmt::Part::Copy` due to this function's precondition.
             buf.write_str(unsafe { str::from_utf8_unchecked(s) })
         }
 
@@ -1489,11 +1497,15 @@ impl<'a> Formatter<'a> {
                         *c = b'0' + (v % 10) as u8;
                         v /= 10;
                     }
-                    write_bytes(self.buf, &s[..len])?;
+                    // SAFETY: Per the precondition.
+                    unsafe {
+                        write_bytes(self.buf, &s[..len])?;
+                    }
                 }
-                numfmt::Part::Copy(buf) => {
+                // SAFETY: Per the precondition.
+                numfmt::Part::Copy(buf) => unsafe {
                     write_bytes(self.buf, buf)?;
-                }
+                },
             }
         }
         Ok(())
diff --git a/library/core/src/fmt/num.rs b/library/core/src/fmt/num.rs
index d8365ae9bf9..4f42f73ebba 100644
--- a/library/core/src/fmt/num.rs
+++ b/library/core/src/fmt/num.rs
@@ -52,8 +52,12 @@ impl_int! { i8 i16 i32 i64 i128 isize }
 impl_uint! { u8 u16 u32 u64 u128 usize }
 
 /// A type that represents a specific radix
+///
+/// # Safety
+///
+/// `digit` must return an ASCII character.
 #[doc(hidden)]
-trait GenericRadix: Sized {
+unsafe trait GenericRadix: Sized {
     /// The number of digits.
     const BASE: u8;
 
@@ -129,7 +133,7 @@ struct UpperHex;
 
 macro_rules! radix {
     ($T:ident, $base:expr, $prefix:expr, $($x:pat => $conv:expr),+) => {
-        impl GenericRadix for $T {
+        unsafe impl GenericRadix for $T {
             const BASE: u8 = $base;
             const PREFIX: &'static str = $prefix;
             fn digit(x: u8) -> u8 {
@@ -407,7 +411,7 @@ macro_rules! impl_Exp {
             let parts = &[
                 numfmt::Part::Copy(buf_slice),
                 numfmt::Part::Zero(added_precision),
-                numfmt::Part::Copy(exp_slice)
+                numfmt::Part::Copy(exp_slice),
             ];
             let sign = if !is_nonnegative {
                 "-"
@@ -416,8 +420,9 @@ macro_rules! impl_Exp {
             } else {
                 ""
             };
-            let formatted = numfmt::Formatted{sign, parts};
-            f.pad_formatted_parts(&formatted)
+            let formatted = numfmt::Formatted { sign, parts };
+            // SAFETY: `buf_slice` and `exp_slice` contain only ASCII characters.
+            unsafe { f.pad_formatted_parts(&formatted) }
         }
 
         $(
diff --git a/library/std/src/path.rs b/library/std/src/path.rs
index 43203c5824d..febdeb51463 100644
--- a/library/std/src/path.rs
+++ b/library/std/src/path.rs
@@ -733,8 +733,9 @@ impl<'a> Components<'a> {
         }
     }
 
-    // parse a given byte sequence into the corresponding path component
-    fn parse_single_component<'b>(&self, comp: &'b [u8]) -> Option<Component<'b>> {
+    // parse a given byte sequence following the OsStr encoding into the
+    // corresponding path component
+    unsafe fn parse_single_component<'b>(&self, comp: &'b [u8]) -> Option<Component<'b>> {
         match comp {
             b"." if self.prefix_verbatim() => Some(Component::CurDir),
             b"." => None, // . components are normalized away, except at
@@ -754,7 +755,8 @@ impl<'a> Components<'a> {
             None => (0, self.path),
             Some(i) => (1, &self.path[..i]),
         };
-        (comp.len() + extra, self.parse_single_component(comp))
+        // SAFETY: `comp` is a valid substring, since it is split on a separator.
+        (comp.len() + extra, unsafe { self.parse_single_component(comp) })
     }
 
     // parse a component from the right, saying how many bytes to consume to
@@ -766,7 +768,8 @@ impl<'a> Components<'a> {
             None => (0, &self.path[start..]),
             Some(i) => (1, &self.path[start + i + 1..]),
         };
-        (comp.len() + extra, self.parse_single_component(comp))
+        // SAFETY: `comp` is a valid substring, since it is split on a separator.
+        (comp.len() + extra, unsafe { self.parse_single_component(comp) })
     }
 
     // trim away repeated separators (i.e., empty components) on the left