about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/iter/adapters/flatten.rs375
-rw-r--r--library/core/tests/iter/adapters/flatten.rs42
2 files changed, 280 insertions, 137 deletions
diff --git a/library/core/src/iter/adapters/flatten.rs b/library/core/src/iter/adapters/flatten.rs
index 15a120e35a2..307016c2690 100644
--- a/library/core/src/iter/adapters/flatten.rs
+++ b/library/core/src/iter/adapters/flatten.rs
@@ -1,6 +1,6 @@
 use crate::fmt;
 use crate::iter::{DoubleEndedIterator, Fuse, FusedIterator, Iterator, Map, TrustedLen};
-use crate::ops::Try;
+use crate::ops::{ControlFlow, Try};
 
 /// An iterator that maps each element to an iterator, and yields the elements
 /// of the produced iterators.
@@ -73,6 +73,21 @@ where
     {
         self.inner.fold(init, fold)
     }
+
+    #[inline]
+    fn advance_by(&mut self, n: usize) -> Result<(), usize> {
+        self.inner.advance_by(n)
+    }
+
+    #[inline]
+    fn count(self) -> usize {
+        self.inner.count()
+    }
+
+    #[inline]
+    fn last(self) -> Option<Self::Item> {
+        self.inner.last()
+    }
 }
 
 #[stable(feature = "rust1", since = "1.0.0")]
@@ -103,6 +118,11 @@ where
     {
         self.inner.rfold(init, fold)
     }
+
+    #[inline]
+    fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
+        self.inner.advance_back_by(n)
+    }
 }
 
 #[stable(feature = "fused", since = "1.26.0")]
@@ -214,6 +234,21 @@ where
     {
         self.inner.fold(init, fold)
     }
+
+    #[inline]
+    fn advance_by(&mut self, n: usize) -> Result<(), usize> {
+        self.inner.advance_by(n)
+    }
+
+    #[inline]
+    fn count(self) -> usize {
+        self.inner.count()
+    }
+
+    #[inline]
+    fn last(self) -> Option<Self::Item> {
+        self.inner.last()
+    }
 }
 
 #[stable(feature = "iterator_flatten", since = "1.29.0")]
@@ -244,6 +279,11 @@ where
     {
         self.inner.rfold(init, fold)
     }
+
+    #[inline]
+    fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
+        self.inner.advance_back_by(n)
+    }
 }
 
 #[stable(feature = "iterator_flatten", since = "1.29.0")]
@@ -280,6 +320,144 @@ where
     }
 }
 
