about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-02-17 20:18:54 +0000
committerbors <bors@rust-lang.org>2024-02-17 20:18:54 +0000
commit6672c16afcd4db8acdf08a6984fd4107bf07632c (patch)
tree069eb81740230dc0ecc830c8683b789524eddfa8
parentcabdf3ad257693aa79ffcc4b7dd1fdab41dc209e (diff)
parentc36ae932f901d452e78e1dd69400f6b2169c8fa7 (diff)
downloadrust-6672c16afcd4db8acdf08a6984fd4107bf07632c.tar.gz
rust-6672c16afcd4db8acdf08a6984fd4107bf07632c.zip
Auto merge of #121204 - cuviper:flatten-one-shot, r=the8472
Specialize flattening iterators with only one inner item

For iterators like `Once` and `option::IntoIter` that only ever have a
single item at most, the front and back iterator states in `FlatMap` and
`Flatten` are a waste, as they're always consumed already. We can use
specialization for these types to simplify the iterator methods.

It's a somewhat common pattern to use `flatten()` for options and
results, even recommended by [multiple][1] [clippy][2] [lints][3]. The
implementation is more efficient with `filter_map`, as mentioned in
[clippy#9377], but this new specialization should close some of that
gap for existing code that flattens.

[1]: https://rust-lang.github.io/rust-clippy/master/#filter_map_identity
[2]: https://rust-lang.github.io/rust-clippy/master/#option_filter_map
[3]: https://rust-lang.github.io/rust-clippy/master/#result_filter_map
[clippy#9377]: https://github.com/rust-lang/rust-clippy/issues/9377
-rw-r--r--library/core/src/iter/adapters/flatten.rs221
-rw-r--r--library/core/tests/iter/adapters/flatten.rs66
2 files changed, 275 insertions, 12 deletions
diff --git a/library/core/src/iter/adapters/flatten.rs b/library/core/src/iter/adapters/flatten.rs
index 99344a88efc..145c9d3dacc 100644
--- a/library/core/src/iter/adapters/flatten.rs
+++ b/library/core/src/iter/adapters/flatten.rs
@@ -3,7 +3,7 @@ use crate::iter::{
     Cloned, Copied, Filter, FilterMap, Fuse, FusedIterator, InPlaceIterable, Map, TrustedFused,
     TrustedLen,
 };
-use crate::iter::{Once, OnceWith};
+use crate::iter::{Empty, Once, OnceWith};
 use crate::num::NonZero;
 use crate::ops::{ControlFlow, Try};
 use crate::result;
@@ -593,6 +593,7 @@ where
     }
 }
 
+// See also the `OneShot` specialization below.
 impl<I, U> Iterator for FlattenCompat<I, U>
 where
     I: Iterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
@@ -601,7 +602,7 @@ where
     type Item = U::Item;
 
     #[inline]
