about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/libcore/slice/mod.rs54
-rw-r--r--src/libcore/tests/lib.rs1
-rw-r--r--src/libcore/tests/slice.rs40
3 files changed, 95 insertions, 0 deletions
diff --git a/src/libcore/slice/mod.rs b/src/libcore/slice/mod.rs
index 57bacdd99ef..6a50cdbc1d9 100644
--- a/src/libcore/slice/mod.rs
+++ b/src/libcore/slice/mod.rs
@@ -2663,6 +2663,60 @@ impl<T> [T] {
     {
         self.iter().is_sorted_by_key(f)
     }
+
+    /// Returns the index of the partition point according to the given predicate
+    /// (the index of the first element of the second partition).
+    ///
+    /// The slice is assumed to be partitioned according to the given predicate.
+    /// This means that all elements for which the predicate returns true are at the start of the slice
+    /// and all elements for which the predicate returns false are at the end.
+    /// For example, [7, 15, 3, 5, 4, 12, 6] is a partitioned under the predicate x % 2 != 0
+    /// (all odd numbers are at the start, all even at the end).
+    ///
+    /// If this slice is not partitioned, the returned result is unspecified and meaningless,
+    /// as this method performs a kind of binary search.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// #![feature(partition_point)]
+    ///
+    /// let v = [1, 2, 3, 3, 5, 6, 7];
+    /// let i = v.partition_point(|&x| x < 5);
+    ///
+    /// assert_eq!(i, 4);
+    /// assert!(v[..i].iter().all(|&x| x < 5));
+    /// assert!(v[i..].iter().all(|&x| !(x < 5)));
+    /// ```
+    #[unstable(feature = "partition_point", reason = "new API", issue = "73831")]
+    pub fn partition_point<P>(&self, mut pred: P) -> usize
+    where
+        P: FnMut(&T) -> bool,
+    {
+        let mut left = 0;
+        let mut right = self.len();
+
+        while left != right {
+            let mid = left + (right - left) / 2;
+            // SAFETY:
+            // When left < right, left <= mid < right.
+            // Therefore left always increases and right always decreases,
+            // and eigher of them is selected.
+            // In both cases left <= right is satisfied.
+            // Therefore if left < right in a step,
+            // left <= right is satisfied in the next step.
+            // Therefore as long as left != right, 0 <= left < right <= len is satisfied
+            // and if this case 0 <= mid < len is satisfied too.
+            let value = unsafe { self.get_unchecked(mid) };
+            if pred(value) {
+                left = mid + 1;
+            } else {
+                right = mid;
+            }
+        }
+
+        left
+    }
 }
 
 #[lang = "slice_u8"]
diff --git a/src/libcore/tests/lib.rs b/src/libcore/tests/lib.rs
index 4e55452a4c3..524c38a7ab4 100644
--- a/src/libcore/tests/lib.rs
+++ b/src/libcore/tests/lib.rs
@@ -44,6 +44,7 @@
 #![feature(const_forget)]
 #![feature(option_unwrap_none)]
 #![feature(peekable_next_if)]
+#![feature(partition_point)]
 
 extern crate test;
 
diff --git a/src/libcore/tests/slice.rs b/src/libcore/tests/slice.rs
index cd46117f763..fba73be92be 100644
--- a/src/libcore/tests/slice.rs
+++ b/src/libcore/tests/slice.rs
@@ -82,6 +82,46 @@ fn test_binary_search_implementation_details() {
 }
 
 #[test]
+fn test_partition_point() {
+    let b: [i32; 0] = [];
+    assert_eq!(b.partition_point(|&x| x < 5), 0);
+
+    let b = [4];
+    assert_eq!(b.partition_point(|&x| x < 3), 0);
+    assert_eq!(b.partition_point(|&x| x < 4), 0);
+    assert_eq!(b.partition_point(|&x| x < 5), 1);
+
+    let b = [1, 2, 4, 6, 8, 9];
+    assert_eq!(b.partition_point(|&x| x < 5), 3);
+    assert_eq!(b.partition_point(|&x| x < 6), 3);
+    assert_eq!(b.partition_point(|&x| x < 7), 4);
+    assert_eq!(b.partition_point(|&x| x < 8), 4);
+
+    let b = [1, 2, 4, 5, 6, 8];
+    assert_eq!(b.partition_point(|&x| x < 9), 6);
+
+    let b = [1, 2, 4, 6, 7, 8, 9];
+    assert_eq!(b.partition_point(|&x| x < 6), 3);
+    assert_eq!(b.partition_point(|&x| x < 5), 3);
+    assert_eq!(b.partition_point(|&x| x < 8), 5);
+
+    let b = [1, 2, 4, 5, 6, 8, 9];
+    assert_eq!(b.partition_point(|&x| x < 7), 5);
+    assert_eq!(b.partition_point(|&x| x < 0), 0);
+
+    let b = [1, 3, 3, 3, 7];
+    assert_eq!(b.partition_point(|&x| x < 0), 0);
+    assert_eq!(b.partition_point(|&x| x < 1), 0);
+    assert_eq!(b.partition_point(|&x| x < 2), 1);
+    assert_eq!(b.partition_point(|&x| x < 3), 1);
+    assert_eq!(b.partition_point(|&x| x < 4), 4);
+    assert_eq!(b.partition_point(|&x| x < 5), 4);
+    assert_eq!(b.partition_point(|&x| x < 6), 4);
+    assert_eq!(b.partition_point(|&x| x < 7), 4);
+    assert_eq!(b.partition_point(|&x| x < 8), 5);
+}
+
+#[test]
 fn test_iterator_nth() {
     let v: &[_] = &[0, 1, 2, 3, 4];
     for i in 0..v.len() {