diff options
| author | Scott McMurray <scottmcm@users.noreply.github.com> | 2022-04-02 14:29:41 -0700 |
|---|---|---|
| committer | Scott McMurray <scottmcm@users.noreply.github.com> | 2022-04-02 14:29:41 -0700 |
| commit | 83595f9242ad9e8a7da091f65d450e44e4434f89 (patch) | |
| tree | 98404a3bad2a0be64721abdd38e705e45b6ab209 | |
| parent | eb82facb1626166188d49599a3313fc95201f556 (diff) | |
| download | rust-83595f9242ad9e8a7da091f65d450e44e4434f89.tar.gz rust-83595f9242ad9e8a7da091f65d450e44e4434f89.zip | |
Fix `array::IntoIter::fold` to use the optimized `Range::fold`
It was using `Iterator::by_ref` in the implementation, which ended up pessimizing it enough that, for example, it didn't vectorize when we tried it in the <https://rust-lang.zulipchat.com/#narrow/stream/257879-project-portable-simd/topic/Reducing.20sum.20into.20wider.20types> conversation. Demonstration that the codegen test doesn't pass on the current nightly: <https://rust.godbolt.org/z/Taxev5eMn>
| -rw-r--r-- | library/core/src/array/iter.rs | 16 | ||||
| -rw-r--r-- | library/core/src/iter/adapters/by_ref_sized.rs | 29 | ||||
| -rw-r--r-- | library/core/tests/array.rs | 32 | ||||
| -rw-r--r-- | src/test/codegen/simd-wide-sum.rs | 54 |
4 files changed, 130 insertions, 1 deletions
diff --git a/library/core/src/array/iter.rs b/library/core/src/array/iter.rs index e5024c215be..baf2f2d6c97 100644 --- a/library/core/src/array/iter.rs +++ b/library/core/src/array/iter.rs @@ -266,7 +266,7 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> { Fold: FnMut(Acc, Self::Item) -> Acc, { let data = &mut self.data; - self.alive.by_ref().fold(init, |acc, idx| { + iter::ByRefSized(&mut self.alive).fold(init, |acc, idx| { // SAFETY: idx is obtained by folding over the `alive` range, which implies the // value is currently considered alive but as the range is being consumed each value // we read here will only be read once and then considered dead. @@ -323,6 +323,20 @@ impl<T, const N: usize> DoubleEndedIterator for IntoIter<T, N> { }) } + #[inline] + fn rfold<Acc, Fold>(mut self, init: Acc, mut rfold: Fold) -> Acc + where + Fold: FnMut(Acc, Self::Item) -> Acc, + { + let data = &mut self.data; + iter::ByRefSized(&mut self.alive).rfold(init, |acc, idx| { + // SAFETY: idx is obtained by folding over the `alive` range, which implies the + // value is currently considered alive but as the range is being consumed each value + // we read here will only be read once and then considered dead. + rfold(acc, unsafe { data.get_unchecked(idx).assume_init_read() }) + }) + } + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { let len = self.len(); diff --git a/library/core/src/iter/adapters/by_ref_sized.rs b/library/core/src/iter/adapters/by_ref_sized.rs index 0b5e2a89ef3..890faa5975f 100644 --- a/library/core/src/iter/adapters/by_ref_sized.rs +++ b/library/core/src/iter/adapters/by_ref_sized.rs @@ -40,3 +40,32 @@ impl<I: Iterator> Iterator for ByRefSized<'_, I> { self.0.try_fold(init, f) } } + +impl<I: DoubleEndedIterator> DoubleEndedIterator for ByRefSized<'_, I> { + fn next_back(&mut self) -> Option<Self::Item> { + self.0.next_back() + } + + fn advance_back_by(&mut self, n: usize) -> Result<(), usize> { + self.0.advance_back_by(n) + } + + fn nth_back(&mut self, n: usize) -> Option<Self::Item> { + self.0.nth_back(n) + } + + fn rfold<B, F>(self, init: B, f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + self.0.rfold(init, f) + } + + fn try_rfold<B, F, R>(&mut self, init: B, f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: Try<Output = B>, + { + self.0.try_rfold(init, f) + } +} diff --git a/library/core/tests/array.rs b/library/core/tests/array.rs index a778779c0fd..ee7ff012ec1 100644 --- a/library/core/tests/array.rs +++ b/library/core/tests/array.rs @@ -668,3 +668,35 @@ fn array_mixed_equality_nans() { assert!(!(mut3 == array3)); assert!(mut3 != array3); } + +#[test] +fn array_into_iter_fold() { + // Strings to help MIRI catch if we double-free or something + let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()]; + let mut s = "s".to_string(); + a.into_iter().for_each(|b| s += &b); + assert_eq!(s, "sAaBbCc"); + + let a = [1, 2, 3, 4, 5, 6]; + let mut it = a.into_iter(); + it.advance_by(1).unwrap(); + it.advance_back_by(2).unwrap(); + let s = it.fold(10, |a, b| 10 * a + b); + assert_eq!(s, 10234); +} + +#[test] +fn array_into_iter_rfold() { + // Strings to help MIRI catch if we double-free or something + let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()]; + let mut s = "s".to_string(); + a.into_iter().rev().for_each(|b| s += &b); + assert_eq!(s, "sCcBbAa"); + + let a = [1, 2, 3, 4, 5, 6]; + let mut it = a.into_iter(); + it.advance_by(1).unwrap(); + it.advance_back_by(2).unwrap(); + let s = it.rfold(10, |a, b| 10 * a + b); + assert_eq!(s, 10432); +} diff --git a/src/test/codegen/simd-wide-sum.rs b/src/test/codegen/simd-wide-sum.rs new file mode 100644 index 00000000000..fde9b0fcd8a --- /dev/null +++ b/src/test/codegen/simd-wide-sum.rs @@ -0,0 +1,54 @@ +// compile-flags: -C opt-level=3 --edition=2021 +// only-x86_64 +// ignore-debug: the debug assertions get in the way + +#![crate_type = "lib"] +#![feature(portable_simd)] + +use std::simd::Simd; +const N: usize = 8; + +#[no_mangle] +// CHECK-LABEL: @wider_reduce_simd +pub fn wider_reduce_simd(x: Simd<u8, N>) -> u16 { + // CHECK: zext <8 x i8> + // CHECK-SAME: to <8 x i16> + // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> + let x: Simd<u16, N> = x.cast(); + x.reduce_sum() +} + +#[no_mangle] +// CHECK-LABEL: @wider_reduce_loop +pub fn wider_reduce_loop(x: Simd<u8, N>) -> u16 { + // CHECK: zext <8 x i8> + // CHECK-SAME: to <8 x i16> + // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> + let mut sum = 0_u16; + for i in 0..N { + sum += u16::from(x[i]); + } + sum +} + +#[no_mangle] +// CHECK-LABEL: @wider_reduce_iter +pub fn wider_reduce_iter(x: Simd<u8, N>) -> u16 { + // CHECK: zext <8 x i8> + // CHECK-SAME: to <8 x i16> + // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> + x.as_array().iter().copied().map(u16::from).sum() +} + +// This iterator one is the most interesting, as it's the one +// which used to not auto-vectorize due to a suboptimality in the +// `<array::IntoIter as Iterator>::fold` implementation. + +#[no_mangle] +// CHECK-LABEL: @wider_reduce_into_iter +pub fn wider_reduce_into_iter(x: Simd<u8, N>) -> u16 { + // CHECK: zext <8 x i8> + // CHECK-SAME: to <8 x i16> + // CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> + x.to_array().into_iter().map(u16::from).sum() +} |
