about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-11-26 12:03:40 -0500
committerGitHub <noreply@github.com>2024-11-26 12:03:40 -0500
commit42459a7971be2d35f2b08c473ffb909cdc6427d2 (patch)
tree514fb2e27868eabc90548f3b86e4a12e812a9fef
parent9d6a11a435b5bcafc9225c2578ffbab1afcba198 (diff)
parent4a230bba746bafb0c152ee08a759990e6ed6a008 (diff)
downloadrust-42459a7971be2d35f2b08c473ffb909cdc6427d2.tar.gz
rust-42459a7971be2d35f2b08c473ffb909cdc6427d2.zip
Rollup merge of #133136 - ChayimFriedman2:get-many-mut, r=Amanieu
Support ranges in `<[T]>::get_many_mut()`

As per T-libs-api decision in #104642.

I implemented that with a separate trait and not within `SliceIndex`, because doing that via `SliceIndex` requires adding support for range types that are (almost) always overlapping e.g. `RangeFrom`, and also adding fake support code for `impl SliceIndex<str>`.

An inconvenience that I ran into was that slice indexing takes the index by value, but I only have it by reference. I could change slice indexing to take by ref, but this is pretty much the hottest code ever so I'm afraid to touch it. Instead I added a requirement for `Clone` (which all index types implement anyway) and cloned. This is an internal requirement the user won't see and the clone should always be optimized away.

I also implemented `Clone`, `PartialEq` and `Eq` for the error type, since I noticed it does not do that when writing the tests and other errors in std seem to implement them. I didn't implement `Copy` because maybe we will want to put something non-`Copy` there.
-rw-r--r--library/core/src/slice/mod.rs198
-rw-r--r--library/core/tests/slice.rs70
2 files changed, 249 insertions, 19 deletions
diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs
index c855f963771..b3defba5a98 100644
--- a/library/core/src/slice/mod.rs
+++ b/library/core/src/slice/mod.rs
@@ -10,10 +10,10 @@ use crate::cmp::Ordering::{self, Equal, Greater, Less};
 use crate::intrinsics::{exact_div, select_unpredictable, unchecked_sub};
 use crate::mem::{self, SizedTypeProperties};
 use crate::num::NonZero;
-use crate::ops::{Bound, OneSidedRange, Range, RangeBounds};
+use crate::ops::{Bound, OneSidedRange, Range, RangeBounds, RangeInclusive};
 use crate::simd::{self, Simd};
 use crate::ub_checks::assert_unsafe_precondition;
-use crate::{fmt, hint, ptr, slice};
+use crate::{fmt, hint, ptr, range, slice};
 
 #[unstable(
     feature = "slice_internals",
