about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFolyd <lyshuhow@gmail.com>2020-09-25 16:43:43 +0800
committerFolyd <lyshuhow@gmail.com>2021-01-30 14:11:47 +0800
commit18e44a1be4d8aecf41d0b904d63f6c72f51d6b0d (patch)
tree043cbe6ab9b8149dd86bf91f9f99681983282f7f
parent0248c6f178ab3a4d2ec702b7d418ff8375ab0515 (diff)
downloadrust-18e44a1be4d8aecf41d0b904d63f6c72f51d6b0d.tar.gz
rust-18e44a1be4d8aecf41d0b904d63f6c72f51d6b0d.zip
Improve slice.binary_search_by()'s best-case performance to O(1)
-rw-r--r--library/core/benches/slice.rs44
-rw-r--r--library/core/src/slice/mod.rs6
-rw-r--r--library/core/tests/slice.rs8
3 files changed, 47 insertions, 11 deletions
diff --git a/library/core/benches/slice.rs b/library/core/benches/slice.rs
index 06b37cb0844..dbab0085686 100644
--- a/library/core/benches/slice.rs
+++ b/library/core/benches/slice.rs
@@ -7,15 +7,21 @@ enum Cache {
     L3,
 }
 
+impl Cache {
+    fn size(&self) -> usize {
+        match self {
+            Cache::L1 => 1000,      // 8kb
+            Cache::L2 => 10_000,    // 80kb
+            Cache::L3 => 1_000_000, // 8Mb
+        }
+    }
+}
+
 fn binary_search<F>(b: &mut Bencher, cache: Cache, mapper: F)
 where
     F: Fn(usize) -> usize,
 {
-    let size = match cache {
-        Cache::L1 => 1000,      // 8kb
-        Cache::L2 => 10_000,    // 80kb
-        Cache::L3 => 1_000_000, // 8Mb
-    };
+    let size = cache.size();
     let v = (0..size).map(&mapper).collect::<Vec<_>>();
     let mut r = 0usize;
     b.iter(move || {
@@ -24,7 +30,18 @@ where
         // Lookup the whole range to get 50% hits and 50% misses.
         let i = mapper(r % size);
         black_box(v.binary_search(&i).is_ok());
-    })
+    });
+}
+
+fn binary_search_worst_case(b: &mut Bencher, cache: Cache) {
+    let size = cache.size();
+
+    let mut v = vec![0; size];
+    let i = 1;
+    v[size - 1] = i;
+    b.iter(move || {
+        black_box(v.binary_search(&i).is_ok());
+    });
 }
 
 #[bench]
@@ -57,6 +74,21 @@ fn binary_search_l3_with_dups(b: &mut Bencher) {
     binary_search(b, Cache::L3, |i| i / 16 * 16);
 }
 
+#[bench]
+fn binary_search_l1_worst_case(b: &mut Bencher) {
+    binary_search_worst_case(b, Cache::L1);
+}
+
+#[bench]
+fn binary_search_l2_worst_case(b: &mut Bencher) {
+    binary_search_worst_case(b, Cache::L2);
+}
+
+#[bench]
+fn binary_search_l3_worst_case(b: &mut Bencher) {
+    binary_search_worst_case(b, Cache::L3);
+}
+
 macro_rules! rotate {
     ($fn:ident, $n:expr, $mapper:expr) => {
         #[bench]
diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs
index b06b6e93373..e904f856d1e 100644
--- a/library/core/src/slice/mod.rs
+++ b/library/core/src/slice/mod.rs
@@ -2167,7 +2167,11 @@ impl<T> [T] {
             // - `mid >= 0`: by definition
             // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
             let cmp = f(unsafe { s.get_unchecked(mid) });
-            base = if cmp == Greater { base } else { mid };
+            if cmp == Equal {
+                return Ok(mid);
+            } else if cmp == Less {
+                base = mid
+            }
             size -= half;
         }
         // SAFETY: base is always in [0, size) because base <= mid.
diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs
index 9ccc5a08dcb..d9efa7ef20b 100644
--- a/library/core/tests/slice.rs
+++ b/library/core/tests/slice.rs
@@ -73,13 +73,13 @@ fn test_binary_search_implementation_details() {
     let b = [1, 1, 2, 2, 3, 3, 3];
     assert_eq!(b.binary_search(&1), Ok(1));
     assert_eq!(b.binary_search(&2), Ok(3));
-    assert_eq!(b.binary_search(&3), Ok(6));
+    assert_eq!(b.binary_search(&3), Ok(5));
     let b = [1, 1, 1, 1, 1, 3, 3, 3, 3];
     assert_eq!(b.binary_search(&1), Ok(4));
-    assert_eq!(b.binary_search(&3), Ok(8));
+    assert_eq!(b.binary_search(&3), Ok(6));
     let b = [1, 1, 1, 1, 3, 3, 3, 3, 3];
-    assert_eq!(b.binary_search(&1), Ok(3));
-    assert_eq!(b.binary_search(&3), Ok(8));
+    assert_eq!(b.binary_search(&1), Ok(2));
+    assert_eq!(b.binary_search(&3), Ok(4));
 }
 
 #[test]