about summary refs log tree commit diff
diff options
context:
space:
mode:
authoroberien <jaro.fietz@gmx.de>2018-01-19 21:07:01 +0100
committeroberien <jaro.fietz@gmx.de>2018-01-19 21:07:01 +0100
commitf08dec114f6008cb7a906ffd55c221fd30d70987 (patch)
tree933de65aec9c7b95e4e090032a86779b5d84f662
parentd33cc12eed3df459db3c9ae2dd89df9cc6e45dd6 (diff)
downloadrust-f08dec114f6008cb7a906ffd55c221fd30d70987.tar.gz
rust-f08dec114f6008cb7a906ffd55c221fd30d70987.zip
Handle Overflow
-rw-r--r--src/libcore/iter/mod.rs31
-rw-r--r--src/libcore/tests/iter.rs44
2 files changed, 72 insertions, 3 deletions
diff --git a/src/libcore/iter/mod.rs b/src/libcore/iter/mod.rs
index 33c00989fd2..57e7e03a6ce 100644
--- a/src/libcore/iter/mod.rs
+++ b/src/libcore/iter/mod.rs
@@ -705,9 +705,34 @@ impl<I> Iterator for StepBy<I> where I: Iterator {
             }
             n -= 1;
         }
-        // n and self.step are indices, thus we need to add 1 before multiplying.
-        // After that we need to subtract 1 from the result to convert it back to an index.
-        self.iter.nth((n + 1) * (self.step + 1) - 1)
+        // n and self.step are indices, we need to add 1 to get the amount of elements
+        // When calling `.nth`, we need to subtract 1 again to convert back to an index
+        // step + 1 can't overflow because `.step_by` sets `self.step` to `step - 1`
+        let mut step = self.step + 1;
+        // n + 1 could overflow
+        // thus, if n is usize::MAX, instead of adding one, we call .nth(step)
+        if n == usize::MAX {
+            self.iter.nth(step - 1);
+        } else {
+            n += 1;
+        }
+
+        // overflow handling
+        while n.checked_mul(step).is_none() {
+            let div_n = usize::MAX / n;
+            let div_step = usize::MAX / step;
+            let nth_n = div_n * n;
+            let nth_step = div_step * step;
+            let nth = if nth_n > nth_step {
+                step -= div_n;
+                nth_n
+            } else {
+                n -= div_step;
+                nth_step
+            };
+            self.iter.nth(nth - 1);
+        }
+        self.iter.nth(n * step - 1)
     }
 }
 
diff --git a/src/libcore/tests/iter.rs b/src/libcore/tests/iter.rs
index e8874991659..e52e119ff59 100644
--- a/src/libcore/tests/iter.rs
+++ b/src/libcore/tests/iter.rs
@@ -180,6 +180,50 @@ fn test_iterator_step_by_nth() {
 }
 
 #[test]
+fn test_iterator_step_by_nth_overflow() {
+    #[cfg(target_pointer_width = "8")]
+    type Bigger = u16;
+    #[cfg(target_pointer_width = "16")]
+    type Bigger = u32;
+    #[cfg(target_pointer_width = "32")]
+    type Bigger = u64;
+    #[cfg(target_pointer_width = "64")]
+    type Bigger = u128;
+
+    #[derive(Clone)]
+    struct Test(Bigger);
+    impl<'a> Iterator for &'a mut Test {
+        type Item = i32;
+        fn next(&mut self) -> Option<Self::Item> { Some(21) }
+        fn nth(&mut self, n: usize) -> Option<Self::Item> {
+            self.0 += n as Bigger + 1;
+            Some(42)
+        }
+    }
+
+    let mut it = Test(0);
+    let root = usize::MAX >> (::std::mem::size_of::<usize>() * 8 / 2);
+    let n = root + 20;
+    (&mut it).step_by(n).nth(n);
+    assert_eq!(it.0, n as Bigger * n as Bigger);
+
+    // large step
+    let mut it = Test(0);
+    (&mut it).step_by(usize::MAX).nth(5);
+    assert_eq!(it.0, (usize::MAX as Bigger) * 5);
+
+    // n + 1 overflows
+    let mut it = Test(0);
+    (&mut it).step_by(2).nth(usize::MAX);
+    assert_eq!(it.0, (usize::MAX as Bigger) * 2);
+
+    // n + 1 overflows
+    let mut it = Test(0);
+    (&mut it).step_by(1).nth(usize::MAX);
+    assert_eq!(it.0, (usize::MAX as Bigger) * 1);
+}
+
+#[test]
 #[should_panic]
 fn test_iterator_step_by_zero() {
     let mut it = (0..).step_by(0);