about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2022-11-27 00:58:50 +0000
committerbors <bors@rust-lang.org>2022-11-27 00:58:50 +0000
commitfaf1891deb2633fe4040de8b71fd7b2045c45dc5 (patch)
treebaf896d46501fa2b94a56210391db2523a791748
parentc0e9c86b3f3e96267ba2cd80f95f362ef0cce40b (diff)
parent9d68a1a74c65245c9ae7b5db2c552c995697e8ef (diff)
downloadrust-faf1891deb2633fe4040de8b71fd7b2045c45dc5.tar.gz
rust-faf1891deb2633fe4040de8b71fd7b2045c45dc5.zip
Auto merge of #104818 - scottmcm:refactor-extend-func, r=the8472
Stop peeling the last iteration of the loop in `Vec::resize_with`

`resize_with` uses the `ExtendWith` code that peels the last iteration:
https://github.com/rust-lang/rust/blob/341d8b8a2c290b4535e965867e876b095461ff6e/library/alloc/src/vec/mod.rs#L2525-L2529

But that's kinda weird for `ExtendFunc` because it does the same thing on the last iteration anyway:
https://github.com/rust-lang/rust/blob/341d8b8a2c290b4535e965867e876b095461ff6e/library/alloc/src/vec/mod.rs#L2494-L2502

So this just has it use the normal `extend`-from-`TrustedLen` code instead.

