about summary refs log tree commit diff
diff options
context:
space:
mode:
authorScott McMurray <scottmcm@users.noreply.github.com>2022-04-02 14:29:41 -0700
committerScott McMurray <scottmcm@users.noreply.github.com>2022-04-02 14:29:41 -0700
commit83595f9242ad9e8a7da091f65d450e44e4434f89 (patch)
tree98404a3bad2a0be64721abdd38e705e45b6ab209
parenteb82facb1626166188d49599a3313fc95201f556 (diff)
downloadrust-83595f9242ad9e8a7da091f65d450e44e4434f89.tar.gz
rust-83595f9242ad9e8a7da091f65d450e44e4434f89.zip
Fix `array::IntoIter::fold` to use the optimized `Range::fold`
It was using `Iterator::by_ref` in the implementation, which ended up pessimizing it enough that, for example, it didn't vectorize when we tried it in the <https://rust-lang.zulipchat.com/#narrow/stream/257879-project-portable-simd/topic/Reducing.20sum.20into.20wider.20types> conversation.

Demonstration that the codegen test doesn't pass on the current nightly: <https://rust.godbolt.org/z/Taxev5eMn>
-rw-r--r--library/core/src/array/iter.rs16
-rw-r--r--library/core/src/iter/adapters/by_ref_sized.rs29
-rw-r--r--library/core/tests/array.rs32
-rw-r--r--src/test/codegen/simd-wide-sum.rs54
4 files changed, 130 insertions, 1 deletions
diff --git a/library/core/src/array/iter.rs b/library/core/src/array/iter.rs
index e5024c215be..baf2f2d6c97 100644
--- a/library/core/src/array/iter.rs
+++ b/library/core/src/array/iter.rs
@@ -266,7 +266,7 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> {
         Fold: FnMut(Acc, Self::Item) -> Acc,
     {
         let data = &mut self.data;
-        self.alive.by_ref().fold(init, |acc, idx| {
+        iter::ByRefSized(&mut self.alive).fold(init, |acc, idx| {
             // SAFETY: idx is obtained by folding over the `alive` range, which implies the
             // value is currently considered alive but as the range is being consumed each value
             // we read here will only be read once and then considered dead.
@@ -323,6 +323,20 @@ impl<T, const N: usize> DoubleEndedIterator for IntoIter<T, N> {
         })
     }
 
+    #[inline]
+    fn rfold<Acc, Fold>(mut self, init: Acc, mut rfold: Fold) -> Acc
+    where
+        Fold: FnMut(Acc, Self::Item) -> Acc,
+    {
+        let data = &mut self.data;
+        iter::ByRefSized(&mut self.alive).rfold(init, |acc, idx| {
+            // SAFETY: idx is obtained by folding over the `alive` range, which implies the
+            // value is currently considered alive but as the range is being consumed each value
+            // we read here will only be read once and then considered dead.
+            rfold(acc, unsafe { data.get_unchecked(idx).assume_init_read() })
+        })
+    }
+
     fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
         let len = self.len();
 
diff --git a/library/core/src/iter/adapters/by_ref_sized.rs b/library/core/src/iter/adapters/by_ref_sized.rs
index 0b5e2a89ef3..890faa5975f 100644
--- a/library/core/src/iter/adapters/by_ref_sized.rs
+++ b/library/core/src/iter/adapters/by_ref_sized.rs
@@ -40,3 +40,32 @@ impl<I: Iterator> Iterator for ByRefSized<'_, I> {
         self.0.try_fold(init, f)
     }
 }
+
+impl<I: DoubleEndedIterator> DoubleEndedIterator for ByRefSized<'_, I> {
+    fn next_back(&mut self) -> Option<Self::Item> {
+        self.0.next_back()
+    }
+
+    fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
+        self.0.advance_back_by(n)
+    }
+
+    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
+        self.0.nth_back(n)
+    }
+
+    fn rfold<B, F>(self, init: B, f: F) -> B
+    where
+        F: FnMut(B, Self::Item) -> B,
+    {
+        self.0.rfold(init, f)
+    }
+
+    fn try_rfold<B, F, R>(&mut self, init: B, f: F) -> R
+    where
+        F: FnMut(B, Self::Item) -> R,
+        R: Try<Output = B>,
+    {
+        self.0.try_rfold(init, f)
+    }
+}
diff --git a/library/core/tests/array.rs b/library/core/tests/array.rs
index a778779c0fd..ee7ff012ec1 100644
--- a/library/core/tests/array.rs
+++ b/library/core/tests/array.rs
@@ -668,3 +668,35 @@ fn array_mixed_equality_nans() {
     assert!(!(mut3 == array3));
     assert!(mut3 != array3);
 }
