about summary refs log tree commit diff
path: root/library/core/src
diff options
context:
space:
mode:
authorThe 8472 <git@infinite-source.de>2022-10-23 19:19:37 +0200
committerThe 8472 <git@infinite-source.de>2022-11-07 21:44:25 +0100
commitcfcce8e684c5e1bb2f9a74e55debf801ef27706f (patch)
treee54523861d32ee7289d8d4cf1ffcdf14ea267e43 /library/core/src
parenteb3f001d3739e5e9f35e1a4b4e600889cf9980c6 (diff)
downloadrust-cfcce8e684c5e1bb2f9a74e55debf801ef27706f.tar.gz
rust-cfcce8e684c5e1bb2f9a74e55debf801ef27706f.zip
specialize iter::ArrayChunks::fold for TrustedRandomAccess iters
This is fairly safe use of TRA since it consumes the iterator so
no struct in an unsafe state will be left exposed to user code
Diffstat (limited to 'library/core/src')
-rw-r--r--library/core/src/iter/adapters/array_chunks.rs89
1 files changed, 86 insertions, 3 deletions
diff --git a/library/core/src/iter/adapters/array_chunks.rs b/library/core/src/iter/adapters/array_chunks.rs
index d4fb886101f..3f0fad4ed33 100644
--- a/library/core/src/iter/adapters/array_chunks.rs
+++ b/library/core/src/iter/adapters/array_chunks.rs
@@ -1,6 +1,8 @@
 use crate::array;
-use crate::iter::{ByRefSized, FusedIterator, Iterator};
-use crate::ops::{ControlFlow, Try};
+use crate::const_closure::ConstFnMutClosure;
+use crate::iter::{ByRefSized, FusedIterator, Iterator, TrustedRandomAccessNoCoerce};
+use crate::mem::{self, MaybeUninit};
+use crate::ops::{ControlFlow, NeverShortCircuit, Try};
 
 /// An iterator over `N` elements of the iterator at a time.
 ///
@@ -82,7 +84,13 @@ where
         }
     }
 
-    impl_fold_via_try_fold! { fold -> try_fold }
+    fn fold<B, F>(self, init: B, f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        <Self as SpecFold>::fold(self, init, f)
+    }
 }
 
 #[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")]
@@ -168,3 +176,78 @@ where
         self.iter.len() < N
     }
 }
+
+trait SpecFold: Iterator {
+    fn fold<B, F>(self, init: B, f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B;
+}
+
+impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
+where
+    I: Iterator,
+{
+    #[inline]
+    default fn fold<B, F>(mut self, init: B, mut f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        let fold = ConstFnMutClosure::new(&mut f, NeverShortCircuit::wrap_mut_2_imp);
+        self.try_fold(init, fold).0
+    }
+}
+
+impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
+where
+    I: Iterator + TrustedRandomAccessNoCoerce,
+{
+    #[inline]
+    fn fold<B, F>(mut self, init: B, mut f: F) -> B
+    where
+        Self: Sized,
+        F: FnMut(B, Self::Item) -> B,
+    {
+        if self.remainder.is_some() {
+            return init;
+        }
+
+        let mut accum = init;
+        let inner_len = self.iter.size();
+        let mut i = 0;
+        // Use a while loop because (0..len).step_by(N) doesn't optimize well.
+        while inner_len - i >= N {
+            let mut chunk = MaybeUninit::uninit_array();
+            let mut guard = array::Guard { array_mut: &mut chunk, initialized: 0 };
+            for j in 0..N {
+                // SAFETY: The method consumes the iterator and the loop condition ensures that
+                // all accesses are in bounds and only happen once.
+                guard.array_mut[j].write(unsafe { self.iter.__iterator_get_unchecked(i + j) });
+                guard.initialized = j + 1;
+            }
+            mem::forget(guard);
+            // SAFETY: The loop above initialized all elements
+            let chunk = unsafe { MaybeUninit::array_assume_init(chunk) };
+            accum = f(accum, chunk);
+            i += N;
+        }
+
+        let remainder = inner_len % N;
+
+        let mut tail = MaybeUninit::uninit_array();
+        let mut guard = array::Guard { array_mut: &mut tail, initialized: 0 };
+        for i in 0..remainder {
+            // SAFETY: the remainder was not visited by the previous loop, so we're still only
+            // accessing each element once
+            let val = unsafe { self.iter.__iterator_get_unchecked(inner_len - remainder + i) };
+            guard.array_mut[i].write(val);
+            guard.initialized = i + 1;
+        }
+        mem::forget(guard);
+        // SAFETY: the loop above initialized elements up to the `remainder` index
+        self.remainder = Some(unsafe { array::IntoIter::new_unchecked(tail, 0..remainder) });
+
+        accum
+    }
+}