about summary refs log tree commit diff
diff options
context:
space:
mode:
authorSoveu <marx.tomasz@gmail.com>2024-09-26 17:58:23 +0200
committerSoveu <marx.tomasz@gmail.com>2025-06-13 22:32:15 +0200
commit8f77681656c47aa09d1b66f87dc249aae1694d07 (patch)
treec9eb6ec384becb0b6c62866a2b65411fb9fcc346
parent0d6ab209c525e276cbe7544cbd39a3c3619b6b18 (diff)
downloadrust-8f77681656c47aa09d1b66f87dc249aae1694d07.tar.gz
rust-8f77681656c47aa09d1b66f87dc249aae1694d07.zip
100% safe implementation of RepeatN
-rw-r--r--library/core/src/iter/sources/repeat_n.rs161
-rw-r--r--tests/codegen/iter-repeat-n-trivial-drop.rs7
2 files changed, 47 insertions, 121 deletions
diff --git a/library/core/src/iter/sources/repeat_n.rs b/library/core/src/iter/sources/repeat_n.rs
index ada37b9af4c..c29ab24a083 100644
--- a/library/core/src/iter/sources/repeat_n.rs
+++ b/library/core/src/iter/sources/repeat_n.rs
@@ -1,8 +1,7 @@
 use crate::fmt;
 use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator};
-use crate::mem::MaybeUninit;
 use crate::num::NonZero;
-use crate::ops::{NeverShortCircuit, Try};
+use crate::ops::Try;
 
 /// Creates a new iterator that repeats a single element a given number of times.
 ///
@@ -58,14 +57,20 @@ use crate::ops::{NeverShortCircuit, Try};
 #[inline]
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
 pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
-    let element = if count == 0 {
-        // `element` gets dropped eagerly.
-        MaybeUninit::uninit()
-    } else {
-        MaybeUninit::new(element)
-    };
+    RepeatN { inner: RepeatNInner::new(element, count) }
+}
 
-    RepeatN { element, count }
+#[derive(Clone, Copy)]
+struct RepeatNInner<T> {
+    count: NonZero<usize>,
+    element: T,
+}
+
+impl<T> RepeatNInner<T> {
+    fn new(element: T, count: usize) -> Option<Self> {
+        let count = NonZero::<usize>::new(count)?;
+        Some(Self { element, count })
+    }
 }
 
 /// An iterator that repeats an element an exact number of times.
@@ -73,63 +78,27 @@ pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
 /// This `struct` is created by the [`repeat_n()`] function.
 /// See its documentation for more.
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
+#[derive(Clone)]
 pub struct RepeatN<A> {
-    count: usize,
-    // Invariant: uninit iff count == 0.
-    element: MaybeUninit<A>,
+    inner: Option<RepeatNInner<A>>,
 }
 
 impl<A> RepeatN<A> {
-    /// Returns the element if it hasn't been dropped already.
-    fn element_ref(&self) -> Option<&A> {
-        if self.count > 0 {
-            // SAFETY: The count is non-zero, so it must be initialized.
-            Some(unsafe { self.element.assume_init_ref() })
-        } else {
-            None
-        }
-    }
     /// If we haven't already dropped the element, return it in an option.
-    ///
-    /// Clears the count so it won't be dropped again later.
     #[inline]
     fn take_element(&mut self) -> Option<A> {
-        if self.count > 0 {
-            self.count = 0;
-            // SAFETY: We just set count to zero so it won't be dropped again,
-            // and it used to be non-zero so it hasn't already been dropped.
-            let element = unsafe { self.element.assume_init_read() };
-            Some(element)
-        } else {
-            None
-        }
-    }
-}
-
-#[stable(feature = "iter_repeat_n", since = "1.82.0")]
-impl<A: Clone> Clone for RepeatN<A> {
-    fn clone(&self) -> RepeatN<A> {
-        RepeatN {
-            count: self.count,
-            element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new),
-        }
+        self.inner.take().map(|inner| inner.element)
     }
 }
 
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
 impl<A: fmt::Debug> fmt::Debug for RepeatN<A> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        f.debug_struct("RepeatN")
-            .field("count", &self.count)
-            .field("element", &self.element_ref())
-            .finish()
-    }
-}
-
-#[stable(feature = "iter_repeat_n", since = "1.82.0")]
-impl<A> Drop for RepeatN<A> {
-    fn drop(&mut self) {
-        self.take_element();
+        let (count, element) = match self.inner.as_ref() {
+            Some(inner) => (inner.count.get(), Some(&inner.element)),
+            None => (0, None),
+        };
+        f.debug_struct("RepeatN").field("count", &count).field("element", &element).finish()
     }
 }
 
