about summary refs log tree commit diff
path: root/src/librustc_data_structures
diff options
context:
space:
mode:
authorNiko Matsakis <niko@alum.mit.edu>2016-08-05 20:12:53 -0400
committerNiko Matsakis <niko@alum.mit.edu>2016-08-09 08:26:06 -0400
commit9978cbc8f42247ab75093b355b54e74b3efbcbf8 (patch)
tree3e92b7a45fb0098c3a634f4ea307b503ab0f29a3 /src/librustc_data_structures
parent8150494ac24f430f986cdb093c4018a7c2bff7fd (diff)
downloadrust-9978cbc8f42247ab75093b355b54e74b3efbcbf8.tar.gz
rust-9978cbc8f42247ab75093b355b54e74b3efbcbf8.zip
generalize BitMatrix to be NxM and not just NxN
Diffstat (limited to 'src/librustc_data_structures')
-rw-r--r--src/librustc_data_structures/bitvec.rs83
-rw-r--r--src/librustc_data_structures/transitive_relation.rs3
2 files changed, 69 insertions, 17 deletions
diff --git a/src/librustc_data_structures/bitvec.rs b/src/librustc_data_structures/bitvec.rs
index 536cefbbe3f..0dab230f47a 100644
--- a/src/librustc_data_structures/bitvec.rs
+++ b/src/librustc_data_structures/bitvec.rs
@@ -124,32 +124,32 @@ impl FromIterator<bool> for BitVector {
     }
 }
 
-/// A "bit matrix" is basically a square matrix of booleans
-/// represented as one gigantic bitvector. In other words, it is as if
-/// you have N bitvectors, each of length N. Note that `elements` here is `N`/
+/// A "bit matrix" is basically a matrix of booleans represented as
+/// one gigantic bitvector. In other words, it is as if you have
+/// `rows` bitvectors, each of length `columns`.
 #[derive(Clone)]
 pub struct BitMatrix {
-    elements: usize,
+    columns: usize,
     vector: Vec<u64>,
 }
 
 impl BitMatrix {
-    // Create a new `elements x elements` matrix, initially empty.
-    pub fn new(elements: usize) -> BitMatrix {
+    // Create a new `rows x columns` matrix, initially empty.
+    pub fn new(rows: usize, columns: usize) -> BitMatrix {
         // For every element, we need one bit for every other
         // element. Round up to an even number of u64s.
-        let u64s_per_elem = u64s(elements);
+        let u64s_per_row = u64s(columns);
         BitMatrix {
-            elements: elements,
-            vector: vec![0; elements * u64s_per_elem],
+            columns: columns,
+            vector: vec![0; rows * u64s_per_row],
         }
     }
 
-    /// The range of bits for a given element.
-    fn range(&self, element: usize) -> (usize, usize) {
-        let u64s_per_elem = u64s(self.elements);
-        let start = element * u64s_per_elem;
-        (start, start + u64s_per_elem)
+    /// The range of bits for a given row.
+    fn range(&self, row: usize) -> (usize, usize) {
+        let u64s_per_row = u64s(self.columns);
+        let start = row * u64s_per_row;
+        (start, start + u64s_per_row)
     }
 
     pub fn add(&mut self, source: usize, target: usize) -> bool {
@@ -179,7 +179,7 @@ impl BitMatrix {
     pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> {
         let (a_start, a_end) = self.range(a);
         let (b_start, b_end) = self.range(b);
-        let mut result = Vec::with_capacity(self.elements);
+        let mut result = Vec::with_capacity(self.columns);
         for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
             let mut v = self.vector[i] & self.vector[j];
             for bit in 0..64 {
@@ -215,6 +215,15 @@ impl BitMatrix {
         }
         changed
     }
+
+    pub fn iter<'a>(&'a self, row: usize) -> BitVectorIter<'a> {
+        let (start, end) = self.range(row);
+        BitVectorIter {
+            iter: self.vector[start..end].iter(),
+            current: 0,
+            idx: 0,
+        }
+    }
 }
 
 fn u64s(elements: usize) -> usize {
@@ -300,7 +309,7 @@ fn grow() {
 
 #[test]
 fn matrix_intersection() {
-    let mut vec1 = BitMatrix::new(200);
+    let mut vec1 = BitMatrix::new(200, 200);
 
     // (*) Elements reachable from both 2 and 65.
 
@@ -328,3 +337,45 @@ fn matrix_intersection() {
     let intersection = vec1.intersection(2, 65);
     assert_eq!(intersection, &[10, 64, 160]);
 }
+
+#[test]
+fn matrix_iter() {
+    let mut matrix = BitMatrix::new(64, 100);
+    matrix.add(3, 22);
+    matrix.add(3, 75);
+    matrix.add(2, 99);
+    matrix.add(4, 0);
+    matrix.merge(3, 5);
+
+    let expected = [99];
+    let mut iter = expected.iter();
+    for i in matrix.iter(2) {
+        let j = *iter.next().unwrap();
+        assert_eq!(i, j);
+    }
+    assert!(iter.next().is_none());
+
+    let expected = [22, 75];
+    let mut iter = expected.iter();
+    for i in matrix.iter(3) {
+        let j = *iter.next().unwrap();
+        assert_eq!(i, j);
+    }
+    assert!(iter.next().is_none());
+
+    let expected = [0];
+    let mut iter = expected.iter();
+    for i in matrix.iter(4) {
+        let j = *iter.next().unwrap();
+        assert_eq!(i, j);
+    }
+    assert!(iter.next().is_none());
+
+    let expected = [22, 75];
+    let mut iter = expected.iter();
+    for i in matrix.iter(5) {
+        let j = *iter.next().unwrap();
+        assert_eq!(i, j);
+    }
+    assert!(iter.next().is_none());
+}
diff --git a/src/librustc_data_structures/transitive_relation.rs b/src/librustc_data_structures/transitive_relation.rs
index c3a2f978e1a..e09e260afc8 100644
--- a/src/librustc_data_structures/transitive_relation.rs
+++ b/src/librustc_data_structures/transitive_relation.rs
@@ -252,7 +252,8 @@ impl<T: Debug + PartialEq> TransitiveRelation<T> {
     }
 
     fn compute_closure(&self) -> BitMatrix {
-        let mut matrix = BitMatrix::new(self.elements.len());
+        let mut matrix = BitMatrix::new(self.elements.len(),
+                                        self.elements.len());
         let mut changed = true;
         while changed {
             changed = false;