about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/iter/traits/iterator.rs36
1 files changed, 31 insertions, 5 deletions
diff --git a/library/core/src/iter/traits/iterator.rs b/library/core/src/iter/traits/iterator.rs
index cf85bdb1352..f296792b1dc 100644
--- a/library/core/src/iter/traits/iterator.rs
+++ b/library/core/src/iter/traits/iterator.rs
@@ -294,13 +294,39 @@ pub trait Iterator {
     #[inline]
     #[unstable(feature = "iter_advance_by", reason = "recently added", issue = "77404")]
     fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
-        for i in 0..n {
-            if self.next().is_none() {
-                // SAFETY: `i` is always less than `n`.
-                return Err(unsafe { NonZero::new_unchecked(n - i) });
+        /// Helper trait to specialize `advance_by` via `try_fold` for `Sized` iterators.
+        trait SpecAdvanceBy {
+            fn spec_advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>>;
+        }
+
+        impl<I: Iterator + ?Sized> SpecAdvanceBy for I {
+            default fn spec_advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
+                for i in 0..n {
+                    if self.next().is_none() {
+                        // SAFETY: `i` is always less than `n`.
+                        return Err(unsafe { NonZero::new_unchecked(n - i) });
+                    }
+                }
+                Ok(())
+            }
+        }
+
+        impl<I: Iterator> SpecAdvanceBy for I {
+            fn spec_advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
+                let Some(n) = NonZero::new(n) else {
+                    return Ok(());
+                };
+
+                let res = self.try_fold(n, |n, _| NonZero::new(n.get() - 1));
+
+                match res {
+                    None => Ok(()),
+                    Some(n) => Err(n),
+                }
             }
         }
-        Ok(())
+
+        self.spec_advance_by(n)
     }
 
     /// Returns the `n`th element of the iterator.