about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/librustc_data_structures/bitvec.rs185
1 files changed, 176 insertions, 9 deletions
diff --git a/src/librustc_data_structures/bitvec.rs b/src/librustc_data_structures/bitvec.rs
index f2f4a69d882..a0e4f4a3f2d 100644
--- a/src/librustc_data_structures/bitvec.rs
+++ b/src/librustc_data_structures/bitvec.rs
@@ -15,26 +15,193 @@ pub struct BitVector {
 
 impl BitVector {
     pub fn new(num_bits: usize) -> BitVector {
-        let num_words = (num_bits + 63) / 64;
+        let num_words = u64s(num_bits);
         BitVector { data: vec![0; num_words] }
     }
 
-    fn word_mask(&self, bit: usize) -> (usize, u64) {
-        let word = bit / 64;
-        let mask = 1 << (bit % 64);
-        (word, mask)
-    }
-
     pub fn contains(&self, bit: usize) -> bool {
-        let (word, mask) = self.word_mask(bit);
+        let (word, mask) = word_mask(bit);
         (self.data[word] & mask) != 0
     }
 
     pub fn insert(&mut self, bit: usize) -> bool {
-        let (word, mask) = self.word_mask(bit);
+        let (word, mask) = word_mask(bit);
         let data = &mut self.data[word];
         let value = *data;
         *data = value | mask;
         (value | mask) != value
     }
+
+    pub fn insert_all(&mut self, all: &BitVector) -> bool {
+        assert!(self.data.len() == all.data.len());
+        let mut changed = false;
+        for (i, j) in self.data.iter_mut().zip(&all.data) {
+            let value = *i;
+            *i = value | *j;
+            if value != *i { changed = true; }
+        }
+        changed
+    }
+
+    pub fn grow(&mut self, num_bits: usize) {
+        let num_words = u64s(num_bits);
+        let extra_words = self.data.len() - num_words;
+        self.data.extend((0..extra_words).map(|_| 0));
+    }
+}
+
+/// A "bit matrix" is basically a square matrix of booleans
+/// represented as one gigantic bitvector. In other words, it is as if
+/// you have N bitvectors, each of length N.
+#[derive(Clone)]
+pub struct BitMatrix {
+    elements: usize,
+    vector: Vec<u64>,
+}
+
+impl BitMatrix {
+    pub fn new(elements: usize) -> BitMatrix {
+        // For every element, we need one bit for every other
+        // element. Round up to an even number of u64s.
+        let u64s_per_elem = u64s(elements);
+        BitMatrix {
+            elements: elements,
+            vector: vec![0; elements * u64s_per_elem]
+        }
+    }
+
+    /// The range of bits for a given element.
+    fn range(&self, element: usize) -> (usize, usize) {
+        let u64s_per_elem = u64s(self.elements);
+        let start = element * u64s_per_elem;
+        (start, start + u64s_per_elem)
+    }
+
+    pub fn add(&mut self, source: usize, target: usize) -> bool {
+        let (start, _) = self.range(source);
+        let (word, mask) = word_mask(target);
+        let mut vector = &mut self.vector[..];
+        let v1 = vector[start+word];
+        let v2 = v1 | mask;
+        vector[start+word] = v2;
+        v1 != v2
+    }
+
+    /// Do the bits from `source` contain `target`?
+    /// Put another way, can `source` reach `target`?
+    pub fn contains(&self, source: usize, target: usize) -> bool {
+        let (start, _) = self.range(source);
+        let (word, mask) = word_mask(target);
+        (self.vector[start+word] & mask) != 0
+    }
+
+    /// Returns those indices that are reachable from both source and
+    /// target. This is an O(n) operation where `n` is the number of
+    /// elements (somewhat independent from the actual size of the
+    /// intersection, in particular).
+    pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> {
+        let (a_start, a_end) = self.range(a);
+        let (b_start, b_end) = self.range(b);
+        let mut result = Vec::with_capacity(self.elements);
+        for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
+            let mut v = self.vector[i] & self.vector[j];
+            for bit in 0..64 {
+                if v == 0 { break; }
+                if v & 0x1 != 0 { result.push(base*64 + bit); }
+                v >>= 1;
+            }
+        }
+        result
+    }
+
+    /// Add the bits from source to the bits from destination,
+    /// return true if anything changed.
+    ///
+    /// This is used when computing reachability because if you have
+    /// an edge `destination -> source`, because in that case
+    /// `destination` can reach everything that `source` can (and
+    /// potentially more).
+    pub fn merge(&mut self, source: usize, destination: usize) -> bool {
+        let (source_start, source_end) = self.range(source);
+        let (destination_start, destination_end) = self.range(destination);
+        let vector = &mut self.vector[..];
+        let mut changed = false;
+        for (source_index, destination_index) in
+            (source_start..source_end).zip(destination_start..destination_end)
+        {
+            let v1 = vector[destination_index];
+            let v2 = v1 | vector[source_index];
+            vector[destination_index] = v2;
+            changed = changed | (v1 != v2);
+        }
+        changed
+    }
+}
+
+fn u64s(elements: usize) -> usize {
+    (elements + 63) / 64
+}
+
+fn word_mask(index: usize) -> (usize, u64) {
+    let word = index / 64;
+    let mask = 1 << (index % 64);
+    (word, mask)
+}
+
+#[test]
+fn union_two_vecs() {
+    let mut vec1 = BitVector::new(65);
+    let mut vec2 = BitVector::new(65);
+    assert!(vec1.insert(3));
+    assert!(!vec1.insert(3));
+    assert!(vec2.insert(5));
+    assert!(vec2.insert(64));
+    assert!(vec1.insert_all(&vec2));
+    assert!(!vec1.insert_all(&vec2));
+    assert!(vec1.contains(3));
+    assert!(!vec1.contains(4));
+    assert!(vec1.contains(5));
+    assert!(!vec1.contains(63));
+    assert!(vec1.contains(64));
+}
+
+#[test]
+fn grow() {
+    let mut vec1 = BitVector::new(65);
+    assert!(vec1.insert(3));
+    assert!(!vec1.insert(3));
+    assert!(vec1.insert(5));
+    assert!(vec1.insert(64));
+    vec1.grow(128);
+    assert!(vec1.contains(3));
+    assert!(vec1.contains(5));
+    assert!(vec1.contains(64));
+    assert!(!vec1.contains(126));
+}
+
+#[test]
+fn matrix_intersection() {
+    let mut vec1 = BitMatrix::new(200);
+
+    vec1.add(2, 3);
+    vec1.add(2, 6);
+    vec1.add(2, 10);
+    vec1.add(2, 64);
+    vec1.add(2, 65);
+    vec1.add(2, 130);
+    vec1.add(2, 160);
+
+    vec1.add(65, 2);
+    vec1.add(65, 8);
+    vec1.add(65, 10); // X
+    vec1.add(65, 64); // X
+    vec1.add(65, 68);
+    vec1.add(65, 133);
+    vec1.add(65, 160); // X
+
+    let intersection = vec1.intersection(2, 64);
+    assert!(intersection.is_empty());
+
+    let intersection = vec1.intersection(2, 65);
+    assert_eq!(intersection, vec![10, 64, 160]);
 }