about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/iter/adapters/take.rs102
-rw-r--r--tests/codegen/lib-optimizations/iter-sum.rs14
2 files changed, 97 insertions, 19 deletions
diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs
index ce18bffe714..70252e075b9 100644
--- a/library/core/src/iter/adapters/take.rs
+++ b/library/core/src/iter/adapters/take.rs
@@ -1,5 +1,7 @@
 use crate::cmp;
-use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen};
+use crate::iter::{
+    adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen, TrustedRandomAccess,
+};
 use crate::num::NonZeroUsize;
 use crate::ops::{ControlFlow, Try};
 
@@ -98,26 +100,18 @@ where
         }
     }
 
-    impl_fold_via_try_fold! { fold -> try_fold }
-
     #[inline]
-    fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
-        // The default implementation would use a unit accumulator, so we can
-        // avoid a stateful closure by folding over the remaining number
-        // of items we wish to return instead.
-        fn check<'a, Item>(
-            mut action: impl FnMut(Item) + 'a,
-        ) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
-            move |more, x| {
-                action(x);
-                more.checked_sub(1)
-            }
-        }
+    fn fold<B, F>(self, init: B, f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        Self::spec_fold(self, init, f)
+    }
 
-        let remaining = self.n;
-        if remaining > 0 {
-            self.iter.try_fold(remaining - 1, check(f));
-        }
+    #[inline]
+    fn for_each<F: FnMut(Self::Item)>(self, f: F) {
+        Self::spec_for_each(self, f)
     }
 
     #[inline]
@@ -249,3 +243,73 @@ impl<I> FusedIterator for Take<I> where I: FusedIterator {}
 
 #[unstable(feature = "trusted_len", issue = "37572")]
 unsafe impl<I: TrustedLen> TrustedLen for Take<I> {}
+
+trait SpecTake: Iterator {
+    fn spec_fold<B, F>(self, init: B, f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B;
+
+    fn spec_for_each<F: FnMut(Self::Item)>(self, f: F);
+}
+
+impl<I: Iterator> SpecTake for Take<I> {
+    #[inline]
+    default fn spec_fold<B, F>(mut self, init: B, f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        use crate::ops::NeverShortCircuit;
+        self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
+    }
+
+    #[inline]
+    default fn spec_for_each<F: FnMut(Self::Item)>(mut self, f: F) {
+        // The default implementation would use a unit accumulator, so we can
+        // avoid a stateful closure by folding over the remaining number
+        // of items we wish to return instead.
+        fn check<'a, Item>(
+            mut action: impl FnMut(Item) + 'a,
+        ) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
+            move |more, x| {
+                action(x);
+                more.checked_sub(1)
+            }
+        }
+
+        let remaining = self.n;
+        if remaining > 0 {
+            self.iter.try_fold(remaining - 1, check(f));
+        }
+    }
+}
+
+impl<I: Iterator + TrustedRandomAccess> SpecTake for Take<I> {
+    #[inline]
+    fn spec_fold<B, F>(mut self, init: B, mut f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        let mut acc = init;
+        let end = self.n.min(self.iter.size());
+        for i in 0..end {
+            // SAFETY: i < end <= self.iter.size() and we discard the iterator at the end
+            let val = unsafe { self.iter.__iterator_get_unchecked(i) };
+            acc = f(acc, val);
+        }
+        acc
+    }
+
+    #[inline]
+    fn spec_for_each<F: FnMut(Self::Item)>(self, f: F) {
+        // Based on the the Iterator trait default impl.
+        #[inline]
+        fn call<T>(mut f: impl FnMut(T)) -> impl FnMut((), T) {
+            move |(), item| f(item)
+        }
+
+        self.spec_fold((), call(f));
+    }
+}
diff --git a/tests/codegen/lib-optimizations/iter-sum.rs b/tests/codegen/lib-optimizations/iter-sum.rs
new file mode 100644
index 00000000000..d6ea4cd74d5
--- /dev/null
+++ b/tests/codegen/lib-optimizations/iter-sum.rs
@@ -0,0 +1,14 @@
+// ignore-debug: the debug assertions get in the way
+// compile-flags: -O
+#![crate_type = "lib"]
+
+
+// Ensure that slice + take + sum gets vectorized.
+// Currently this relies on the slice::Iter::try_fold implementation
+// CHECK-LABEL: @slice_take_sum
+#[no_mangle]
+pub fn slice_take_sum(s: &[u64], l: usize) -> u64 {
+    // CHECK: vector.body:
+    // CHECK: ret
+    s.iter().take(l).sum()
+}