about summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Desjardins <erikdesjardins@users.noreply.github.com>2021-08-10 19:29:18 -0400
committerErik Desjardins <erikdesjardins@users.noreply.github.com>2021-08-25 17:49:28 -0400
commit1eaccab24e536f5708bef8538cfe0dca367ed544 (patch)
tree4a851ff38536d53b96311f72dfb665d2cc70594d
parentc9599c4cacb4c5b9f8a2c353ac062885525960c9 (diff)
downloadrust-1eaccab24e536f5708bef8538cfe0dca367ed544.tar.gz
rust-1eaccab24e536f5708bef8538cfe0dca367ed544.zip
optimize initialization checks
-rw-r--r--compiler/rustc_middle/src/mir/interpret/allocation.rs111
1 files changed, 102 insertions, 9 deletions
diff --git a/compiler/rustc_middle/src/mir/interpret/allocation.rs b/compiler/rustc_middle/src/mir/interpret/allocation.rs
index 71580bcc06d..4040f4a112e 100644
--- a/compiler/rustc_middle/src/mir/interpret/allocation.rs
+++ b/compiler/rustc_middle/src/mir/interpret/allocation.rs
@@ -1,7 +1,7 @@
 //! The virtual memory representation of the MIR interpreter.
 
 use std::borrow::Cow;
-use std::convert::TryFrom;
+use std::convert::{TryFrom, TryInto};
 use std::iter;
 use std::ops::{Deref, Range};
 use std::ptr;
@@ -720,13 +720,12 @@ impl InitMask {
             return Err(self.len..end);
         }
 
-        // FIXME(oli-obk): optimize this for allocations larger than a block.
-        let idx = (start..end).find(|&i| !self.get(i));
+        let uninit_start = find_bit(self, start, end, false);
 
-        match idx {
-            Some(idx) => {
-                let uninit_end = (idx..end).find(|&i| self.get(i)).unwrap_or(end);
-                Err(idx..uninit_end)
+        match uninit_start {
+            Some(uninit_start) => {
+                let uninit_end = find_bit(self, uninit_start, end, true).unwrap_or(end);
+                Err(uninit_start..uninit_end)
             }
             None => Ok(()),
         }
@@ -863,9 +862,8 @@ impl<'a> Iterator for InitChunkIter<'a> {
         }
 
         let is_init = self.init_mask.get(self.start);
-        // FIXME(oli-obk): optimize this for allocations larger than a block.
         let end_of_chunk =
-            (self.start..self.end).find(|&i| self.init_mask.get(i) != is_init).unwrap_or(self.end);
+            find_bit(&self.init_mask, self.start, self.end, !is_init).unwrap_or(self.end);
         let range = self.start..end_of_chunk;
 
         self.start = end_of_chunk;
@@ -874,6 +872,94 @@ impl<'a> Iterator for InitChunkIter<'a> {
     }
 }
 
+/// Returns the index of the first bit in `start..end` (end-exclusive) that is equal to is_init.
+fn find_bit(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
+    fn find_bit_fast(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
+        fn search_block(
+            bits: Block,
+            block: usize,
+            start_bit: usize,
+            is_init: bool,
+        ) -> Option<Size> {
+            // invert bits so we're always looking for the first set bit
+            let bits = if is_init { bits } else { !bits };
+            // mask off unused start bits
+            let bits = bits & (!0 << start_bit);
+            // find set bit, if any
+            if bits == 0 {
+                None
+            } else {
+                let bit = bits.trailing_zeros();
+                Some(size_from_bit_index(block, bit))
+            }
+        }
+
+        if start >= end {
+            return None;
+        }
+
+        let (start_block, start_bit) = bit_index(start);
+        let (end_block, end_bit) = bit_index(end);
+
+        // handle first block: need to skip `start_bit` bits
+        if let Some(i) =
+            search_block(init_mask.blocks[start_block], start_block, start_bit, is_init)
+        {
+            if i < end {
+                return Some(i);
+            } else {
+                // if the range is less than a block, we may find a matching bit after `end`
+                return None;
+            }
+        }
+
+        let one_block_past_the_end = if end_bit > 0 {
+            // if `end_bit` > 0, then the range overlaps `end_block`
+            end_block + 1
+        } else {
+            end_block
+        };
+
+        // handle remaining blocks
+        if start_block < one_block_past_the_end {
+            for (&bits, block) in init_mask.blocks[start_block + 1..one_block_past_the_end]
+                .iter()
+                .zip(start_block + 1..)
+            {
+                if let Some(i) = search_block(bits, block, 0, is_init) {
+                    if i < end {
+                        return Some(i);
+                    } else {
+                        // if this is the last block, we may find a matching bit after `end`
+                        return None;
+                    }
+                }
+            }
+        }
+
+        None
+    }
+
+    #[cfg_attr(not(debug_assertions), allow(dead_code))]
+    fn find_bit_slow(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
+        (start..end).find(|&i| init_mask.get(i) == is_init)
+    }
+
+    let result = find_bit_fast(init_mask, start, end, is_init);
+
+    debug_assert_eq!(
+        result,
+        find_bit_slow(init_mask, start, end, is_init),
+        "optimized implementation of find_bit is wrong for start={:?} end={:?} is_init={} init_mask={:#?}",
+        start,
+        end,
+        is_init,
+        init_mask
+    );
+
+    result
+}
+
 #[inline]
 fn bit_index(bits: Size) -> (usize, usize) {
     let bits = bits.bytes();
@@ -881,3 +967,10 @@ fn bit_index(bits: Size) -> (usize, usize) {
     let b = bits % InitMask::BLOCK_SIZE;
     (usize::try_from(a).unwrap(), usize::try_from(b).unwrap())
 }
+
+#[inline]
+fn size_from_bit_index(block: impl TryInto<u64>, bit: impl TryInto<u64>) -> Size {
+    let block = block.try_into().ok().unwrap();
+    let bit = bit.try_into().ok().unwrap();
+    Size::from_bytes(block * InitMask::BLOCK_SIZE + bit)
+}