+impl<I, U> FlattenCompat<I, U>
+where
+    I: Iterator<Item: IntoIterator<IntoIter = U>>,
+{
+    /// Folds the inner iterators into an accumulator by applying an operation.
+    ///
+    /// Folds over the inner iterators, not over their elements. Is used by the `fold`, `count`,
+    /// and `last` methods.
+    #[inline]
+    fn iter_fold<Acc, Fold>(self, mut acc: Acc, mut fold: Fold) -> Acc
+    where
+        Fold: FnMut(Acc, U) -> Acc,
+    {
+        #[inline]
+        fn flatten<T: IntoIterator, Acc>(
+            fold: &mut impl FnMut(Acc, T::IntoIter) -> Acc,
+        ) -> impl FnMut(Acc, T) -> Acc + '_ {
+            move |acc, iter| fold(acc, iter.into_iter())
+        }
+
+        if let Some(iter) = self.frontiter {
+            acc = fold(acc, iter);
+        }
+
+        acc = self.iter.fold(acc, flatten(&mut fold));
+
+        if let Some(iter) = self.backiter {
+            acc = fold(acc, iter);
+        }
+
+        acc
+    }
+
+    /// Folds over the inner iterators as long as the given function returns successfully,
+    /// always storing the most recent inner iterator in `self.frontiter`.
+    ///
+    /// Folds over the inner iterators, not over their elements. Is used by the `try_fold` and
+    /// `advance_by` methods.
+    #[inline]
+    fn iter_try_fold<Acc, Fold, R>(&mut self, mut acc: Acc, mut fold: Fold) -> R
+    where
+        Fold: FnMut(Acc, &mut U) -> R,
+        R: Try<Output = Acc>,
+    {
+        #[inline]
+        fn flatten<'a, T: IntoIterator, Acc, R: Try<Output = Acc>>(
+            frontiter: &'a mut Option<T::IntoIter>,
+            fold: &'a mut impl FnMut(Acc, &mut T::IntoIter) -> R,
+        ) -> impl FnMut(Acc, T) -> R + 'a {
+            move |acc, iter| fold(acc, frontiter.insert(iter.into_iter()))
+        }
+
+        if let Some(iter) = &mut self.frontiter {
+            acc = fold(acc, iter)?;
+        }
+        self.frontiter = None;
+
+        acc = self.iter.try_fold(acc, flatten(&mut self.frontiter, &mut fold))?;
+        self.frontiter = None;
+
+        if let Some(iter) = &mut self.backiter {
+            acc = fold(acc, iter)?;
+        }
+        self.backiter = None;
+
+        try { acc }
+    }
+}
+
+impl<I, U> FlattenCompat<I, U>
+where
+    I: DoubleEndedIterator<Item: IntoIterator<IntoIter = U>>,
+{
+    /// Folds the inner iterators into an accumulator by applying an operation, starting form the
+    /// back.
+    ///
+    /// Folds over the inner iterators, not over their elements. Is used by the `rfold` method.
+    #[inline]
+    fn iter_rfold<Acc, Fold>(self, mut acc: Acc, mut fold: Fold) -> Acc
+    where
+        Fold: FnMut(Acc, U) -> Acc,
+    {
+        #[inline]
+        fn flatten<T: IntoIterator, Acc>(
+            fold: &mut impl FnMut(Acc, T::IntoIter) -> Acc,
+        ) -> impl FnMut(Acc, T) -> Acc + '_ {
+            move |acc, iter| fold(acc, iter.into_iter())
+        }
+
+        if let Some(iter) = self.backiter {
+            acc = fold(acc, iter);
+        }
+
+        acc = self.iter.rfold(acc, flatten(&mut fold));
+
+        if let Some(iter) = self.frontiter {
+            acc = fold(acc, iter);
+        }
+
+        acc
+    }
+
+    /// Folds over the inner iterators in reverse order as long as the given function returns
+    /// successfully, always storing the most recent inner iterator in `self.backiter`.
+    ///
+    /// Folds over the inner iterators, not over their elements. Is used by the `try_rfold` and
+    /// `advance_back_by` methods.
+    #[inline]
+    fn iter_try_rfold<Acc, Fold, R>(&mut self, mut acc: Acc, mut fold: Fold) -> R
+    where
+        Fold: FnMut(Acc, &mut U) -> R,
+        R: Try<Output = Acc>,
+    {
+        #[inline]
+        fn flatten<'a, T: IntoIterator, Acc, R: Try>(
+            backiter: &'a mut Option<T::IntoIter>,
+            fold: &'a mut impl FnMut(Acc, &mut T::IntoIter) -> R,
+        ) -> impl FnMut(Acc, T) -> R + 'a {
+            move |acc, iter| fold(acc, backiter.insert(iter.into_iter()))
+        }
+
+        if let Some(iter) = &mut self.backiter {
+            acc = fold(acc, iter)?;
+        }
+        self.backiter = None;
+
+        acc = self.iter.try_rfold(acc, flatten(&mut self.backiter, &mut fold))?;
+        self.backiter = None;
+
+        if let Some(iter) = &mut self.frontiter {
+            acc = fold(acc, iter)?;
+        }
+        self.frontiter = None;
+
+        try { acc }
+    }
+}
+
 impl<I, U> Iterator for FlattenCompat<I, U>
 where
     I: Iterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
@@ -323,99 +501,74 @@ where
     }
 
     #[inline]
-    fn try_fold<Acc, Fold, R>(&mut self, mut init: Acc, mut fold: Fold) -> R
+    fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
     where
         Self: Sized,
         Fold: FnMut(Acc, Self::Item) -> R,
         R: Try<Output = Acc>,
     {
         #[inline]
-        fn flatten<'a, T: IntoIterator, Acc, R: Try<Output = Acc>>(
-            frontiter: &'a mut Option<T::IntoIter>,
-            fold: &'a mut impl FnMut(Acc, T::Item) -> R,
-        ) -> impl FnMut(Acc, T) -> R + 'a {
-            move |acc, x| {
-                let mut mid = x.into_iter();
-                let r = mid.try_fold(acc, &mut *fold);
-                *frontiter = Some(mid);
-                r
-            }
-        }
-
-        if let Some(ref mut front) = self.frontiter {
-            init = front.try_fold(init, &mut fold)?;
+        fn flatten<U: Iterator, Acc, R: Try<Output = Acc>>(
+            mut fold: impl FnMut(Acc, U::Item) -> R,
+        ) -> impl FnMut(Acc, &mut U) -> R {
+            move |acc, iter| iter.try_fold(acc, &mut fold)
         }
-        self.frontiter = None;
 
-        init = self.iter.try_fold(init, flatten(&mut self.frontiter, &mut fold))?;
-        self.frontiter = None;
-
-        if let Some(ref mut back) = self.backiter {
-            init = back.try_fold(init, &mut fold)?;
-        }
-        self.backiter = None;
-
-        try { init }
+        self.iter_try_fold(init, flatten(fold))
     }
 
     #[inline]