-    fn next(&mut self) -> Option<U::Item> {
+    default fn next(&mut self) -> Option<U::Item> {
         loop {
             if let elt @ Some(_) = and_then_or_clear(&mut self.frontiter, Iterator::next) {
                 return elt;
@@ -614,7 +615,7 @@ where
     }
 
     #[inline]
-    fn size_hint(&self) -> (usize, Option<usize>) {
+    default fn size_hint(&self) -> (usize, Option<usize>) {
         let (flo, fhi) = self.frontiter.as_ref().map_or((0, Some(0)), U::size_hint);
         let (blo, bhi) = self.backiter.as_ref().map_or((0, Some(0)), U::size_hint);
         let lo = flo.saturating_add(blo);
@@ -636,7 +637,7 @@ where
     }
 
     #[inline]
-    fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
+    default fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
     where
         Self: Sized,
         Fold: FnMut(Acc, Self::Item) -> R,
@@ -653,7 +654,7 @@ where
     }
 
     #[inline]
-    fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
+    default fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
     where
         Fold: FnMut(Acc, Self::Item) -> Acc,
     {
@@ -669,7 +670,7 @@ where
 
     #[inline]
     #[rustc_inherit_overflow_checks]
-    fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
+    default fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
         #[inline]
         #[rustc_inherit_overflow_checks]
         fn advance<U: Iterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
@@ -686,7 +687,7 @@ where
     }
 
     #[inline]
-    fn count(self) -> usize {
+    default fn count(self) -> usize {
         #[inline]
         #[rustc_inherit_overflow_checks]
         fn count<U: Iterator>(acc: usize, iter: U) -> usize {
@@ -697,7 +698,7 @@ where
     }
 
     #[inline]
-    fn last(self) -> Option<Self::Item> {
+    default fn last(self) -> Option<Self::Item> {
         #[inline]
         fn last<U: Iterator>(last: Option<U::Item>, iter: U) -> Option<U::Item> {
             iter.last().or(last)
@@ -707,13 +708,14 @@ where
     }
 }
 
+// See also the `OneShot` specialization below.
 impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
 where
     I: DoubleEndedIterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
     U: DoubleEndedIterator,
 {
     #[inline]
-    fn next_back(&mut self) -> Option<U::Item> {
+    default fn next_back(&mut self) -> Option<U::Item> {
         loop {
             if let elt @ Some(_) = and_then_or_clear(&mut self.backiter, |b| b.next_back()) {
                 return elt;
@@ -726,7 +728,7 @@ where
     }
 
     #[inline]
-    fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
+    default fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
     where
         Self: Sized,
         Fold: FnMut(Acc, Self::Item) -> R,
@@ -743,7 +745,7 @@ where
     }
 
     #[inline]
-    fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
+    default fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
     where
         Fold: FnMut(Acc, Self::Item) -> Acc,
     {
@@ -759,7 +761,7 @@ where
 
     #[inline]
     #[rustc_inherit_overflow_checks]
-    fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
+    default fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
         #[inline]
         #[rustc_inherit_overflow_checks]
         fn advance<U: DoubleEndedIterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
@@ -841,3 +843,198 @@ fn and_then_or_clear<T, U>(opt: &mut Option<T>, f: impl FnOnce(&mut T) -> Option
     }
     x
 }
+
+/// Specialization trait for iterator types that never return more than one item.
+///
+/// Note that we still have to deal with the possibility that the iterator was
+/// already exhausted before it came into our control.
+#[rustc_specialization_trait]
+trait OneShot {}
+
+// These all have exactly one item, if not already consumed.
+impl<T> OneShot for Once<T> {}
+impl<F> OneShot for OnceWith<F> {}
+impl<T> OneShot for array::IntoIter<T, 1> {}
+impl<T> OneShot for option::IntoIter<T> {}
+impl<T> OneShot for option::Iter<'_, T> {}
+impl<T> OneShot for option::IterMut<'_, T> {}
+impl<T> OneShot for result::IntoIter<T> {}
+impl<T> OneShot for result::Iter<'_, T> {}
+impl<T> OneShot for result::IterMut<'_, T> {}
+
+// These are always empty, which is fine to optimize too.
+impl<T> OneShot for Empty<T> {}
+impl<T> OneShot for array::IntoIter<T, 0> {}
+
+// These adaptors never increase the number of items.
+// (There are more possible, but for now this matches BoundedSize above.)
+impl<I: OneShot> OneShot for Cloned<I> {}
+impl<I: OneShot> OneShot for Copied<I> {}
+impl<I: OneShot, P> OneShot for Filter<I, P> {}
+impl<I: OneShot, P> OneShot for FilterMap<I, P> {}
+impl<I: OneShot, F> OneShot for Map<I, F> {}
+
+// Blanket impls pass this property through as well
+// (but we can't do `Box<I>` unless we expose this trait to alloc)
+impl<I: OneShot> OneShot for &mut I {}
+
+#[inline]
+fn into_item<I>(inner: I) -> Option<I::Item>
+where
+    I: IntoIterator<IntoIter: OneShot>,
+{
+    inner.into_iter().next()
+}
+
+#[inline]
+fn flatten_one<I: IntoIterator<IntoIter: OneShot>, Acc>(
+    mut fold: impl FnMut(Acc, I::Item) -> Acc,
+) -> impl FnMut(Acc, I) -> Acc {
+    move |acc, inner| match inner.into_iter().next() {
+        Some(item) => fold(acc, item),
+        None => acc,
+    }
+}
+
+#[inline]
+fn try_flatten_one<I: IntoIterator<IntoIter: OneShot>, Acc, R: Try<Output = Acc>>(
+    mut fold: impl FnMut(Acc, I::Item) -> R,
+) -> impl FnMut(Acc, I) -> R {
+    move |acc, inner| match inner.into_iter().next() {
+        Some(item) => fold(acc, item),
+        None => try { acc },
+    }
+}
+
+#[inline]
+fn advance_by_one<I>(n: NonZero<usize>, inner: I) -> Option<NonZero<usize>>
+where
+    I: IntoIterator<IntoIter: OneShot>,
+{
+    match inner.into_iter().next() {
+        Some(_) => NonZero::new(n.get() - 1),
+        None => Some(n),
+    }
+}
+
+// Specialization: When the inner iterator `U` never returns more than one item, the `frontiter` and
+// `backiter` states are a waste, because they'll always have already consumed their item. So in
+// this impl, we completely ignore them and just focus on `self.iter`, and we only call the inner
+// `U::next()` one time.
+//
+// It's mostly fine if we accidentally mix this with the more generic impls, e.g. by forgetting to
+// specialize one of the methods. If the other impl did set the front or back, we wouldn't see it
+// here, but it would be empty anyway; and if the other impl looked for a front or back that we
+// didn't bother setting, it would just see `None` (or a previous empty) and move on.
+//
+// An exception to that is `advance_by(0)` and `advance_back_by(0)`, where the generic impls may set
+// `frontiter` or `backiter` without consuming the item, so we **must** override those.
+impl<I, U> Iterator for FlattenCompat<I, U>
+where
+    I: Iterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
+    U: Iterator + OneShot,
+{
+    #[inline]
+    fn next(&mut self) -> Option<U::Item> {
+        while let Some(inner) = self.iter.next() {
+            if let item @ Some(_) = inner.into_iter().next() {
+                return item;
+            }
+        }
+        None
+    }
+
+    #[inline]
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        let (lower, upper) = self.iter.size_hint();
+        match <I::Item as ConstSizeIntoIterator>::size() {
+            Some(0) => (0, Some(0)),
+            Some(1) => (lower, upper),
+            _ => (0, upper),
+        }
+    }
+
+    #[inline]
+    fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
+    where
+        Self: Sized,
+        Fold: FnMut(Acc, Self::Item) -> R,
+        R: Try<Output = Acc>,
+    {
+        self.iter.try_fold(init, try_flatten_one(fold))
+    }
+
+    #[inline]
+    fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
+    where
+        Fold: FnMut(Acc, Self::Item) -> Acc,
+    {
+        self.iter.fold(init, flatten_one(fold))
+    }
+
+    #[inline]
+    fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
+        if let Some(n) = NonZero::new(n) {
+            self.iter.try_fold(n, advance_by_one).map_or(Ok(()), Err)
+        } else {
+            // Just advance the outer iterator
+            self.iter.advance_by(0)
+        }
+    }
+
+    #[inline]
+    fn count(self) -> usize {
+        self.iter.filter_map(into_item).count()
+    }
+
+    #[inline]
+    fn last(self) -> Option<Self::Item> {
+        self.iter.filter_map(into_item).last()
+    }
+}
+
+// Note: We don't actually care about `U: DoubleEndedIterator`, since forward and backward are the
+// same for a one-shot iterator, but we have to keep that to match the default specialization.
+impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
+where
+    I: DoubleEndedIterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
+    U: DoubleEndedIterator + OneShot,
+{
+    #[inline]
+    fn next_back(&mut self) -> Option<U::Item> {
+        while let Some(inner) = self.iter.next_back() {
+            if let item @ Some(_) = inner.into_iter().next() {
+                return item;
+            }
+        }
+        None
+    }
+
+    #[inline]
+    fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
+    where
+        Self: Sized,
+        Fold: FnMut(Acc, Self::Item) -> R,
+        R: Try<Output = Acc>,
+    {
+        self.iter.try_rfold(init, try_flatten_one(fold))
+    }
+
+    #[inline]
+    fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
+    where
+        Fold: FnMut(Acc, Self::Item) -> Acc,
+    {
+        self.iter.rfold(init, flatten_one(fold))
+    }
+
+    #[inline]
+    fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
+        if let Some(n) = NonZero::new(n) {
+            self.iter.try_rfold(n, advance_by_one).map_or(Ok(()), Err)
+        } else {
+            // Just advance the outer iterator
+            self.iter.advance_back_by(0)
+        }
+    }
+}
diff --git a/library/core/tests/iter/adapters/flatten.rs b/library/core/tests/iter/adapters/flatten.rs
index 2af7e0c388a..1f953f2aa01 100644
--- a/library/core/tests/iter/adapters/flatten.rs
+++ b/library/core/tests/iter/adapters/flatten.rs
@@ -212,3 +212,69 @@ fn test_flatten_last() {
     assert_eq!(it.advance_by(3), Ok(())); // 22..22
     assert_eq!(it.clone().last(), None);
 }
+
+#[test]
+fn test_flatten_one_shot() {
+    // This could be `filter_map`, but people often do flatten options.
+    let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7));
+    assert_eq!(it.size_hint(), (0, Some(10)));
+    assert_eq!(it.clone().count(), 8);
+    assert_eq!(it.clone().last(), NonZero::new(2));
+
+    // sum -> fold
+    let sum: i8 = it.clone().map(|n| n.get()).sum();
+    assert_eq!(sum, 24);
+
+    // the product overflows at 6, remaining are 7,8,9 -> 1,2
+    let one = NonZero::new(1i8).unwrap();
+    let product = it.try_fold(one, |acc, x| acc.checked_mul(x));
+    assert_eq!(product, None);
+    assert_eq!(it.size_hint(), (0, Some(3)));
+    assert_eq!(it.clone().count(), 2);
+
+    assert_eq!(it.advance_by(0), Ok(()));
+    assert_eq!(it.clone().next(), NonZero::new(1));
+    assert_eq!(it.advance_by(1), Ok(()));
+    assert_eq!(it.clone().next(), NonZero::new(2));
+    assert_eq!(it.advance_by(100), Err(NonZero::new(99).unwrap()));
+    assert_eq!(it.next(), None);
+}
+
+#[test]
+fn test_flatten_one_shot_rev() {
+    let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7)).rev();
+    assert_eq!(it.size_hint(), (0, Some(10)));
+    assert_eq!(it.clone().count(), 8);
+    assert_eq!(it.clone().last(), NonZero::new(1));
+
+    // sum -> Rev fold -> rfold
+    let sum: i8 = it.clone().map(|n| n.get()).sum();
+    assert_eq!(sum, 24);
+
+    // Rev try_fold -> try_rfold
+    // the product overflows at 4, remaining are 3,2,1,0 -> 3,2,1
+    let one = NonZero::new(1i8).unwrap();
+    let product = it.try_fold(one, |acc, x| acc.checked_mul(x));
+    assert_eq!(product, None);
+    assert_eq!(it.size_hint(), (0, Some(4)));
+    assert_eq!(it.clone().count(), 3);
+
+    // Rev advance_by -> advance_back_by
+    assert_eq!(it.advance_by(0), Ok(()));
+    assert_eq!(it.clone().next(), NonZero::new(3));
+    assert_eq!(it.advance_by(1), Ok(()));
+    assert_eq!(it.clone().next(), NonZero::new(2));
+    assert_eq!(it.advance_by(100), Err(NonZero::new(98).unwrap()));
+    assert_eq!(it.next(), None);
+}
+
+#[test]
+fn test_flatten_one_shot_arrays() {
+    let it = (0..10).flat_map(|i| [i]);
+    assert_eq!(it.size_hint(), (10, Some(10)));
+    assert_eq!(it.sum::<i32>(), 45);
+
+    let mut it = (0..10).flat_map(|_| -> [i32; 0] { [] });
+    assert_eq!(it.size_hint(), (0, Some(0)));
+    assert_eq!(it.next(), None);
+}