diff options
| author | Erik Desjardins <erikdesjardins@users.noreply.github.com> | 2021-08-10 19:29:18 -0400 |
|---|---|---|
| committer | Erik Desjardins <erikdesjardins@users.noreply.github.com> | 2021-08-25 17:49:28 -0400 |
| commit | 1eaccab24e536f5708bef8538cfe0dca367ed544 (patch) | |
| tree | 4a851ff38536d53b96311f72dfb665d2cc70594d | |
| parent | c9599c4cacb4c5b9f8a2c353ac062885525960c9 (diff) | |
| download | rust-1eaccab24e536f5708bef8538cfe0dca367ed544.tar.gz rust-1eaccab24e536f5708bef8538cfe0dca367ed544.zip | |
optimize initialization checks
| -rw-r--r-- | compiler/rustc_middle/src/mir/interpret/allocation.rs | 111 |
1 files changed, 102 insertions, 9 deletions
diff --git a/compiler/rustc_middle/src/mir/interpret/allocation.rs b/compiler/rustc_middle/src/mir/interpret/allocation.rs index 71580bcc06d..4040f4a112e 100644 --- a/compiler/rustc_middle/src/mir/interpret/allocation.rs +++ b/compiler/rustc_middle/src/mir/interpret/allocation.rs @@ -1,7 +1,7 @@ //! The virtual memory representation of the MIR interpreter. use std::borrow::Cow; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::iter; use std::ops::{Deref, Range}; use std::ptr; @@ -720,13 +720,12 @@ impl InitMask { return Err(self.len..end); } - // FIXME(oli-obk): optimize this for allocations larger than a block. - let idx = (start..end).find(|&i| !self.get(i)); + let uninit_start = find_bit(self, start, end, false); - match idx { - Some(idx) => { - let uninit_end = (idx..end).find(|&i| self.get(i)).unwrap_or(end); - Err(idx..uninit_end) + match uninit_start { + Some(uninit_start) => { + let uninit_end = find_bit(self, uninit_start, end, true).unwrap_or(end); + Err(uninit_start..uninit_end) } None => Ok(()), } @@ -863,9 +862,8 @@ impl<'a> Iterator for InitChunkIter<'a> { } let is_init = self.init_mask.get(self.start); - // FIXME(oli-obk): optimize this for allocations larger than a block. let end_of_chunk = - (self.start..self.end).find(|&i| self.init_mask.get(i) != is_init).unwrap_or(self.end); + find_bit(&self.init_mask, self.start, self.end, !is_init).unwrap_or(self.end); let range = self.start..end_of_chunk; self.start = end_of_chunk; @@ -874,6 +872,94 @@ impl<'a> Iterator for InitChunkIter<'a> { } } +/// Returns the index of the first bit in `start..end` (end-exclusive) that is equal to is_init. +fn find_bit(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> { + fn find_bit_fast(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> { + fn search_block( + bits: Block, + block: usize, + start_bit: usize, + is_init: bool, + ) -> Option<Size> { + // invert bits so we're always looking for the first set bit + let bits = if is_init { bits } else { !bits }; + // mask off unused start bits + let bits = bits & (!0 << start_bit); + // find set bit, if any + if bits == 0 { + None + } else { + let bit = bits.trailing_zeros(); + Some(size_from_bit_index(block, bit)) + } + } + + if start >= end { + return None; + } + + let (start_block, start_bit) = bit_index(start); + let (end_block, end_bit) = bit_index(end); + + // handle first block: need to skip `start_bit` bits + if let Some(i) = + search_block(init_mask.blocks[start_block], start_block, start_bit, is_init) + { + if i < end { + return Some(i); + } else { + // if the range is less than a block, we may find a matching bit after `end` + return None; + } + } + + let one_block_past_the_end = if end_bit > 0 { + // if `end_bit` > 0, then the range overlaps `end_block` + end_block + 1 + } else { + end_block + }; + + // handle remaining blocks + if start_block < one_block_past_the_end { + for (&bits, block) in init_mask.blocks[start_block + 1..one_block_past_the_end] + .iter() + .zip(start_block + 1..) + { + if let Some(i) = search_block(bits, block, 0, is_init) { + if i < end { + return Some(i); + } else { + // if this is the last block, we may find a matching bit after `end` + return None; + } + } + } + } + + None + } + + #[cfg_attr(not(debug_assertions), allow(dead_code))] + fn find_bit_slow(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> { + (start..end).find(|&i| init_mask.get(i) == is_init) + } + + let result = find_bit_fast(init_mask, start, end, is_init); + + debug_assert_eq!( + result, + find_bit_slow(init_mask, start, end, is_init), + "optimized implementation of find_bit is wrong for start={:?} end={:?} is_init={} init_mask={:#?}", + start, + end, + is_init, + init_mask + ); + + result +} + #[inline] fn bit_index(bits: Size) -> (usize, usize) { let bits = bits.bytes(); @@ -881,3 +967,10 @@ fn bit_index(bits: Size) -> (usize, usize) { let b = bits % InitMask::BLOCK_SIZE; (usize::try_from(a).unwrap(), usize::try_from(b).unwrap()) } + +#[inline] +fn size_from_bit_index(block: impl TryInto<u64>, bit: impl TryInto<u64>) -> Size { + let block = block.try_into().ok().unwrap(); + let bit = bit.try_into().ok().unwrap(); + Size::from_bytes(block * InitMask::BLOCK_SIZE + bit) +} |