-    fn fold<Acc, Fold>(self, mut init: Acc, mut fold: Fold) -> Acc
+    fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
     where
         Fold: FnMut(Acc, Self::Item) -> Acc,
     {
         #[inline]
-        fn flatten<T: IntoIterator, Acc>(
-            fold: &mut impl FnMut(Acc, T::Item) -> Acc,
-        ) -> impl FnMut(Acc, T) -> Acc + '_ {
-            move |acc, x| x.into_iter().fold(acc, &mut *fold)
-        }
-
-        if let Some(front) = self.frontiter {
-            init = front.fold(init, &mut fold);
-        }
-
-        init = self.iter.fold(init, flatten(&mut fold));
-
-        if let Some(back) = self.backiter {
-            init = back.fold(init, &mut fold);
+        fn flatten<U: Iterator, Acc>(
+            mut fold: impl FnMut(Acc, U::Item) -> Acc,
+        ) -> impl FnMut(Acc, U) -> Acc {
+            move |acc, iter| iter.fold(acc, &mut fold)
         }
 
-        init
+        self.iter_fold(init, flatten(fold))
     }
 
     #[inline]
     #[rustc_inherit_overflow_checks]
     fn advance_by(&mut self, n: usize) -> Result<(), usize> {
-        let mut rem = n;
-        loop {
-            if let Some(ref mut front) = self.frontiter {
-                match front.advance_by(rem) {
-                    ret @ Ok(_) => return ret,
-                    Err(advanced) => rem -= advanced,
-                }
-            }
-            self.frontiter = match self.iter.next() {
-                Some(iterable) => Some(iterable.into_iter()),
-                _ => break,
+        #[inline]
+        #[rustc_inherit_overflow_checks]
+        fn advance<U: Iterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
+            match iter.advance_by(n) {
+                Ok(()) => ControlFlow::BREAK,
+                Err(advanced) => ControlFlow::Continue(n - advanced),
             }
         }
 
-        self.frontiter = None;
-
-        if let Some(ref mut back) = self.backiter {
-            match back.advance_by(rem) {
-                ret @ Ok(_) => return ret,
-                Err(advanced) => rem -= advanced,
-            }
+        match self.iter_try_fold(n, advance) {
+            ControlFlow::Continue(remaining) if remaining > 0 => Err(n - remaining),
+            _ => Ok(()),
         }
+    }
 
-        if rem > 0 {
-            return Err(n - rem);
+    #[inline]
+    fn count(self) -> usize {
+        #[inline]
+        #[rustc_inherit_overflow_checks]
+        fn count<U: Iterator>(acc: usize, iter: U) -> usize {
+            acc + iter.count()
         }
 
-        self.backiter = None;
+        self.iter_fold(0, count)
+    }
+
+    #[inline]
+    fn last(self) -> Option<Self::Item> {
+        #[inline]
+        fn last<U: Iterator>(last: Option<U::Item>, iter: U) -> Option<U::Item> {
+            iter.last().or(last)
+        }
 
-        Ok(())
+        self.iter_fold(None, last)
     }
 }
 
@@ -438,105 +591,53 @@ where
     }
 
     #[inline]
-    fn try_rfold<Acc, Fold, R>(&mut self, mut init: Acc, mut fold: Fold) -> R
+    fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
     where
         Self: Sized,
         Fold: FnMut(Acc, Self::Item) -> R,
         R: Try<Output = Acc>,
     {
         #[inline]
-        fn flatten<'a, T: IntoIterator, Acc, R: Try<Output = Acc>>(
-            backiter: &'a mut Option<T::IntoIter>,
-            fold: &'a mut impl FnMut(Acc, T::Item) -> R,
-        ) -> impl FnMut(Acc, T) -> R + 'a
-        where
-            T::IntoIter: DoubleEndedIterator,
-        {
-            move |acc, x| {
-                let mut mid = x.into_iter();
-                let r = mid.try_rfold(acc, &mut *fold);
-                *backiter = Some(mid);
-                r
-            }
+        fn flatten<U: DoubleEndedIterator, Acc, R: Try<Output = Acc>>(
+            mut fold: impl FnMut(Acc, U::Item) -> R,
+        ) -> impl FnMut(Acc, &mut U) -> R {
+            move |acc, iter| iter.try_rfold(acc, &mut fold)
         }
 
-        if let Some(ref mut back) = self.backiter {
-            init = back.try_rfold(init, &mut fold)?;
-        }
-        self.backiter = None;
-
-        init = self.iter.try_rfold(init, flatten(&mut self.backiter, &mut fold))?;
-        self.backiter = None;
-
-        if let Some(ref mut front) = self.frontiter {
-            init = front.try_rfold(init, &mut fold)?;
-        }
-        self.frontiter = None;
-
-        try { init }
+        self.iter_try_rfold(init, flatten(fold))
     }
 
     #[inline]
-    fn rfold<Acc, Fold>(self, mut init: Acc, mut fold: Fold) -> Acc
+    fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
     where
         Fold: FnMut(Acc, Self::Item) -> Acc,
     {
         #[inline]
-        fn flatten<T: IntoIterator, Acc>(
-            fold: &mut impl FnMut(Acc, T::Item) -> Acc,
-        ) -> impl FnMut(Acc, T) -> Acc + '_
-        where
-            T::IntoIter: DoubleEndedIterator,
-        {
-            move |acc, x| x.into_iter().rfold(acc, &mut *fold)
-        }
-
-        if let Some(back) = self.backiter {
-            init = back.rfold(init, &mut fold);
+        fn flatten<U: DoubleEndedIterator, Acc>(
+            mut fold: impl FnMut(Acc, U::Item) -> Acc,
+        ) -> impl FnMut(Acc, U) -> Acc {
+            move |acc, iter| iter.rfold(acc, &mut fold)
         }
 
-        init = self.iter.rfold(init, flatten(&mut fold));
-
-        if let Some(front) = self.frontiter {
-            init = front.rfold(init, &mut fold);
-        }
-
-        init
+        self.iter_rfold(init, flatten(fold))
     }
 
     #[inline]
     #[rustc_inherit_overflow_checks]
     fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
-        let mut rem = n;
-        loop {
-            if let Some(ref mut back) = self.backiter {
-                match back.advance_back_by(rem) {
-                    ret @ Ok(_) => return ret,
-                    Err(advanced) => rem -= advanced,
-                }
-            }
-            match self.iter.next_back() {
-                Some(iterable) => self.backiter = Some(iterable.into_iter()),
-                _ => break,
-            }
-        }
-
-        self.backiter = None;
-
-        if let Some(ref mut front) = self.frontiter {
-            match front.advance_back_by(rem) {
-                ret @ Ok(_) => return ret,
-                Err(advanced) => rem -= advanced,
+        #[inline]
+        #[rustc_inherit_overflow_checks]
+        fn advance<U: DoubleEndedIterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
+            match iter.advance_back_by(n) {
+                Ok(()) => ControlFlow::BREAK,
+                Err(advanced) => ControlFlow::Continue(n - advanced),
             }
         }
 
-        if rem > 0 {
-            return Err(n - rem);
+        match self.iter_try_rfold(n, advance) {
+            ControlFlow::Continue(remaining) if remaining > 0 => Err(n - remaining),
+            _ => Ok(()),
         }
-
-        self.frontiter = None;
-
-        Ok(())
     }
 }
 
diff --git a/library/core/tests/iter/adapters/flatten.rs b/library/core/tests/iter/adapters/flatten.rs
index f8ab8c9d444..690fd0c2197 100644
--- a/library/core/tests/iter/adapters/flatten.rs
+++ b/library/core/tests/iter/adapters/flatten.rs
@@ -168,3 +168,45 @@ fn test_trusted_len_flatten() {
     assert_trusted_len(&iter);
     assert_eq!(iter.size_hint(), (20, Some(20)));
 }
+
+#[test]
+fn test_flatten_count() {
+    let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten();
+
+    assert_eq!(it.clone().count(), 40);
+    it.advance_by(5).unwrap();
+    assert_eq!(it.clone().count(), 35);
+    it.advance_back_by(5).unwrap();
+    assert_eq!(it.clone().count(), 30);
+    it.advance_by(10).unwrap();
+    assert_eq!(it.clone().count(), 20);
+    it.advance_back_by(8).unwrap();
+    assert_eq!(it.clone().count(), 12);
+    it.advance_by(4).unwrap();
+    assert_eq!(it.clone().count(), 8);
+    it.advance_back_by(5).unwrap();
+    assert_eq!(it.clone().count(), 3);
+    it.advance_by(3).unwrap();
+    assert_eq!(it.clone().count(), 0);
+}
+
+#[test]
+fn test_flatten_last() {
+    let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten();
+
+    assert_eq!(it.clone().last(), Some(39));
+    it.advance_by(5).unwrap(); // 5..40
+    assert_eq!(it.clone().last(), Some(39));
+    it.advance_back_by(5).unwrap(); // 5..35
+    assert_eq!(it.clone().last(), Some(34));
+    it.advance_by(10).unwrap(); // 15..35
+    assert_eq!(it.clone().last(), Some(34));
+    it.advance_back_by(8).unwrap(); // 15..27
+    assert_eq!(it.clone().last(), Some(26));
+    it.advance_by(4).unwrap(); // 19..27
+    assert_eq!(it.clone().last(), Some(26));
+    it.advance_back_by(5).unwrap(); // 19..22
+    assert_eq!(it.clone().last(), Some(21));
+    it.advance_by(3).unwrap(); // 22..22
+    assert_eq!(it.clone().last(), None);
+}