about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/iter/adapters/array_chunks.rs89
1 files changed, 86 insertions, 3 deletions
diff --git a/library/core/src/iter/adapters/array_chunks.rs b/library/core/src/iter/adapters/array_chunks.rs
index d4fb886101f..3f0fad4ed33 100644
--- a/library/core/src/iter/adapters/array_chunks.rs
+++ b/library/core/src/iter/adapters/array_chunks.rs
@@ -1,6 +1,8 @@
 use crate::array;
-use crate::iter::{ByRefSized, FusedIterator, Iterator};
-use crate::ops::{ControlFlow, Try};
+use crate::const_closure::ConstFnMutClosure;
+use crate::iter::{ByRefSized, FusedIterator, Iterator, TrustedRandomAccessNoCoerce};
+use crate::mem::{self, MaybeUninit};
+use crate::ops::{ControlFlow, NeverShortCircuit, Try};
 
 /// An iterator over `N` elements of the iterator at a time.
 ///
@@ -82,7 +84,13 @@ where
         }
     }
 
-    impl_fold_via_try_fold! { fold -> try_fold }
+    fn fold<B, F>(self, init: B, f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        <Self as SpecFold>::fold(self, init, f)
+    }
 }
 
 #[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")]
@@ -168,3 +176,78 @@ where
         self.iter.len() < N
     }
 }
+
+trait SpecFold: Iterator {
+    fn fold<B, F>(self, init: B, f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B;
+}
+
+impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
+where
+    I: Iterator,
+{
+    #[inline]
+    default fn fold<B, F>(mut self, init: B, mut f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        let fold = ConstFnMutClosure::new(&mut f, NeverShortCircuit::wrap_mut_2_imp);
+        self.try_fold(init, fold).0
+    }
+}
+
+impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
+where
+    I: Iterator + TrustedRandomAccessNoCoerce,
+{
+    #[inline]
+    fn fold<B, F>(mut self, init: B, mut f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        if self.remainder.is_some() {
+            return init;
+        }
+
+        let mut accum = init;
+        let inner_len = self.iter.size();
+        let mut i = 0;
+        // Use a while loop because (0..len).step_by(N) doesn't optimize well.
+        while inner_len - i >= N {
+            let mut chunk = MaybeUninit::uninit_array();
+            let mut guard = array::Guard { array_mut: &mut chunk, initialized: 0 };
+            for j in 0..N {
+                // SAFETY: The method consumes the iterator and the loop condition ensures that
+                // all accesses are in bounds and only happen once.
+                guard.array_mut[j].write(unsafe { self.iter.__iterator_get_unchecked(i + j) });
+                guard.initialized = j + 1;
+            }
+            mem::forget(guard);
+            // SAFETY: The loop above initialized all elements
+            let chunk = unsafe { MaybeUninit::array_assume_init(chunk) };
+            accum = f(accum, chunk);
+            i += N;
+        }
+
+        let remainder = inner_len % N;
+
+        let mut tail = MaybeUninit::uninit_array();
+        let mut guard = array::Guard { array_mut: &mut tail, initialized: 0 };
+        for i in 0..remainder {
+            // SAFETY: the remainder was not visited by the previous loop, so we're still only
+            // accessing each element once
+            let val = unsafe { self.iter.__iterator_get_unchecked(inner_len - remainder + i) };
+            guard.array_mut[i].write(val);
+            guard.initialized = i + 1;
+        }
+        mem::forget(guard);
+        // SAFETY: the loop above initialized elements up to the `remainder` index
+        self.remainder = Some(unsafe { array::IntoIter::new_unchecked(tail, 0..remainder) });
+
+        accum
+    }
+}