diff options
| -rw-r--r-- | library/core/src/slice/sort.rs | 112 |
1 files changed, 67 insertions, 45 deletions
diff --git a/library/core/src/slice/sort.rs b/library/core/src/slice/sort.rs index 6bb53b16e61..227db51a0b4 100644 --- a/library/core/src/slice/sort.rs +++ b/library/core/src/slice/sort.rs @@ -1196,52 +1196,37 @@ pub fn merge_sort<T, CmpF, ElemAllocF, ElemDeallocF, RunAllocF, RunDeallocF>( let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn); - // In order to identify natural runs in `v`, we traverse it backwards. That might seem like a - // strange decision, but consider the fact that merges more often go in the opposite direction - // (forwards). According to benchmarks, merging forwards is slightly faster than merging - // backwards. To conclude, identifying runs by traversing backwards improves performance. - let mut end = len; - while end > 0 { - // Find the next natural run, and reverse it if it's strictly descending. - let mut start = end - 1; - if start > 0 { - start -= 1; - - // SAFETY: The v.get_unchecked must be fed with correct inbound indicies. - unsafe { - if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) { - while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) { - start -= 1; - } - v[start..end].reverse(); - } else { - while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) - { - start -= 1; - } - } - } + let mut end = 0; + let mut start = 0; + + // Scan forward. Memory pre-fetching prefers forward scanning vs backwards scanning, and the + // code-gen is usually better. For the most sensitive types such as integers, these are merged + // bidirectionally at once. So there is no benefit in scanning backwards. + while end < len { + let (streak_end, was_reversed) = find_streak(&v[start..], is_less); + end += streak_end; + if was_reversed { + v[start..end].reverse(); } // Insert some more elements into the run if it's too short. Insertion sort is faster than // merge sort on short sequences, so this significantly improves performance. - start = provide_sorted_batch(v, start, end, is_less); + end = provide_sorted_batch(v, start, end, is_less); // Push this run onto the stack. runs.push(TimSortRun { start, len: end - start }); - end = start; + start = end; // Merge some pairs of adjacent runs to satisfy the invariants. - while let Some(r) = collapse(runs.as_slice()) { - let left = runs[r + 1]; - let right = runs[r]; - // SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and - // neither side may be on length 0. + while let Some(r) = collapse(runs.as_slice(), len) { + let left = runs[r]; + let right = runs[r + 1]; + let merge_slice = &mut v[left.start..right.start + right.len]; unsafe { - merge(&mut v[left.start..right.start + right.len], left.len, buf_ptr, is_less); + merge(merge_slice, left.len, buf_ptr, is_less); } - runs[r] = TimSortRun { start: left.start, len: left.len + right.len }; - runs.remove(r + 1); + runs[r + 1] = TimSortRun { start: left.start, len: left.len + right.len }; + runs.remove(r); } } @@ -1263,10 +1248,10 @@ pub fn merge_sort<T, CmpF, ElemAllocF, ElemDeallocF, RunAllocF, RunDeallocF>( // run starts at index 0, it will always demand a merge operation until the stack is fully // collapsed, in order to complete the sort. #[inline] - fn collapse(runs: &[TimSortRun]) -> Option<usize> { + fn collapse(runs: &[TimSortRun], stop: usize) -> Option<usize> { let n = runs.len(); if n >= 2 - && (runs[n - 1].start == 0 + && (runs[n - 1].start + runs[n - 1].len == stop || runs[n - 2].len <= runs[n - 1].len || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) @@ -1454,14 +1439,15 @@ pub struct TimSortRun { start: usize, } -/// Takes a range as denoted by start and end, that is already sorted and extends it to the left if +/// Takes a range as denoted by start and end, that is already sorted and extends it to the right if /// necessary with sorts optimized for smaller ranges such as insertion sort. #[cfg(not(no_global_oom_handling))] -fn provide_sorted_batch<T, F>(v: &mut [T], mut start: usize, end: usize, is_less: &mut F) -> usize +fn provide_sorted_batch<T, F>(v: &mut [T], start: usize, mut end: usize, is_less: &mut F) -> usize where F: FnMut(&T, &T) -> bool, { - debug_assert!(end > start); + let len = v.len(); + assert!(end >= start && end <= len); // This value is a balance between least comparisons and best performance, as // influenced by for example cache locality. @@ -1469,18 +1455,54 @@ where // Insert some more elements into the run if it's too short. Insertion sort is faster than // merge sort on short sequences, so this significantly improves performance. - let start_found = start; let start_end_diff = end - start; - if start_end_diff < MIN_INSERTION_RUN && start != 0 { + if start_end_diff < MIN_INSERTION_RUN && end < len { // v[start_found..end] are elements that are already sorted in the input. We want to extend // the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is // more efficient that trying to push those already sorted elements to the left. + end = cmp::min(start + MIN_INSERTION_RUN, len); + let presorted_start = cmp::max(start_end_diff, 1); - start = if end >= MIN_INSERTION_RUN { end - MIN_INSERTION_RUN } else { 0 }; + insertion_sort_shift_left(&mut v[start..end], presorted_start, is_less); + } - insertion_sort_shift_right(&mut v[start..end], start_found - start, is_less); + end +} + +/// Finds a streak of presorted elements starting at the beginning of the slice. Returns the first +/// value that is not part of said streak, and a bool denoting wether the streak was reversed. +/// Streaks can be increasing or decreasing. +fn find_streak<T, F>(v: &[T], is_less: &mut F) -> (usize, bool) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + if len < 2 { + return (len, false); } - start + let mut end = 2; + + // SAFETY: See below specific. + unsafe { + // SAFETY: We checked that len >= 2, so 0 and 1 are valid indices. + let assume_reverse = is_less(v.get_unchecked(1), v.get_unchecked(0)); + + // SAFETY: We know end >= 2 and check end < len. + // From that follows that accessing v at end and end - 1 is safe. + if assume_reverse { + while end < len && is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { + end += 1; + } + + (end, true) + } else { + while end < len && !is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { + end += 1; + } + (end, false) + } + } } |
