about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/iter/sources/repeat_n.rs60
-rw-r--r--library/core/tests/iter/sources.rs24
2 files changed, 69 insertions, 15 deletions
diff --git a/library/core/src/iter/sources/repeat_n.rs b/library/core/src/iter/sources/repeat_n.rs
index 9c062193363..7e162ff387b 100644
--- a/library/core/src/iter/sources/repeat_n.rs
+++ b/library/core/src/iter/sources/repeat_n.rs
@@ -1,5 +1,6 @@
+use crate::fmt;
 use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator};
-use crate::mem::ManuallyDrop;
+use crate::mem::{self, MaybeUninit};
 use crate::num::NonZero;
 
 /// Creates a new iterator that repeats a single element a given number of times.
@@ -58,14 +59,12 @@ use crate::num::NonZero;
 #[inline]
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
 pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
-    let mut element = ManuallyDrop::new(element);
-
-    if count == 0 {
-        // SAFETY: we definitely haven't dropped it yet, since we only just got
-        // passed it in, and because the count is zero the instance we're about
-        // to create won't drop it, so to avoid leaking we need to now.
-        unsafe { ManuallyDrop::drop(&mut element) };
-    }
+    let element = if count == 0 {
+        // `element` gets dropped eagerly.
+        MaybeUninit::uninit()
+    } else {
+        MaybeUninit::new(element)
+    };
 
     RepeatN { element, count }
 }
@@ -74,15 +73,23 @@ pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
 ///
 /// This `struct` is created by the [`repeat_n()`] function.
 /// See its documentation for more.
-#[derive(Clone, Debug)]
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
 pub struct RepeatN<A> {
     count: usize,
-    // Invariant: has been dropped iff count == 0.
-    element: ManuallyDrop<A>,
+    // Invariant: uninit iff count == 0.
+    element: MaybeUninit<A>,
 }
 
 impl<A> RepeatN<A> {
+    /// Returns the element if it hasn't been dropped already.
+    fn element_ref(&self) -> Option<&A> {
+        if self.count > 0 {
+            // SAFETY: The count is non-zero, so it must be initialized.
+            Some(unsafe { self.element.assume_init_ref() })
+        } else {
+            None
+        }
+    }
     /// If we haven't already dropped the element, return it in an option.
     ///
     /// Clears the count so it won't be dropped again later.
@@ -90,9 +97,10 @@ impl<A> RepeatN<A> {
     fn take_element(&mut self) -> Option<A> {
         if self.count > 0 {
             self.count = 0;
+            let element = mem::replace(&mut self.element, MaybeUninit::uninit());
             // SAFETY: We just set count to zero so it won't be dropped again,
             // and it used to be non-zero so it hasn't already been dropped.
-            unsafe { Some(ManuallyDrop::take(&mut self.element)) }
+            unsafe { Some(element.assume_init()) }
         } else {
             None
         }
@@ -100,6 +108,26 @@ impl<A> RepeatN<A> {
 }
 
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
+impl<A: Clone> Clone for RepeatN<A> {
+    fn clone(&self) -> RepeatN<A> {
+        RepeatN {
+            count: self.count,
+            element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new),
+        }
+    }
+}
+
+#[stable(feature = "iter_repeat_n", since = "1.82.0")]
+impl<A: fmt::Debug> fmt::Debug for RepeatN<A> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("RepeatN")
+            .field("count", &self.count)
+            .field("element", &self.element_ref())
+            .finish()
+    }
+}
+
+#[stable(feature = "iter_repeat_n", since = "1.82.0")]
 impl<A> Drop for RepeatN<A> {
     fn drop(&mut self) {
         self.take_element();
@@ -194,9 +222,11 @@ impl<A: Clone> UncheckedIterator for RepeatN<A> {
             // SAFETY: the check above ensured that the count used to be non-zero,
             // so element hasn't been dropped yet, and we just lowered the count to
             // zero so it won't be dropped later, and thus it's okay to take it here.
-            unsafe { ManuallyDrop::take(&mut self.element) }
+            unsafe { mem::replace(&mut self.element, MaybeUninit::uninit()).assume_init() }
         } else {
-            A::clone(&self.element)
+            // SAFETY: the count is non-zero, so it must have not been dropped yet.
+            let element = unsafe { self.element.assume_init_ref() };
+            A::clone(element)
         }
     }
 }
diff --git a/library/core/tests/iter/sources.rs b/library/core/tests/iter/sources.rs
index eb8c80dd087..506febaa056 100644
--- a/library/core/tests/iter/sources.rs
+++ b/library/core/tests/iter/sources.rs
@@ -156,3 +156,27 @@ fn test_repeat_n_drop() {
     drop((x0, x1, x2));
     assert_eq!(count.get(), 3);
 }
+
+#[test]
+fn test_repeat_n_soundness() {
+    let x = std::iter::repeat_n(String::from("use after free"), 0);
+    println!("{x:?}");
+
+    pub struct PanicOnClone;
+
+    impl Clone for PanicOnClone {
+        fn clone(&self) -> Self {
+            unreachable!()
+        }
+    }
+
+    // `repeat_n` should drop the element immediately if `count` is zero.
+    // `Clone` should then not try to clone the element.
+    let x = std::iter::repeat_n(PanicOnClone, 0);
+    let _ = x.clone();
+
+    let mut y = std::iter::repeat_n(Box::new(0), 1);
+    let x = y.next().unwrap();
+    let _z = y;
+    assert_eq!(0, *x);
+}