about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMazdak Farrokhzad <twingoow@gmail.com>2020-04-09 05:29:43 +0200
committerGitHub <noreply@github.com>2020-04-09 05:29:43 +0200
commitecc4e2a647533ddabb7a1b69f6ab1385c3f4003c (patch)
treede8e9f80a0c55f6f9e546ede1d3646fd0997c3bd
parentcefee7bd9a7e188e406cac0190cf4a19fcca27ff (diff)
parentce8abc63a7e1fcac4b69574e00a70352e583cec8 (diff)
downloadrust-ecc4e2a647533ddabb7a1b69f6ab1385c3f4003c.tar.gz
rust-ecc4e2a647533ddabb7a1b69f6ab1385c3f4003c.zip
Rollup merge of #70896 - cuviper:optional-chain, r=scottmcm
Implement Chain with Option fuses

The iterators are now "fused" with `Option` so we don't need separate state to track which part is already exhausted, and we may also get niche layout for `None`. We don't use the real `Fuse` adapter because its specialization for `FusedIterator` unconditionally descends into the iterator, and that could be expensive to keep revisiting stuff like nested chains. It also hurts compiler performance to add more iterator layers to `Chain`.

This change was inspired by the [proposal](https://internals.rust-lang.org/t/proposal-implement-iter-chain-using-fuse/12006) on the internals forum. This is an alternate to #70332, directly employing some of the same `Fuse` optimizations as #70366 and #70750.

r? @scottmcm
-rw-r--r--src/libcore/iter/adapters/chain.rs267
1 files changed, 113 insertions, 154 deletions
diff --git a/src/libcore/iter/adapters/chain.rs b/src/libcore/iter/adapters/chain.rs
index 3611a1aadad..2dd405ced20 100644
--- a/src/libcore/iter/adapters/chain.rs
+++ b/src/libcore/iter/adapters/chain.rs
@@ -1,8 +1,7 @@
+use crate::iter::{DoubleEndedIterator, FusedIterator, Iterator, TrustedLen};
 use crate::ops::Try;
 use crate::usize;
 
-use super::super::{DoubleEndedIterator, FusedIterator, Iterator, TrustedLen};
-
 /// An iterator that links two iterators together, in a chain.
 ///
 /// This `struct` is created by the [`chain`] method on [`Iterator`]. See its
@@ -14,37 +13,34 @@ use super::super::{DoubleEndedIterator, FusedIterator, Iterator, TrustedLen};
 #[must_use = "iterators are lazy and do nothing unless consumed"]
 #[stable(feature = "rust1", since = "1.0.0")]
 pub struct Chain<A, B> {
-    a: A,
-    b: B,
-    state: ChainState,
+    // These are "fused" with `Option` so we don't need separate state to track which part is
+    // already exhausted, and we may also get niche layout for `None`. We don't use the real `Fuse`
+    // adapter because its specialization for `FusedIterator` unconditionally descends into the
+    // iterator, and that could be expensive to keep revisiting stuff like nested chains. It also
+    // hurts compiler performance to add more iterator layers to `Chain`.
+    a: Option<A>,
+    b: Option<B>,
 }
 impl<A, B> Chain<A, B> {
     pub(in super::super) fn new(a: A, b: B) -> Chain<A, B> {
-        Chain { a, b, state: ChainState::Both }
+        Chain { a: Some(a), b: Some(b) }
     }
 }
 
-// The iterator protocol specifies that iteration ends with the return value
-// `None` from `.next()` (or `.next_back()`) and it is unspecified what
-// further calls return. The chain adaptor must account for this since it uses
-// two subiterators.
-//
-//  It uses three states:
-//
-//  - Both: `a` and `b` are remaining
-//  - Front: `a` remaining
-//  - Back: `b` remaining
-//
-//  The fourth state (neither iterator is remaining) only occurs after Chain has
-//  returned None once, so we don't need to store this state.
-#[derive(Clone, Debug)]
-enum ChainState {
-    // both front and back iterator are remaining
-    Both,
-    // only front is remaining
-    Front,
-    // only back is remaining
-    Back,
+/// Fuse the iterator if the expression is `None`.
+macro_rules! fuse {
+    ($self:ident . $iter:ident . $($call:tt)+) => {
+        match $self.$iter {
+            Some(ref mut iter) => match iter.$($call)+ {
+                None => {
+                    $self.$iter = None;
+                    None
+                }
+                item => item,
+            },
+            None => None,
+        }
+    };
 }
 
 #[stable(feature = "rust1", since = "1.0.0")]
@@ -57,88 +53,68 @@ where
 
     #[inline]
     fn next(&mut self) -> Option<A::Item> {
-        match self.state {
-            ChainState::Both => match self.a.next() {
-                elt @ Some(..) => elt,
-                None => {
-                    self.state = ChainState::Back;
-                    self.b.next()
-                }
-            },
-            ChainState::Front => self.a.next(),
-            ChainState::Back => self.b.next(),
+        match fuse!(self.a.next()) {
+            None => fuse!(self.b.next()),
+            item => item,
         }
     }
 
     #[inline]
     #[rustc_inherit_overflow_checks]
     fn count(self) -> usize {
-        match self.state {
-            ChainState::Both => self.a.count() + self.b.count(),
-            ChainState::Front => self.a.count(),
-            ChainState::Back => self.b.count(),
-        }
+        let a_count = match self.a {
+            Some(a) => a.count(),
+            None => 0,
+        };
+        let b_count = match self.b {
+            Some(b) => b.count(),
+            None => 0,
+        };
+        a_count + b_count
     }
 
-    fn try_fold<Acc, F, R>(&mut self, init: Acc, mut f: F) -> R
+    fn try_fold<Acc, F, R>(&mut self, mut acc: Acc, mut f: F) -> R
     where
         Self: Sized,
         F: FnMut(Acc, Self::Item) -> R,
         R: Try<Ok = Acc>,
     {
-        let mut accum = init;
-        match self.state {
-            ChainState::Both | ChainState::Front => {
-                accum = self.a.try_fold(accum, &mut f)?;
-                if let ChainState::Both = self.state {
-                    self.state = ChainState::Back;
-                }
-            }
-            _ => {}
+        if let Some(ref mut a) = self.a {
+            acc = a.try_fold(acc, &mut f)?;
+            self.a = None;
         }
-        if let ChainState::Back = self.state {
-            accum = self.b.try_fold(accum, &mut f)?;
+        if let Some(ref mut b) = self.b {
+            acc = b.try_fold(acc, f)?;
+            self.b = None;
         }
-        Try::from_ok(accum)
+        Try::from_ok(acc)
     }
 
-    fn fold<Acc, F>(self, init: Acc, mut f: F) -> Acc
+    fn fold<Acc, F>(self, mut acc: Acc, mut f: F) -> Acc
     where
         F: FnMut(Acc, Self::Item) -> Acc,
     {
-        let mut accum = init;
-        match self.state {
-            ChainState::Both | ChainState::Front => {
-                accum = self.a.fold(accum, &mut f);
-            }
-            _ => {}
+        if let Some(a) = self.a {
+            acc = a.fold(acc, &mut f);
         }
-        match self.state {
-            ChainState::Both | ChainState::Back => {
-                accum = self.b.fold(accum, &mut f);
-            }
-            _ => {}
+        if let Some(b) = self.b {
+            acc = b.fold(acc, f);
         }
-        accum
+        acc
     }
 
     #[inline]
     fn nth(&mut self, mut n: usize) -> Option<A::Item> {
-        match self.state {
-            ChainState::Both | ChainState::Front => {
-                for x in self.a.by_ref() {
-                    if n == 0 {
-                        return Some(x);
-                    }
-                    n -= 1;
-                }
-                if let ChainState::Both = self.state {
-                    self.state = ChainState::Back;
+        if let Some(ref mut a) = self.a {
+            while let Some(x) = a.next() {
+                if n == 0 {
+                    return Some(x);
                 }
+                n -= 1;
             }
-            ChainState::Back => {}
+            self.a = None;
         }
-        if let ChainState::Back = self.state { self.b.nth(n) } else { None }
+        fuse!(self.b.nth(n))
     }
 
     #[inline]
@@ -146,39 +122,32 @@ where
     where
         P: FnMut(&Self::Item) -> bool,
     {
-        match self.state {
-            ChainState::Both => match self.a.find(&mut predicate) {
-                None => {
-                    self.state = ChainState::Back;
-                    self.b.find(predicate)
-                }
-                v => v,
-            },
-            ChainState::Front => self.a.find(predicate),
-            ChainState::Back => self.b.find(predicate),
+        match fuse!(self.a.find(&mut predicate)) {
+            None => fuse!(self.b.find(predicate)),
+            item => item,
         }
     }
 
     #[inline]
     fn last(self) -> Option<A::Item> {
-        match self.state {
-            ChainState::Both => {
-                // Must exhaust a before b.
-                let a_last = self.a.last();
-                let b_last = self.b.last();
-                b_last.or(a_last)
-            }
-            ChainState::Front => self.a.last(),
-            ChainState::Back => self.b.last(),
-        }
+        // Must exhaust a before b.
+        let a_last = match self.a {
+            Some(a) => a.last(),
+            None => None,
+        };
+        let b_last = match self.b {
+            Some(b) => b.last(),
+            None => None,
+        };
+        b_last.or(a_last)
     }
 
     #[inline]
     fn size_hint(&self) -> (usize, Option<usize>) {
-        match self.state {
-            ChainState::Both => {
-                let (a_lower, a_upper) = self.a.size_hint();
-                let (b_lower, b_upper) = self.b.size_hint();
+        match self {
+            Chain { a: Some(a), b: Some(b) } => {
+                let (a_lower, a_upper) = a.size_hint();
+                let (b_lower, b_upper) = b.size_hint();
 
                 let lower = a_lower.saturating_add(b_lower);
 
@@ -189,8 +158,9 @@ where
 
                 (lower, upper)
             }
-            ChainState::Front => self.a.size_hint(),
-            ChainState::Back => self.b.size_hint(),
+            Chain { a: Some(a), b: None } => a.size_hint(),
+            Chain { a: None, b: Some(b) } => b.size_hint(),
+            Chain { a: None, b: None } => (0, Some(0)),
         }
     }
 }
@@ -203,82 +173,71 @@ where
 {
     #[inline]
     fn next_back(&mut self) -> Option<A::Item> {
-        match self.state {
-            ChainState::Both => match self.b.next_back() {
-                elt @ Some(..) => elt,
-                None => {
-                    self.state = ChainState::Front;
-                    self.a.next_back()
-                }
-            },
-            ChainState::Front => self.a.next_back(),
-            ChainState::Back => self.b.next_back(),
+        match fuse!(self.b.next_back()) {
+            None => fuse!(self.a.next_back()),
+            item => item,
         }
     }
 
     #[inline]
     fn nth_back(&mut self, mut n: usize) -> Option<A::Item> {
-        match self.state {
-            ChainState::Both | ChainState::Back => {
-                for x in self.b.by_ref().rev() {
-                    if n == 0 {
-                        return Some(x);
-                    }
-                    n -= 1;
-                }
-                if let ChainState::Both = self.state {
-                    self.state = ChainState::Front;
+        if let Some(ref mut b) = self.b {
+            while let Some(x) = b.next_back() {
+                if n == 0 {
+                    return Some(x);
                 }
+                n -= 1;
             }
-            ChainState::Front => {}
+            self.b = None;
         }
-        if let ChainState::Front = self.state { self.a.nth_back(n) } else { None }
+        fuse!(self.a.nth_back(n))
     }
 
-    fn try_rfold<Acc, F, R>(&mut self, init: Acc, mut f: F) -> R
+    #[inline]
+    fn rfind<P>(&mut self, mut predicate: P) -> Option<Self::Item>
+    where
+        P: FnMut(&Self::Item) -> bool,
+    {
+        match fuse!(self.b.rfind(&mut predicate)) {
+            None => fuse!(self.a.rfind(predicate)),
+            item => item,
+        }
+    }
+
+    fn try_rfold<Acc, F, R>(&mut self, mut acc: Acc, mut f: F) -> R
     where
         Self: Sized,
         F: FnMut(Acc, Self::Item) -> R,
         R: Try<Ok = Acc>,
     {
-        let mut accum = init;
-        match self.state {
-            ChainState::Both | ChainState::Back => {
-                accum = self.b.try_rfold(accum, &mut f)?;
-                if let ChainState::Both = self.state {
-                    self.state = ChainState::Front;
-                }
-            }
-            _ => {}
+        if let Some(ref mut b) = self.b {
+            acc = b.try_rfold(acc, &mut f)?;
+            self.b = None;
         }
-        if let ChainState::Front = self.state {
-            accum = self.a.try_rfold(accum, &mut f)?;
+        if let Some(ref mut a) = self.a {
+            acc = a.try_rfold(acc, f)?;
+            self.a = None;
         }
-        Try::from_ok(accum)
+        Try::from_ok(acc)
     }
 
-    fn rfold<Acc, F>(self, init: Acc, mut f: F) -> Acc
+    fn rfold<Acc, F>(self, mut acc: Acc, mut f: F) -> Acc
     where
         F: FnMut(Acc, Self::Item) -> Acc,
     {
-        let mut accum = init;
-        match self.state {
-            ChainState::Both | ChainState::Back => {
-                accum = self.b.rfold(accum, &mut f);
-            }
-            _ => {}
+        if let Some(b) = self.b {
+            acc = b.rfold(acc, &mut f);
         }
-        match self.state {
-            ChainState::Both | ChainState::Front => {
-                accum = self.a.rfold(accum, &mut f);
-            }
-            _ => {}
+        if let Some(a) = self.a {
+            acc = a.rfold(acc, f);
         }
-        accum
+        acc
     }
 }
 
 // Note: *both* must be fused to handle double-ended iterators.
+// Now that we "fuse" both sides, we *could* implement this unconditionally,
+// but we should be cautious about committing to that in the public API.
 #[stable(feature = "fused", since = "1.26.0")]
 impl<A, B> FusedIterator for Chain<A, B>
 where