+
+#[test]
+fn array_into_iter_fold() {
+    // Strings to help MIRI catch if we double-free or something
+    let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
+    let mut s = "s".to_string();
+    a.into_iter().for_each(|b| s += &b);
+    assert_eq!(s, "sAaBbCc");
+
+    let a = [1, 2, 3, 4, 5, 6];
+    let mut it = a.into_iter();
+    it.advance_by(1).unwrap();
+    it.advance_back_by(2).unwrap();
+    let s = it.fold(10, |a, b| 10 * a + b);
+    assert_eq!(s, 10234);
+}
+
+#[test]
+fn array_into_iter_rfold() {
+    // Strings to help MIRI catch if we double-free or something
+    let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
+    let mut s = "s".to_string();
+    a.into_iter().rev().for_each(|b| s += &b);
+    assert_eq!(s, "sCcBbAa");
+
+    let a = [1, 2, 3, 4, 5, 6];
+    let mut it = a.into_iter();
+    it.advance_by(1).unwrap();
+    it.advance_back_by(2).unwrap();
+    let s = it.rfold(10, |a, b| 10 * a + b);
+    assert_eq!(s, 10432);
+}
diff --git a/src/test/codegen/simd-wide-sum.rs b/src/test/codegen/simd-wide-sum.rs
new file mode 100644
index 00000000000..fde9b0fcd8a
--- /dev/null
+++ b/src/test/codegen/simd-wide-sum.rs
@@ -0,0 +1,54 @@
+// compile-flags: -C opt-level=3 --edition=2021
+// only-x86_64
+// ignore-debug: the debug assertions get in the way
+
+#![crate_type = "lib"]
+#![feature(portable_simd)]
+
+use std::simd::Simd;
+const N: usize = 8;
+
+#[no_mangle]
+// CHECK-LABEL: @wider_reduce_simd
+pub fn wider_reduce_simd(x: Simd<u8, N>) -> u16 {
+    // CHECK: zext <8 x i8>
+    // CHECK-SAME: to <8 x i16>
+    // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
+    let x: Simd<u16, N> = x.cast();
+    x.reduce_sum()
+}
+
+#[no_mangle]
+// CHECK-LABEL: @wider_reduce_loop
+pub fn wider_reduce_loop(x: Simd<u8, N>) -> u16 {
+    // CHECK: zext <8 x i8>
+    // CHECK-SAME: to <8 x i16>
+    // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
+    let mut sum = 0_u16;
+    for i in 0..N {
+        sum += u16::from(x[i]);
+    }
+    sum
+}
+
+#[no_mangle]
+// CHECK-LABEL: @wider_reduce_iter
+pub fn wider_reduce_iter(x: Simd<u8, N>) -> u16 {
+    // CHECK: zext <8 x i8>
+    // CHECK-SAME: to <8 x i16>
+    // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
+    x.as_array().iter().copied().map(u16::from).sum()
+}
+
+// This iterator one is the most interesting, as it's the one
+// which used to not auto-vectorize due to a suboptimality in the
+// `<array::IntoIter as Iterator>::fold` implementation.
+
+#[no_mangle]
+// CHECK-LABEL: @wider_reduce_into_iter
+pub fn wider_reduce_into_iter(x: Simd<u8, N>) -> u16 {
+    // CHECK: zext <8 x i8>
+    // CHECK-SAME: to <8 x i16>
+    // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
+    x.to_array().into_iter().map(u16::from).sum()
+}