r? `@ghost`
-rw-r--r--library/alloc/src/vec/mod.rs46
-rw-r--r--library/alloc/src/vec/set_len_on_drop.rs5
-rw-r--r--library/alloc/src/vec/spec_extend.rs34
-rw-r--r--library/core/src/iter/adapters/take.rs21
-rw-r--r--library/core/src/iter/sources/repeat_with.rs17
-rw-r--r--library/core/tests/iter/adapters/take.rs20
-rw-r--r--src/test/codegen/repeat-trusted-len.rs7
7 files changed, 106 insertions, 44 deletions
diff --git a/library/alloc/src/vec/mod.rs b/library/alloc/src/vec/mod.rs
index 766006939fa..e147af2ce39 100644
--- a/library/alloc/src/vec/mod.rs
+++ b/library/alloc/src/vec/mod.rs
@@ -2163,7 +2163,7 @@ impl<T, A: Allocator> Vec<T, A> {
     {
         let len = self.len();
         if new_len > len {
-            self.extend_with(new_len - len, ExtendFunc(f));
+            self.extend_trusted(iter::repeat_with(f).take(new_len - len));
         } else {
             self.truncate(new_len);
         }
@@ -2491,16 +2491,6 @@ impl<T: Clone> ExtendWith<T> for ExtendElement<T> {
     }
 }
 
-struct ExtendFunc<F>(F);
-impl<T, F: FnMut() -> T> ExtendWith<T> for ExtendFunc<F> {
-    fn next(&mut self) -> T {
-        (self.0)()
-    }
-    fn last(mut self) -> T {
-        (self.0)()
-    }
-}
-
 impl<T, A: Allocator> Vec<T, A> {
     #[cfg(not(no_global_oom_handling))]
     /// Extend the vector by `n` values, using the given generator.
@@ -2870,6 +2860,40 @@ impl<T, A: Allocator> Vec<T, A> {
         }
     }
 
+    // specific extend for `TrustedLen` iterators, called both by the specializations
+    // and internal places where resolving specialization makes compilation slower
+    #[cfg(not(no_global_oom_handling))]
+    fn extend_trusted(&mut self, iterator: impl iter::TrustedLen<Item = T>) {
+        let (low, high) = iterator.size_hint();
+        if let Some(additional) = high {
+            debug_assert_eq!(
+                low,
+                additional,
+                "TrustedLen iterator's size hint is not exact: {:?}",
+                (low, high)
+            );
+            self.reserve(additional);
+            unsafe {
+                let ptr = self.as_mut_ptr();
+                let mut local_len = SetLenOnDrop::new(&mut self.len);
+                iterator.for_each(move |element| {
+                    ptr::write(ptr.add(local_len.current_len()), element);
+                    // Since the loop executes user code which can panic we have to update
+                    // the length every step to correctly drop what we've written.
+                    // NB can't overflow since we would have had to alloc the address space
+                    local_len.increment_len(1);
+                });
+            }
+        } else {
+            // Per TrustedLen contract a `None` upper bound means that the iterator length
+            // truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway.
+            // Since the other branch already panics eagerly (via `reserve()`) we do the same here.
+            // This avoids additional codegen for a fallback code path which would eventually
+            // panic anyway.
+            panic!("capacity overflow");
+        }
+    }
+
     /// Creates a splicing iterator that replaces the specified range in the vector
     /// with the given `replace_with` iterator and yields the removed items.
     /// `replace_with` does not need to be the same length as `range`.
diff --git a/library/alloc/src/vec/set_len_on_drop.rs b/library/alloc/src/vec/set_len_on_drop.rs
index 8b66bc81212..6ce5a3a9f54 100644
--- a/library/alloc/src/vec/set_len_on_drop.rs
+++ b/library/alloc/src/vec/set_len_on_drop.rs
@@ -18,6 +18,11 @@ impl<'a> SetLenOnDrop<'a> {
     pub(super) fn increment_len(&mut self, increment: usize) {
         self.local_len += increment;
     }
+
+    #[inline]
+    pub(super) fn current_len(&self) -> usize {
+        self.local_len
+    }
 }
 
 impl Drop for SetLenOnDrop<'_> {
diff --git a/library/alloc/src/vec/spec_extend.rs b/library/alloc/src/vec/spec_extend.rs
index 1ea9c827afd..56065ce565b 100644
--- a/library/alloc/src/vec/spec_extend.rs
+++ b/library/alloc/src/vec/spec_extend.rs
@@ -1,9 +1,8 @@
 use crate::alloc::Allocator;
 use core::iter::TrustedLen;
-use core::ptr::{self};
 use core::slice::{self};
 
-use super::{IntoIter, SetLenOnDrop, Vec};
+use super::{IntoIter, Vec};
 
 // Specialization trait used for Vec::extend
 pub(super) trait SpecExtend<T, I> {
@@ -24,36 +23,7 @@ where
     I: TrustedLen<Item = T>,
 {
     default fn spec_extend(&mut self, iterator: I) {
-        // This is the case for a TrustedLen iterator.
-        let (low, high) = iterator.size_hint();
-        if let Some(additional) = high {
-            debug_assert_eq!(
-                low,
-                additional,
-                "TrustedLen iterator's size hint is not exact: {:?}",
-                (low, high)
-            );
-            self.reserve(additional);
-            unsafe {
-                let mut ptr = self.as_mut_ptr().add(self.len());
-                let mut local_len = SetLenOnDrop::new(&mut self.len);
-                iterator.for_each(move |element| {
-                    ptr::write(ptr, element);
-                    ptr = ptr.add(1);
-                    // Since the loop executes user code which can panic we have to bump the pointer
-                    // after each step.
-                    // NB can't overflow since we would have had to alloc the address space
-                    local_len.increment_len(1);
-                });
-            }
-        } else {
-            // Per TrustedLen contract a `None` upper bound means that the iterator length
-            // truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway.
-            // Since the other branch already panics eagerly (via `reserve()`) we do the same here.
-            // This avoids additional codegen for a fallback code path which would eventually
-            // panic anyway.
-            panic!("capacity overflow");
-        }
+        self.extend_trusted(iterator)
     }
 }
 
diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs
index 58a0b9d7bbe..d947c7b0e30 100644
--- a/library/core/src/iter/adapters/take.rs
+++ b/library/core/src/iter/adapters/take.rs
@@ -75,7 +75,6 @@ where
     #[inline]
     fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
     where
-        Self: Sized,
         Fold: FnMut(Acc, Self::Item) -> R,
         R: Try<Output = Acc>,
     {
@@ -101,6 +100,26 @@ where
     impl_fold_via_try_fold! { fold -> try_fold }
 
     #[inline]
+    fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
+        // The default implementation would use a unit accumulator, so we can
+        // avoid a stateful closure by folding over the remaining number
+        // of items we wish to return instead.
+        fn check<'a, Item>(
+            mut action: impl FnMut(Item) + 'a,
+        ) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
+            move |more, x| {
+                action(x);
+                more.checked_sub(1)
+            }
+        }
+
+        let remaining = self.n;
+        if remaining > 0 {
+            self.iter.try_fold(remaining - 1, check(f));
+        }
+    }
+
+    #[inline]
     #[rustc_inherit_overflow_checks]
     fn advance_by(&mut self, n: usize) -> Result<(), usize> {
         let min = self.n.min(n);
diff --git a/library/core/src/iter/sources/repeat_with.rs b/library/core/src/iter/sources/repeat_with.rs
index 6f62662d880..ab2d0472b47 100644
--- a/library/core/src/iter/sources/repeat_with.rs
+++ b/library/core/src/iter/sources/repeat_with.rs
@@ -1,4 +1,5 @@
 use crate::iter::{FusedIterator, TrustedLen};
+use crate::ops::Try;
 
 /// Creates a new iterator that repeats elements of type `A` endlessly by
 /// applying the provided closure, the repeater, `F: FnMut() -> A`.
@@ -89,6 +90,22 @@ impl<A, F: FnMut() -> A> Iterator for RepeatWith<F> {
     fn size_hint(&self) -> (usize, Option<usize>) {
         (usize::MAX, None)
     }
+
+    #[inline]
+    fn try_fold<Acc, Fold, R>(&mut self, mut init: Acc, mut fold: Fold) -> R
+    where
+        Fold: FnMut(Acc, Self::Item) -> R,
+        R: Try<Output = Acc>,
+    {
+        // This override isn't strictly needed, but avoids the need to optimize
+        // away the `next`-always-returns-`Some` and emphasizes that the `?`
+        // is the only way to exit the loop.
+
+        loop {
+            let item = (self.repeater)();
+            init = fold(init, item)?;
+        }
+    }
 }
 
 #[stable(feature = "iterator_repeat_with", since = "1.28.0")]
diff --git a/library/core/tests/iter/adapters/take.rs b/library/core/tests/iter/adapters/take.rs
index bfb659f0a83..3e26b43a2ed 100644
--- a/library/core/tests/iter/adapters/take.rs
+++ b/library/core/tests/iter/adapters/take.rs
@@ -146,3 +146,23 @@ fn test_take_try_folds() {
     assert_eq!(iter.try_for_each(Err), Err(2));
     assert_eq!(iter.try_for_each(Err), Ok(()));
 }
+
+#[test]
+fn test_byref_take_consumed_items() {
+    let mut inner = 10..90;
+
+    let mut count = 0;
+    inner.by_ref().take(0).for_each(|_| count += 1);
+    assert_eq!(count, 0);
+    assert_eq!(inner, 10..90);
+
+    let mut count = 0;
+    inner.by_ref().take(10).for_each(|_| count += 1);
+    assert_eq!(count, 10);
+    assert_eq!(inner, 20..90);
+
+    let mut count = 0;
+    inner.by_ref().take(100).for_each(|_| count += 1);
+    assert_eq!(count, 70);
+    assert_eq!(inner, 90..90);
+}
diff --git a/src/test/codegen/repeat-trusted-len.rs b/src/test/codegen/repeat-trusted-len.rs
index 7aebd3ec7df..87c8fe1354d 100644
--- a/src/test/codegen/repeat-trusted-len.rs
+++ b/src/test/codegen/repeat-trusted-len.rs
@@ -11,3 +11,10 @@ pub fn repeat_take_collect() -> Vec<u8> {
 // CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 42, i{{[0-9]+}} 100000, i1 false)
     iter::repeat(42).take(100000).collect()
 }
+
+// CHECK-LABEL: @repeat_with_take_collect
+#[no_mangle]
+pub fn repeat_with_take_collect() -> Vec<u8> {
+// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 13, i{{[0-9]+}} 12345, i1 false)
+    iter::repeat_with(|| 13).take(12345).collect()
+}