about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/alloc/src/vec/mod.rs9
-rw-r--r--library/alloc/src/vec/set_len_on_drop.rs5
-rw-r--r--library/core/src/iter/adapters/take.rs21
-rw-r--r--library/core/src/iter/sources/repeat_with.rs17
-rw-r--r--library/core/tests/iter/adapters/take.rs20
-rw-r--r--src/test/codegen/repeat-trusted-len.rs7
6 files changed, 73 insertions, 6 deletions
diff --git a/library/alloc/src/vec/mod.rs b/library/alloc/src/vec/mod.rs
index 94ebcb863a4..e147af2ce39 100644
--- a/library/alloc/src/vec/mod.rs
+++ b/library/alloc/src/vec/mod.rs
@@ -2874,13 +2874,12 @@ impl<T, A: Allocator> Vec<T, A> {
             );
             self.reserve(additional);
             unsafe {
-                let mut ptr = self.as_mut_ptr().add(self.len());
+                let ptr = self.as_mut_ptr();
                 let mut local_len = SetLenOnDrop::new(&mut self.len);
                 iterator.for_each(move |element| {
-                    ptr::write(ptr, element);
-                    ptr = ptr.add(1);
-                    // Since the loop executes user code which can panic we have to bump the pointer
-                    // after each step.
+                    ptr::write(ptr.add(local_len.current_len()), element);
+                    // Since the loop executes user code which can panic we have to update
+                    // the length every step to correctly drop what we've written.
                     // NB can't overflow since we would have had to alloc the address space
                     local_len.increment_len(1);
                 });
diff --git a/library/alloc/src/vec/set_len_on_drop.rs b/library/alloc/src/vec/set_len_on_drop.rs
index 8b66bc81212..6ce5a3a9f54 100644
--- a/library/alloc/src/vec/set_len_on_drop.rs
+++ b/library/alloc/src/vec/set_len_on_drop.rs
@@ -18,6 +18,11 @@ impl<'a> SetLenOnDrop<'a> {
     pub(super) fn increment_len(&mut self, increment: usize) {
         self.local_len += increment;
     }
+
+    #[inline]
+    pub(super) fn current_len(&self) -> usize {
+        self.local_len
+    }
 }
 
 impl Drop for SetLenOnDrop<'_> {
diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs
index 58a0b9d7bbe..d947c7b0e30 100644
--- a/library/core/src/iter/adapters/take.rs
+++ b/library/core/src/iter/adapters/take.rs
@@ -75,7 +75,6 @@ where
     #[inline]
     fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
     where
-        Self: Sized,
         Fold: FnMut(Acc, Self::Item) -> R,
         R: Try<Output = Acc>,
     {
@@ -101,6 +100,26 @@ where
     impl_fold_via_try_fold! { fold -> try_fold }
 
     #[inline]
+    fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
+        // The default implementation would use a unit accumulator, so we can
+        // avoid a stateful closure by folding over the remaining number
+        // of items we wish to return instead.
+        fn check<'a, Item>(
+            mut action: impl FnMut(Item) + 'a,
+        ) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
+            move |more, x| {
+                action(x);
+                more.checked_sub(1)
+            }
+        }
+
+        let remaining = self.n;
+        if remaining > 0 {
+            self.iter.try_fold(remaining - 1, check(f));
+        }
+    }
+
+    #[inline]
     #[rustc_inherit_overflow_checks]
     fn advance_by(&mut self, n: usize) -> Result<(), usize> {
         let min = self.n.min(n);
diff --git a/library/core/src/iter/sources/repeat_with.rs b/library/core/src/iter/sources/repeat_with.rs
index 6f62662d880..ab2d0472b47 100644
--- a/library/core/src/iter/sources/repeat_with.rs
+++ b/library/core/src/iter/sources/repeat_with.rs
@@ -1,4 +1,5 @@
 use crate::iter::{FusedIterator, TrustedLen};
+use crate::ops::Try;
 
 /// Creates a new iterator that repeats elements of type `A` endlessly by
 /// applying the provided closure, the repeater, `F: FnMut() -> A`.
@@ -89,6 +90,22 @@ impl<A, F: FnMut() -> A> Iterator for RepeatWith<F> {
     fn size_hint(&self) -> (usize, Option<usize>) {
         (usize::MAX, None)
     }
+
+    #[inline]
+    fn try_fold<Acc, Fold, R>(&mut self, mut init: Acc, mut fold: Fold) -> R
+    where
+        Fold: FnMut(Acc, Self::Item) -> R,
+        R: Try<Output = Acc>,
+    {
+        // This override isn't strictly needed, but avoids the need to optimize
+        // away the `next`-always-returns-`Some` and emphasizes that the `?`
+        // is the only way to exit the loop.
+
+        loop {
+            let item = (self.repeater)();
+            init = fold(init, item)?;
+        }
+    }
 }
 
 #[stable(feature = "iterator_repeat_with", since = "1.28.0")]
diff --git a/library/core/tests/iter/adapters/take.rs b/library/core/tests/iter/adapters/take.rs
index bfb659f0a83..3e26b43a2ed 100644
--- a/library/core/tests/iter/adapters/take.rs
+++ b/library/core/tests/iter/adapters/take.rs
@@ -146,3 +146,23 @@ fn test_take_try_folds() {
     assert_eq!(iter.try_for_each(Err), Err(2));
     assert_eq!(iter.try_for_each(Err), Ok(()));
 }
+
+#[test]
+fn test_byref_take_consumed_items() {
+    let mut inner = 10..90;
+
+    let mut count = 0;
+    inner.by_ref().take(0).for_each(|_| count += 1);
+    assert_eq!(count, 0);
+    assert_eq!(inner, 10..90);
+
+    let mut count = 0;
+    inner.by_ref().take(10).for_each(|_| count += 1);
+    assert_eq!(count, 10);
+    assert_eq!(inner, 20..90);
+
+    let mut count = 0;
+    inner.by_ref().take(100).for_each(|_| count += 1);
+    assert_eq!(count, 70);
+    assert_eq!(inner, 90..90);
+}
diff --git a/src/test/codegen/repeat-trusted-len.rs b/src/test/codegen/repeat-trusted-len.rs
index 7aebd3ec7df..87c8fe1354d 100644
--- a/src/test/codegen/repeat-trusted-len.rs
+++ b/src/test/codegen/repeat-trusted-len.rs
@@ -11,3 +11,10 @@ pub fn repeat_take_collect() -> Vec<u8> {
 // CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 42, i{{[0-9]+}} 100000, i1 false)
     iter::repeat(42).take(100000).collect()
 }
+
+// CHECK-LABEL: @repeat_with_take_collect
+#[no_mangle]
+pub fn repeat_with_take_collect() -> Vec<u8> {
+// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 13, i{{[0-9]+}} 12345, i1 false)
+    iter::repeat_with(|| 13).take(12345).collect()
+}