@@ -4469,6 +4469,12 @@ impl<T> [T] {
 
     /// Returns mutable references to many indices at once, without doing any checks.
     ///
+    /// An index can be either a `usize`, a [`Range`] or a [`RangeInclusive`]. Note
+    /// that this method takes an array, so all indices must be of the same type.
+    /// If passed an array of `usize`s this method gives back an array of mutable references
+    /// to single elements, while if passed an array of ranges it gives back an array of
+    /// mutable references to slices.
+    ///
     /// For a safe alternative see [`get_many_mut`].
     ///
     /// # Safety
@@ -4489,30 +4495,49 @@ impl<T> [T] {
     ///     *b *= 100;
     /// }
     /// assert_eq!(x, &[10, 2, 400]);
+    ///
+    /// unsafe {
+    ///     let [a, b] = x.get_many_unchecked_mut([0..1, 1..3]);
+    ///     a[0] = 8;
+    ///     b[0] = 88;
+    ///     b[1] = 888;
+    /// }
+    /// assert_eq!(x, &[8, 88, 888]);
+    ///
+    /// unsafe {
+    ///     let [a, b] = x.get_many_unchecked_mut([1..=2, 0..=0]);
+    ///     a[0] = 11;
+    ///     a[1] = 111;
+    ///     b[0] = 1;
+    /// }
+    /// assert_eq!(x, &[1, 11, 111]);
     /// ```
     ///
     /// [`get_many_mut`]: slice::get_many_mut
     /// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html
     #[unstable(feature = "get_many_mut", issue = "104642")]
     #[inline]
-    pub unsafe fn get_many_unchecked_mut<const N: usize>(
+    pub unsafe fn get_many_unchecked_mut<I, const N: usize>(
         &mut self,
-        indices: [usize; N],
-    ) -> [&mut T; N] {
+        indices: [I; N],
+    ) -> [&mut I::Output; N]
+    where
+        I: GetManyMutIndex + SliceIndex<Self>,
+    {
         // NB: This implementation is written as it is because any variation of
         // `indices.map(|i| self.get_unchecked_mut(i))` would make miri unhappy,
         // or generate worse code otherwise. This is also why we need to go
         // through a raw pointer here.
         let slice: *mut [T] = self;
-        let mut arr: mem::MaybeUninit<[&mut T; N]> = mem::MaybeUninit::uninit();
+        let mut arr: mem::MaybeUninit<[&mut I::Output; N]> = mem::MaybeUninit::uninit();
         let arr_ptr = arr.as_mut_ptr();
 
         // SAFETY: We expect `indices` to contain disjunct values that are
         // in bounds of `self`.
         unsafe {
             for i in 0..N {
-                let idx = *indices.get_unchecked(i);
-                *(*arr_ptr).get_unchecked_mut(i) = &mut *slice.get_unchecked_mut(idx);
+                let idx = indices.get_unchecked(i).clone();
+                arr_ptr.cast::<&mut I::Output>().add(i).write(&mut *slice.get_unchecked_mut(idx));
             }
             arr.assume_init()
         }
@@ -4520,8 +4545,18 @@ impl<T> [T] {
 
     /// Returns mutable references to many indices at once.
     ///
-    /// Returns an error if any index is out-of-bounds, or if the same index was
-    /// passed more than once.
+    /// An index can be either a `usize`, a [`Range`] or a [`RangeInclusive`]. Note
+    /// that this method takes an array, so all indices must be of the same type.
+    /// If passed an array of `usize`s this method gives back an array of mutable references
+    /// to single elements, while if passed an array of ranges it gives back an array of
+    /// mutable references to slices.
+    ///
+    /// Returns an error if any index is out-of-bounds, or if there are overlapping indices.
+    /// An empty range is not considered to overlap if it is located at the beginning or at
+    /// the end of another range, but is considered to overlap if it is located in the middle.
+    ///
+    /// This method does a O(n^2) check to check that there are no overlapping indices, so be careful
+    /// when passing many indices.
     ///
     /// # Examples
     ///
@@ -4534,13 +4569,30 @@ impl<T> [T] {
     ///     *b = 612;
     /// }
     /// assert_eq!(v, &[413, 2, 612]);
+    ///
+    /// if let Ok([a, b]) = v.get_many_mut([0..1, 1..3]) {
+    ///     a[0] = 8;
+    ///     b[0] = 88;
+    ///     b[1] = 888;
+    /// }
+    /// assert_eq!(v, &[8, 88, 888]);
+    ///
+    /// if let Ok([a, b]) = v.get_many_mut([1..=2, 0..=0]) {
+    ///     a[0] = 11;
+    ///     a[1] = 111;
+    ///     b[0] = 1;
+    /// }
+    /// assert_eq!(v, &[1, 11, 111]);
     /// ```
     #[unstable(feature = "get_many_mut", issue = "104642")]
     #[inline]
-    pub fn get_many_mut<const N: usize>(
+    pub fn get_many_mut<I, const N: usize>(
         &mut self,
-        indices: [usize; N],
-    ) -> Result<[&mut T; N], GetManyMutError<N>> {
+        indices: [I; N],
+    ) -> Result<[&mut I::Output; N], GetManyMutError<N>>
+    where
+        I: GetManyMutIndex + SliceIndex<Self>,
+    {
         if !get_many_check_valid(&indices, self.len()) {
             return Err(GetManyMutError { _private: () });
         }
@@ -4885,14 +4937,15 @@ impl<T, const N: usize> SlicePattern for [T; N] {
 ///
 /// This will do `binomial(N + 1, 2) = N * (N + 1) / 2 = 0, 1, 3, 6, 10, ..`
 /// comparison operations.
-fn get_many_check_valid<const N: usize>(indices: &[usize; N], len: usize) -> bool {
+#[inline]
+fn get_many_check_valid<I: GetManyMutIndex, const N: usize>(indices: &[I; N], len: usize) -> bool {
     // NB: The optimizer should inline the loops into a sequence
     // of instructions without additional branching.
     let mut valid = true;
-    for (i, &idx) in indices.iter().enumerate() {
-        valid &= idx < len;
-        for &idx2 in &indices[..i] {
-            valid &= idx != idx2;
+    for (i, idx) in indices.iter().enumerate() {
+        valid &= idx.is_in_bounds(len);
+        for idx2 in &indices[..i] {
+            valid &= !idx.is_overlapping(idx2);
         }
     }
     valid
@@ -4916,6 +4969,7 @@ fn get_many_check_valid<const N: usize>(indices: &[usize; N], len: usize) -> boo
 #[unstable(feature = "get_many_mut", issue = "104642")]
 // NB: The N here is there to be forward-compatible with adding more details
 // to the error type at a later point
+#[derive(Clone, PartialEq, Eq)]
 pub struct GetManyMutError<const N: usize> {
     _private: (),
 }
@@ -4933,3 +4987,111 @@ impl<const N: usize> fmt::Display for GetManyMutError<N> {
         fmt::Display::fmt("an index is out of bounds or appeared multiple times in the array", f)
     }
 }
+
+mod private_get_many_mut_index {
+    use super::{Range, RangeInclusive, range};
+
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    pub trait Sealed {}
+
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    impl Sealed for usize {}
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    impl Sealed for Range<usize> {}
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    impl Sealed for RangeInclusive<usize> {}
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    impl Sealed for range::Range<usize> {}
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    impl Sealed for range::RangeInclusive<usize> {}
+}
+
+/// A helper trait for `<[T]>::get_many_mut()`.
+///
+/// # Safety
+///
+/// If `is_in_bounds()` returns `true` and `is_overlapping()` returns `false`,
+/// it must be safe to index the slice with the indices.
+#[unstable(feature = "get_many_mut_helpers", issue = "none")]
+pub unsafe trait GetManyMutIndex: Clone + private_get_many_mut_index::Sealed {
+    /// Returns `true` if `self` is in bounds for `len` slice elements.
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    fn is_in_bounds(&self, len: usize) -> bool;
+
+    /// Returns `true` if `self` overlaps with `other`.
+    ///
+    /// Note that we don't consider zero-length ranges to overlap at the beginning or the end,
+    /// but do consider them to overlap in the middle.
+    #[unstable(feature = "get_many_mut_helpers", issue = "none")]
+    fn is_overlapping(&self, other: &Self) -> bool;
+}
+
+#[unstable(feature = "get_many_mut_helpers", issue = "none")]
+// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
+unsafe impl GetManyMutIndex for usize {
+    #[inline]
+    fn is_in_bounds(&self, len: usize) -> bool {
+        *self < len
+    }
+
+    #[inline]
+    fn is_overlapping(&self, other: &Self) -> bool {
+        *self == *other
+    }
+}
+
+#[unstable(feature = "get_many_mut_helpers", issue = "none")]
+// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
+unsafe impl GetManyMutIndex for Range<usize> {
+    #[inline]
+    fn is_in_bounds(&self, len: usize) -> bool {
+        (self.start <= self.end) & (self.end <= len)
+    }
+
+    #[inline]
+    fn is_overlapping(&self, other: &Self) -> bool {
+        (self.start < other.end) & (other.start < self.end)
+    }
+}
+
+#[unstable(feature = "get_many_mut_helpers", issue = "none")]
+// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
+unsafe impl GetManyMutIndex for RangeInclusive<usize> {
+    #[inline]
+    fn is_in_bounds(&self, len: usize) -> bool {
+        (self.start <= self.end) & (self.end < len)
+    }
+
+    #[inline]
+    fn is_overlapping(&self, other: &Self) -> bool {
+        (self.start <= other.end) & (other.start <= self.end)
+    }
+}
+
+#[unstable(feature = "get_many_mut_helpers", issue = "none")]
+// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
+unsafe impl GetManyMutIndex for range::Range<usize> {
+    #[inline]
+    fn is_in_bounds(&self, len: usize) -> bool {
+        Range::from(*self).is_in_bounds(len)
+    }
+
+    #[inline]
+    fn is_overlapping(&self, other: &Self) -> bool {
+        Range::from(*self).is_overlapping(&Range::from(*other))
+    }
+}
+
+#[unstable(feature = "get_many_mut_helpers", issue = "none")]
+// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
+unsafe impl GetManyMutIndex for range::RangeInclusive<usize> {
+    #[inline]
+    fn is_in_bounds(&self, len: usize) -> bool {
+        RangeInclusive::from(*self).is_in_bounds(len)
+    }
+
+    #[inline]
+    fn is_overlapping(&self, other: &Self) -> bool {
+        RangeInclusive::from(*self).is_overlapping(&RangeInclusive::from(*other))
+    }
+}
diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs
index 9ae2bcc8526..510dd4967c9 100644
--- a/library/core/tests/slice.rs
+++ b/library/core/tests/slice.rs
@@ -2,6 +2,7 @@ use core::cell::Cell;
 use core::cmp::Ordering;
 use core::mem::MaybeUninit;
 use core::num::NonZero;
+use core::ops::{Range, RangeInclusive};
 use core::slice;
 
 #[test]
@@ -2553,6 +2554,14 @@ fn test_get_many_mut_normal_2() {
     *a += 10;
     *b += 100;
     assert_eq!(v, vec![101, 2, 3, 14, 5]);
+
+    let [a, b] = v.get_many_mut([0..=1, 2..=2]).unwrap();
+    assert_eq!(a, &mut [101, 2][..]);
+    assert_eq!(b, &mut [3][..]);
+    a[0] += 10;
+    a[1] += 20;
+    b[0] += 100;
+    assert_eq!(v, vec![111, 22, 103, 14, 5]);
 }
 
 #[test]
@@ -2563,12 +2572,23 @@ fn test_get_many_mut_normal_3() {
     *b += 100;
     *c += 1000;
     assert_eq!(v, vec![11, 2, 1003, 4, 105]);
+
+    let [a, b, c] = v.get_many_mut([0..1, 4..5, 1..4]).unwrap();
+    assert_eq!(a, &mut [11][..]);
+    assert_eq!(b, &mut [105][..]);
+    assert_eq!(c, &mut [2, 1003, 4][..]);
+    a[0] += 10;
+    b[0] += 100;
+    c[0] += 1000;
+    assert_eq!(v, vec![21, 1002, 1003, 4, 205]);
 }
 
 #[test]
 fn test_get_many_mut_empty() {
     let mut v = vec![1, 2, 3, 4, 5];
-    let [] = v.get_many_mut([]).unwrap();
+    let [] = v.get_many_mut::<usize, 0>([]).unwrap();
+    let [] = v.get_many_mut::<RangeInclusive<usize>, 0>([]).unwrap();
+    let [] = v.get_many_mut::<Range<usize>, 0>([]).unwrap();
     assert_eq!(v, vec![1, 2, 3, 4, 5]);
 }
 
@@ -2607,6 +2627,54 @@ fn test_get_many_mut_duplicate() {
 }
 
 #[test]
+fn test_get_many_mut_range_oob() {
+    let mut v = vec![1, 2, 3, 4, 5];
+    assert!(v.get_many_mut([0..6]).is_err());
+    assert!(v.get_many_mut([5..6]).is_err());
+    assert!(v.get_many_mut([6..6]).is_err());
+    assert!(v.get_many_mut([0..=5]).is_err());
+    assert!(v.get_many_mut([0..=6]).is_err());
+    assert!(v.get_many_mut([5..=5]).is_err());
+}
+
+#[test]
+fn test_get_many_mut_range_overlapping() {
+    let mut v = vec![1, 2, 3, 4, 5];
+    assert!(v.get_many_mut([0..1, 0..2]).is_err());
+    assert!(v.get_many_mut([0..1, 1..2, 0..1]).is_err());
+    assert!(v.get_many_mut([0..3, 1..1]).is_err());
+    assert!(v.get_many_mut([0..3, 1..2]).is_err());
+    assert!(v.get_many_mut([0..=0, 2..=2, 0..=1]).is_err());
+    assert!(v.get_many_mut([0..=4, 0..=0]).is_err());
+    assert!(v.get_many_mut([4..=4, 0..=0, 3..=4]).is_err());
+}
+
+#[test]
+fn test_get_many_mut_range_empty_at_edge() {
+    let mut v = vec![1, 2, 3, 4, 5];
+    assert_eq!(
+        v.get_many_mut([0..0, 0..5, 5..5]),
+        Ok([&mut [][..], &mut [1, 2, 3, 4, 5], &mut []]),
+    );
+    assert_eq!(
+        v.get_many_mut([0..0, 0..1, 1..1, 1..2, 2..2, 2..3, 3..3, 3..4, 4..4, 4..5, 5..5]),
+        Ok([
+            &mut [][..],
+            &mut [1],
+            &mut [],
+            &mut [2],
+            &mut [],
+            &mut [3],
+            &mut [],
+            &mut [4],
+            &mut [],
+            &mut [5],
+            &mut [],
+        ]),
+    );
+}
+
+#[test]
 fn test_slice_from_raw_parts_in_const() {
     static FANCY: i32 = 4;
     static FANCY_SLICE: &[i32] = unsafe { std::slice::from_raw_parts(&FANCY, 1) };