about summary refs log tree commit diff
diff options
context:
space:
mode:
authorkennytm <kennytm@gmail.com>2018-02-01 02:34:15 +0800
committerkennytm <kennytm@gmail.com>2018-02-01 02:34:15 +0800
commitaf95302d3cb2f5cd59eb38b88eab6a054d965e9f (patch)
treef26b25cc9404a7af1260814d824a1c88f454fd57
parent86eb7259539a5dd0763a5703fb948bae6c6b064d (diff)
parent4a0da4cf2c7a2b5903fd1b8bc124f8963ce1b535 (diff)
downloadrust-af95302d3cb2f5cd59eb38b88eab6a054d965e9f.tar.gz
rust-af95302d3cb2f5cd59eb38b88eab6a054d965e9f.zip
Rollup merge of #47552 - oberien:stepby-nth, r=dtolnay
Specialize StepBy::nth

This allows optimizations of implementations of the inner iterator's `.nth` method.
-rw-r--r--src/libcore/iter/mod.rs44
-rw-r--r--src/libcore/tests/iter.rs62
2 files changed, 106 insertions, 0 deletions
diff --git a/src/libcore/iter/mod.rs b/src/libcore/iter/mod.rs
index 06c29b47bf9..7314fac282b 100644
--- a/src/libcore/iter/mod.rs
+++ b/src/libcore/iter/mod.rs
@@ -307,6 +307,7 @@ use fmt;
 use iter_private::TrustedRandomAccess;
 use ops::Try;
 use usize;
+use intrinsics;
 
 #[stable(feature = "rust1", since = "1.0.0")]
 pub use self::iterator::Iterator;
@@ -694,6 +695,49 @@ impl<I> Iterator for StepBy<I> where I: Iterator {
             (f(inner_hint.0), inner_hint.1.map(f))
         }
     }
+
+    #[inline]
+    fn nth(&mut self, mut n: usize) -> Option<Self::Item> {
+        if self.first_take {
+            self.first_take = false;
+            let first = self.iter.next();
+            if n == 0 {
+                return first;
+            }
+            n -= 1;
+        }
+        // n and self.step are indices, we need to add 1 to get the amount of elements
+        // When calling `.nth`, we need to subtract 1 again to convert back to an index
+        // step + 1 can't overflow because `.step_by` sets `self.step` to `step - 1`
+        let mut step = self.step + 1;
+        // n + 1 could overflow
+        // thus, if n is usize::MAX, instead of adding one, we call .nth(step)
+        if n == usize::MAX {
+            self.iter.nth(step - 1);
+        } else {
+            n += 1;
+        }
+
+        // overflow handling
+        loop {
+            let mul = n.checked_mul(step);
+            if unsafe { intrinsics::likely(mul.is_some()) } {
+                return self.iter.nth(mul.unwrap() - 1);
+            }
+            let div_n = usize::MAX / n;
+            let div_step = usize::MAX / step;
+            let nth_n = div_n * n;
+            let nth_step = div_step * step;
+            let nth = if nth_n > nth_step {
+                step -= div_n;
+                nth_n
+            } else {
+                n -= div_step;
+                nth_step
+            };
+            self.iter.nth(nth - 1);
+        }
+    }
 }
 
 // StepBy can only make the iterator shorter, so the len will still fit.
diff --git a/src/libcore/tests/iter.rs b/src/libcore/tests/iter.rs
index 8997cf9c6bf..e52e119ff59 100644
--- a/src/libcore/tests/iter.rs
+++ b/src/libcore/tests/iter.rs
@@ -162,6 +162,68 @@ fn test_iterator_step_by() {
 }
 
 #[test]
+fn test_iterator_step_by_nth() {
+    let mut it = (0..16).step_by(5);
+    assert_eq!(it.nth(0), Some(0));
+    assert_eq!(it.nth(0), Some(5));
+    assert_eq!(it.nth(0), Some(10));
+    assert_eq!(it.nth(0), Some(15));
+    assert_eq!(it.nth(0), None);
+
+    let it = (0..18).step_by(5);
+    assert_eq!(it.clone().nth(0), Some(0));
+    assert_eq!(it.clone().nth(1), Some(5));
+    assert_eq!(it.clone().nth(2), Some(10));
+    assert_eq!(it.clone().nth(3), Some(15));
+    assert_eq!(it.clone().nth(4), None);
+    assert_eq!(it.clone().nth(42), None);
+}
+
+#[test]
+fn test_iterator_step_by_nth_overflow() {
+    #[cfg(target_pointer_width = "8")]
+    type Bigger = u16;
+    #[cfg(target_pointer_width = "16")]
+    type Bigger = u32;
+    #[cfg(target_pointer_width = "32")]
+    type Bigger = u64;
+    #[cfg(target_pointer_width = "64")]
+    type Bigger = u128;
+
+    #[derive(Clone)]
+    struct Test(Bigger);
+    impl<'a> Iterator for &'a mut Test {
+        type Item = i32;
+        fn next(&mut self) -> Option<Self::Item> { Some(21) }
+        fn nth(&mut self, n: usize) -> Option<Self::Item> {
+            self.0 += n as Bigger + 1;
+            Some(42)
+        }
+    }
+
+    let mut it = Test(0);
+    let root = usize::MAX >> (::std::mem::size_of::<usize>() * 8 / 2);
+    let n = root + 20;
+    (&mut it).step_by(n).nth(n);
+    assert_eq!(it.0, n as Bigger * n as Bigger);
+
+    // large step
+    let mut it = Test(0);
+    (&mut it).step_by(usize::MAX).nth(5);
+    assert_eq!(it.0, (usize::MAX as Bigger) * 5);
+
+    // n + 1 overflows
+    let mut it = Test(0);
+    (&mut it).step_by(2).nth(usize::MAX);
+    assert_eq!(it.0, (usize::MAX as Bigger) * 2);
+
+    // n + 1 overflows
+    let mut it = Test(0);
+    (&mut it).step_by(1).nth(usize::MAX);
+    assert_eq!(it.0, (usize::MAX as Bigger) * 1);
+}
+
+#[test]
 #[should_panic]
 fn test_iterator_step_by_zero() {
     let mut it = (0..).step_by(0);