about summary refs log tree commit diff
path: root/library
diff options
context:
space:
mode:
Diffstat (limited to 'library')
-rw-r--r--library/core/src/alloc/layout.rs43
-rw-r--r--library/core/src/mem/valid_align.rs7
-rw-r--r--library/core/tests/alloc.rs44
3 files changed, 84 insertions, 10 deletions
diff --git a/library/core/src/alloc/layout.rs b/library/core/src/alloc/layout.rs
index 59ebe5fbe02..3473ac09e95 100644
--- a/library/core/src/alloc/layout.rs
+++ b/library/core/src/alloc/layout.rs
@@ -72,9 +72,8 @@ impl Layout {
         Layout::from_size_valid_align(size, unsafe { ValidAlign::new_unchecked(align) })
     }
 
-    /// Internal helper constructor to skip revalidating alignment validity.
-    #[inline]
-    const fn from_size_valid_align(size: usize, align: ValidAlign) -> Result<Self, LayoutError> {
+    #[inline(always)]
+    const fn max_size_for_align(align: ValidAlign) -> usize {
         // (power-of-two implies align != 0.)
 
         // Rounded up size is:
@@ -89,7 +88,13 @@ impl Layout {
         //
         // Above implies that checking for summation overflow is both
         // necessary and sufficient.
-        if size > isize::MAX as usize - (align.as_nonzero().get() - 1) {
+        isize::MAX as usize - (align.as_usize() - 1)
+    }
+
+    /// Internal helper constructor to skip revalidating alignment validity.
+    #[inline]
+    const fn from_size_valid_align(size: usize, align: ValidAlign) -> Result<Self, LayoutError> {
+        if size > Self::max_size_for_align(align) {
             return Err(LayoutError);
         }
 
@@ -128,7 +133,7 @@ impl Layout {
                   without modifying the layout"]
     #[inline]
     pub const fn align(&self) -> usize {
-        self.align.as_nonzero().get()
+        self.align.as_usize()
     }
 
     /// Constructs a `Layout` suitable for holding a value of type `T`.
@@ -410,13 +415,33 @@ impl Layout {
 
     /// Creates a layout describing the record for a `[T; n]`.
     ///
-    /// On arithmetic overflow, returns `LayoutError`.
+    /// On arithmetic overflow or when the total size would exceed
+    /// `isize::MAX`, returns `LayoutError`.
     #[stable(feature = "alloc_layout_manipulation", since = "1.44.0")]
     #[inline]
     pub fn array<T>(n: usize) -> Result<Self, LayoutError> {
-        let array_size = mem::size_of::<T>().checked_mul(n).ok_or(LayoutError)?;
-        // The safe constructor is called here to enforce the isize size limit.
-        Layout::from_size_valid_align(array_size, ValidAlign::of::<T>())
+        // Reduce the amount of code we need to monomorphize per `T`.
+        return inner(mem::size_of::<T>(), ValidAlign::of::<T>(), n);
+
+        #[inline]
+        fn inner(element_size: usize, align: ValidAlign, n: usize) -> Result<Layout, LayoutError> {
+            // We need to check two things about the size:
+            //  - That the total size won't overflow a `usize`, and
+            //  - That the total size still fits in an `isize`.
+            // By using division we can check them both with a single threshold.
+            // That'd usually be a bad idea, but thankfully here the element size
+            // and alignment are constants, so the compiler will fold all of it.
+            if element_size != 0 && n > Layout::max_size_for_align(align) / element_size {
+                return Err(LayoutError);
+            }
+
+            let array_size = element_size * n;
+
+            // SAFETY: We just checked above that the `array_size` will not
+            // exceed `isize::MAX` even when rounded up to the alignment.
+            // And `ValidAlign` guarantees it's a power of two.
+            unsafe { Ok(Layout::from_size_align_unchecked(array_size, align.as_usize())) }
+        }
     }
 }
 
diff --git a/library/core/src/mem/valid_align.rs b/library/core/src/mem/valid_align.rs
index 4ce6d13cf90..b9ccc0b4c79 100644
--- a/library/core/src/mem/valid_align.rs
+++ b/library/core/src/mem/valid_align.rs
@@ -36,9 +36,14 @@ impl ValidAlign {
     }
 
     #[inline]
+    pub(crate) const fn as_usize(self) -> usize {
+        self.0 as usize
+    }
+
+    #[inline]
     pub(crate) const fn as_nonzero(self) -> NonZeroUsize {
         // SAFETY: All the discriminants are non-zero.
-        unsafe { NonZeroUsize::new_unchecked(self.0 as usize) }
+        unsafe { NonZeroUsize::new_unchecked(self.as_usize()) }
     }
 
     /// Returns the base 2 logarithm of the alignment.
diff --git a/library/core/tests/alloc.rs b/library/core/tests/alloc.rs
index 8a5a06b3440..3ceaeadcec6 100644
--- a/library/core/tests/alloc.rs
+++ b/library/core/tests/alloc.rs
@@ -1,4 +1,5 @@
 use core::alloc::Layout;
+use core::mem::size_of;
 use core::ptr::{self, NonNull};
 
 #[test]
@@ -13,6 +14,49 @@ fn const_unchecked_layout() {
 }
 
 #[test]
+fn layout_round_up_to_align_edge_cases() {
+    const MAX_SIZE: usize = isize::MAX as usize;
+
+    for shift in 0..usize::BITS {
+        let align = 1_usize << shift;
+        let edge = (MAX_SIZE + 1) - align;
+        let low = edge.saturating_sub(10);
+        let high = edge.saturating_add(10);
+        assert!(Layout::from_size_align(low, align).is_ok());
+        assert!(Layout::from_size_align(high, align).is_err());
+        for size in low..=high {
+            assert_eq!(
+                Layout::from_size_align(size, align).is_ok(),
+                size.next_multiple_of(align) <= MAX_SIZE,
+            );
+        }
+    }
+}
+
+#[test]
+fn layout_array_edge_cases() {
+    for_type::<i64>();
+    for_type::<[i32; 0b10101]>();
+    for_type::<[u8; 0b1010101]>();
+
+    // Make sure ZSTs don't lead to divide-by-zero
+    assert_eq!(Layout::array::<()>(usize::MAX).unwrap(), Layout::from_size_align(0, 1).unwrap());
+
+    fn for_type<T>() {
+        const MAX_SIZE: usize = isize::MAX as usize;
+
+        let edge = (MAX_SIZE + 1) / size_of::<T>();
+        let low = edge.saturating_sub(10);
+        let high = edge.saturating_add(10);
+        assert!(Layout::array::<T>(low).is_ok());
+        assert!(Layout::array::<T>(high).is_err());
+        for n in low..=high {
+            assert_eq!(Layout::array::<T>(n).is_ok(), n * size_of::<T>() <= MAX_SIZE);
+        }
+    }
+}
+
+#[test]
 fn layout_debug_shows_log2_of_alignment() {
     // `Debug` is not stable, but here's what it does right now
     let layout = Layout::from_size_align(24576, 8192).unwrap();