about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/benches/slice.rs44
-rw-r--r--library/core/src/slice/mod.rs44
-rw-r--r--library/core/tests/slice.rs21
3 files changed, 80 insertions, 29 deletions
diff --git a/library/core/benches/slice.rs b/library/core/benches/slice.rs
index 06b37cb0844..dbab0085686 100644
--- a/library/core/benches/slice.rs
+++ b/library/core/benches/slice.rs
@@ -7,15 +7,21 @@ enum Cache {
     L3,
 }
 
+impl Cache {
+    fn size(&self) -> usize {
+        match self {
+            Cache::L1 => 1000,      // 8kb
+            Cache::L2 => 10_000,    // 80kb
+            Cache::L3 => 1_000_000, // 8Mb
+        }
+    }
+}
+
 fn binary_search<F>(b: &mut Bencher, cache: Cache, mapper: F)
 where
     F: Fn(usize) -> usize,
 {
-    let size = match cache {
-        Cache::L1 => 1000,      // 8kb
-        Cache::L2 => 10_000,    // 80kb
-        Cache::L3 => 1_000_000, // 8Mb
-    };
+    let size = cache.size();
     let v = (0..size).map(&mapper).collect::<Vec<_>>();
     let mut r = 0usize;
     b.iter(move || {
@@ -24,7 +30,18 @@ where
         // Lookup the whole range to get 50% hits and 50% misses.
         let i = mapper(r % size);
         black_box(v.binary_search(&i).is_ok());
-    })
+    });
+}
+
+fn binary_search_worst_case(b: &mut Bencher, cache: Cache) {
+    let size = cache.size();
+
+    let mut v = vec![0; size];
+    let i = 1;
+    v[size - 1] = i;
+    b.iter(move || {
+        black_box(v.binary_search(&i).is_ok());
+    });
 }
 
 #[bench]
@@ -57,6 +74,21 @@ fn binary_search_l3_with_dups(b: &mut Bencher) {
     binary_search(b, Cache::L3, |i| i / 16 * 16);
 }
 
+#[bench]
+fn binary_search_l1_worst_case(b: &mut Bencher) {
+    binary_search_worst_case(b, Cache::L1);
+}
+
+#[bench]
+fn binary_search_l2_worst_case(b: &mut Bencher) {
+    binary_search_worst_case(b, Cache::L2);
+}
+
+#[bench]
+fn binary_search_l3_worst_case(b: &mut Bencher) {
+    binary_search_worst_case(b, Cache::L3);
+}
+
 macro_rules! rotate {
     ($fn:ident, $n:expr, $mapper:expr) => {
         #[bench]
diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs
index b2c5c292f45..5510bb0257e 100644
--- a/library/core/src/slice/mod.rs
+++ b/library/core/src/slice/mod.rs
@@ -8,7 +8,7 @@
 
 #![stable(feature = "rust1", since = "1.0.0")]
 
-use crate::cmp::Ordering::{self, Equal, Greater, Less};
+use crate::cmp::Ordering::{self, Greater, Less};
 use crate::marker::Copy;
 use crate::mem;
 use crate::num::NonZeroUsize;
@@ -2185,25 +2185,31 @@ impl<T> [T] {
     where
         F: FnMut(&'a T) -> Ordering,
     {
-        let s = self;
-        let mut size = s.len();
-        if size == 0 {
-            return Err(0);
-        }
-        let mut base = 0usize;
-        while size > 1 {
-            let half = size / 2;
-            let mid = base + half;
-            // SAFETY: the call is made safe by the following inconstants:
-            // - `mid >= 0`: by definition
-            // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
-            let cmp = f(unsafe { s.get_unchecked(mid) });
-            base = if cmp == Greater { base } else { mid };
-            size -= half;
+        let mut size = self.len();
+        let mut left = 0;
+        let mut right = size;
+        while left < right {
+            let mid = left + size / 2;
+
+            // SAFETY: the call is made safe by the following invariants:
+            // - `mid >= 0`
+            // - `mid < size`: `mid` is limited by `[left; right)` bound.
+            let cmp = f(unsafe { self.get_unchecked(mid) });
+
+            // The reason why we use if/else control flow rather than match
+            // is because match reorders comparison operations, which is perf sensitive.
+            // This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
+            if cmp == Less {
+                left = mid + 1;
+            } else if cmp == Greater {
+                right = mid;
+            } else {
+                return Ok(mid);
+            }
+
+            size = right - left;
         }
-        // SAFETY: base is always in [0, size) because base <= mid.
-        let cmp = f(unsafe { s.get_unchecked(base) });
-        if cmp == Equal { Ok(base) } else { Err(base + (cmp == Less) as usize) }
+        Err(left)
     }
 
     /// Binary searches this sorted slice with a key extraction function.
diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs
index 9ccc5a08dcb..7e198631cc7 100644
--- a/library/core/tests/slice.rs
+++ b/library/core/tests/slice.rs
@@ -1,4 +1,5 @@
 use core::cell::Cell;
+use core::cmp::Ordering;
 use core::result::Result::{Err, Ok};
 
 #[test]
@@ -64,6 +65,17 @@ fn test_binary_search() {
     assert_eq!(b.binary_search(&6), Err(4));
     assert_eq!(b.binary_search(&7), Ok(4));
     assert_eq!(b.binary_search(&8), Err(5));
+
+    let b = [(); usize::MAX];
+    assert_eq!(b.binary_search(&()), Ok(usize::MAX / 2));
+}
+
+#[test]
+fn test_binary_search_by_overflow() {
+    let b = [(); usize::MAX];
+    assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(usize::MAX / 2));
+    assert_eq!(b.binary_search_by(|_| Ordering::Greater), Err(0));
+    assert_eq!(b.binary_search_by(|_| Ordering::Less), Err(usize::MAX));
 }
 
 #[test]
@@ -73,13 +85,13 @@ fn test_binary_search_implementation_details() {
     let b = [1, 1, 2, 2, 3, 3, 3];
     assert_eq!(b.binary_search(&1), Ok(1));
     assert_eq!(b.binary_search(&2), Ok(3));
-    assert_eq!(b.binary_search(&3), Ok(6));
+    assert_eq!(b.binary_search(&3), Ok(5));
     let b = [1, 1, 1, 1, 1, 3, 3, 3, 3];
     assert_eq!(b.binary_search(&1), Ok(4));
-    assert_eq!(b.binary_search(&3), Ok(8));
+    assert_eq!(b.binary_search(&3), Ok(7));
     let b = [1, 1, 1, 1, 3, 3, 3, 3, 3];
-    assert_eq!(b.binary_search(&1), Ok(3));
-    assert_eq!(b.binary_search(&3), Ok(8));
+    assert_eq!(b.binary_search(&1), Ok(2));
+    assert_eq!(b.binary_search(&3), Ok(4));
 }
 
 #[test]
@@ -1982,6 +1994,7 @@ fn test_copy_within_panics_dest_too_long() {
     // The length is only 13, so a slice of length 4 starting at index 10 is out of bounds.
     bytes.copy_within(0..4, 10);
 }
+
 #[test]
 #[should_panic(expected = "slice index starts at 2 but ends at 1")]
 fn test_copy_within_panics_src_inverted() {