@@ -139,12 +108,17 @@ impl<A: Clone> Iterator for RepeatN<A> {
 
     #[inline]
     fn next(&mut self) -> Option<A> {
-        if self.count > 0 {
-            // SAFETY: Just checked it's not empty
-            unsafe { Some(self.next_unchecked()) }
-        } else {
-            None
+        let inner = self.inner.as_mut()?;
+        let count = inner.count.get();
+
+        if let Some(decremented) = NonZero::<usize>::new(count - 1) {
+            // Order of these is important for optimization
+            let tmp = inner.element.clone();
+            inner.count = decremented;
+            return Some(tmp);
         }
+
+        return self.take_element();
     }
 
     #[inline]
@@ -155,52 +129,19 @@ impl<A: Clone> Iterator for RepeatN<A> {
 
     #[inline]
     fn advance_by(&mut self, skip: usize) -> Result<(), NonZero<usize>> {
-        let len = self.count;
+        let Some(inner) = self.inner.as_mut() else {
+            return NonZero::<usize>::new(skip).map(Err).unwrap_or(Ok(()));
+        };
 
-        if skip >= len {
-            self.take_element();
-        }
+        let len = inner.count.get();
 
-        if skip > len {
-            // SAFETY: we just checked that the difference is positive
-            Err(unsafe { NonZero::new_unchecked(skip - len) })
-        } else {
-            self.count = len - skip;
-            Ok(())
+        if let Some(new_len) = len.checked_sub(skip).and_then(NonZero::<usize>::new) {
+            inner.count = new_len;
+            return Ok(());
         }
-    }
 
-    fn try_fold<B, F, R>(&mut self, mut acc: B, mut f: F) -> R
-    where
-        F: FnMut(B, A) -> R,
-        R: Try<Output = B>,
-    {
-        if self.count > 0 {
-            while self.count > 1 {
-                self.count -= 1;
-                // SAFETY: the count was larger than 1, so the element is
-                // initialized and hasn't been dropped.
-                acc = f(acc, unsafe { self.element.assume_init_ref().clone() })?;
-            }
-
-            // We could just set the count to zero directly, but doing it this
-            // way should make it easier for the optimizer to fold this tail
-            // into the loop when `clone()` is equivalent to copying.
-            self.count -= 1;
-            // SAFETY: we just set the count to zero from one, so the element
-            // is still initialized, has not been dropped yet and will not be
-            // accessed by future calls.
-            f(acc, unsafe { self.element.assume_init_read() })
-        } else {
-            try { acc }
-        }
-    }
-
-    fn fold<B, F>(mut self, init: B, f: F) -> B
-    where
-        F: FnMut(B, A) -> B,
-    {
-        self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
+        self.inner = None;
+        return NonZero::<usize>::new(skip - len).map(Err).unwrap_or(Ok(()));
     }
 
     #[inline]
@@ -217,7 +158,7 @@ impl<A: Clone> Iterator for RepeatN<A> {
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
 impl<A: Clone> ExactSizeIterator for RepeatN<A> {
     fn len(&self) -> usize {
-        self.count
+        self.inner.as_ref().map(|inner| inner.count.get()).unwrap_or(0)
     }
 }
 
@@ -262,20 +203,4 @@ impl<A: Clone> FusedIterator for RepeatN<A> {}
 #[unstable(feature = "trusted_len", issue = "37572")]
 unsafe impl<A: Clone> TrustedLen for RepeatN<A> {}
 #[stable(feature = "iter_repeat_n", since = "1.82.0")]
-impl<A: Clone> UncheckedIterator for RepeatN<A> {
-    #[inline]
-    unsafe fn next_unchecked(&mut self) -> Self::Item {
-        // SAFETY: The caller promised the iterator isn't empty
-        self.count = unsafe { self.count.unchecked_sub(1) };
-        if self.count == 0 {
-            // SAFETY: the check above ensured that the count used to be non-zero,
-            // so element hasn't been dropped yet, and we just lowered the count to
-            // zero so it won't be dropped later, and thus it's okay to take it here.
-            unsafe { self.element.assume_init_read() }
-        } else {
-            // SAFETY: the count is non-zero, so it must have not been dropped yet.
-            let element = unsafe { self.element.assume_init_ref() };
-            A::clone(element)
-        }
-    }
-}
+impl<A: Clone> UncheckedIterator for RepeatN<A> {}
diff --git a/tests/codegen/iter-repeat-n-trivial-drop.rs b/tests/codegen/iter-repeat-n-trivial-drop.rs
index 3bb942d11d5..28173530324 100644
--- a/tests/codegen/iter-repeat-n-trivial-drop.rs
+++ b/tests/codegen/iter-repeat-n-trivial-drop.rs
@@ -1,5 +1,6 @@
 //@ compile-flags: -C opt-level=3
 //@ only-x86_64
+//@ needs-deterministic-layouts
 
 #![crate_type = "lib"]
 #![feature(iter_repeat_n)]
@@ -25,10 +26,10 @@ pub fn iter_repeat_n_next(it: &mut std::iter::RepeatN<NotCopy>) -> Option<NotCop
     // CHECK-NEXT: br i1 %[[COUNT_ZERO]], label %[[EMPTY:.+]], label %[[NOT_EMPTY:.+]]
 
     // CHECK: [[NOT_EMPTY]]:
-    // CHECK-NEXT: %[[DEC:.+]] = add i64 %[[COUNT]], -1
-    // CHECK-NEXT: store i64 %[[DEC]]
     // CHECK-NOT: br
-    // CHECK: %[[VAL:.+]] = load i16
+    // CHECK: %[[DEC:.+]] = add i64 %[[COUNT]], -1
+    // CHECK-NEXT: %[[VAL:.+]] = load i16
+    // CHECK-NEXT: store i64 %[[DEC]]
     // CHECK-NEXT: br label %[[EMPTY]]
 
     // CHECK: [[EMPTY]]: