about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/libcollections/bitv.rs170
1 files changed, 87 insertions, 83 deletions
diff --git a/src/libcollections/bitv.rs b/src/libcollections/bitv.rs
index 20d7c3ef2cf..b480b88b4d4 100644
--- a/src/libcollections/bitv.rs
+++ b/src/libcollections/bitv.rs
@@ -24,22 +24,6 @@ use std::hash;
 use {Collection, Mutable, Set, MutableSet};
 use vec::Vec;
 
-/**
- * A mask that has a 1 for each defined bit in the n'th element of a `Bitv`,
- * assuming n bits.
- */
-#[inline]
-fn big_mask(nbits: uint, elem: uint) -> uint {
-    let rmd = nbits % uint::BITS;
-    let nelems = (nbits + uint::BITS - 1) / uint::BITS;
-
-    if elem < nelems - 1 || rmd == 0 {
-        !0
-    } else {
-        (1 << rmd) - 1
-    }
-}
-
 /// The bitvector type
 ///
 /// # Example
@@ -75,35 +59,47 @@ pub struct Bitv {
     nbits: uint
 }
 
-struct Words<'a> {
+struct MaskWords<'a> {
     iter: slice::Items<'a, uint>,
+    next_word: Option<&'a uint>,
+    last_word_mask: uint,
     offset: uint
 }
 
-impl<'a> Iterator<(uint, uint)> for Words<'a> {
+impl<'a> Iterator<(uint, uint)> for MaskWords<'a> {
     /// Returns (offset, word)
     fn next<'a>(&'a mut self) -> Option<(uint, uint)> {
-        let ret = self.iter.next().map(|&n| (self.offset, n));
-        self.offset += 1;
-        ret
+        let ret = self.next_word;
+        match ret {
+            Some(&w) => {
+                self.next_word = self.iter.next();
+                self.offset += 1;
+                // The last word may need to be masked
+                if self.next_word.is_none() {
+                    Some((self.offset - 1, w & self.last_word_mask))
+                } else {
+                    Some((self.offset - 1, w))
+                }
+            },
+            None => None
+        }
     }
 }
 
 impl Bitv {
     #[inline]
-    fn process(&mut self, other: &Bitv, nbits: uint,
-               op: |uint, uint| -> uint) -> bool {
+    fn process(&mut self, other: &Bitv, op: |uint, uint| -> uint) -> bool {
         let len = other.storage.len();
         assert_eq!(self.storage.len(), len);
         let mut changed = false;
-        for (i, (a, b)) in self.storage.mut_iter()
-                               .zip(other.storage.iter())
-                               .enumerate() {
-            let mask = big_mask(nbits, i);
-            let w0 = *a & mask;
-            let w1 = *b & mask;
-            let w = op(w0, w1) & mask;
-            if w0 != w {
+        // Notice: `a` is *not* masked here, which is fine as long as
+        // `op` is a bitwise operation, since any bits that should've
+        // been masked were fine to change anyway. `b` is masked to
+        // make sure its unmasked bits do not cause damage.
+        for (a, (_, b)) in self.storage.mut_iter()
+                           .zip(other.mask_words(0)) {
+            let w = op(*a, b);
+            if *a != w {
                 changed = true;
                 *a = w;
             }
@@ -112,10 +108,20 @@ impl Bitv {
     }
 
     #[inline]
-    #[inline]
-    fn words<'a>(&'a self, start: uint) -> Words<'a> {
-        Words {
-          iter: self.storage.slice_from(start).iter(),
+    fn mask_words<'a>(&'a self, mut start: uint) -> MaskWords<'a> {
+        if start > self.storage.len() {
+            start = self.storage.len();
+        }
+        let mut iter = self.storage.slice_from(start).iter();
+        MaskWords {
+          next_word: iter.next(),
+          iter: iter,
+          last_word_mask: {
+              let rem = self.nbits % uint::BITS;
+              if rem > 0 {
+                  (1 << rem) - 1
+              } else { !0 }
+          },
           offset: start
         }
     }
@@ -124,15 +130,8 @@ impl Bitv {
     /// to `init`.
     pub fn new(nbits: uint, init: bool) -> Bitv {
         Bitv {
-            storage: {
-                let nelems = (nbits + uint::BITS - 1) / uint::BITS;
-                let mut v = Vec::from_elem(nelems, if init { !0u } else { 0u });
-                // Zero out any remainder bits
-                if nbits % uint::BITS > 0 {
-                    *v.get_mut(nelems - 1) &= (1 << nbits % uint::BITS) - 1;
-                }
-                v
-            },
+            storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS,
+                                    if init { !0u } else { 0u }),
             nbits: nbits
         }
     }
@@ -145,8 +144,7 @@ impl Bitv {
     */
     #[inline]
     pub fn union(&mut self, other: &Bitv) -> bool {
-        let nbits = self.nbits;
-        self.process(other, nbits, |w1, w2| w1 | w2)
+        self.process(other, |w1, w2| w1 | w2)
     }
 
     /**
@@ -157,8 +155,7 @@ impl Bitv {
     */
     #[inline]
     pub fn intersect(&mut self, other: &Bitv) -> bool {
-        let nbits = self.nbits;
-        self.process(other, nbits, |w1, w2| w1 & w2)
+        self.process(other, |w1, w2| w1 & w2)
     }
 
     /**
@@ -169,8 +166,7 @@ impl Bitv {
      */
     #[inline]
     pub fn assign(&mut self, other: &Bitv) -> bool {
-        let nbits = self.nbits;
-        self.process(other, nbits, |_, w| w)
+        self.process(other, |_, w| w)
     }
 
     /// Retrieve the value at index `i`
@@ -227,20 +223,18 @@ impl Bitv {
      */
     #[inline]
     pub fn difference(&mut self, other: &Bitv) -> bool {
-        let nbits = self.nbits;
-        self.process(other, nbits, |w1, w2| w1 & !w2)
+        self.process(other, |w1, w2| w1 & !w2)
     }
 
     /// Returns `true` if all bits are 1
     #[inline]
     pub fn all(&self) -> bool {
-        for (i, &elem) in self.storage.iter().enumerate() {
-            let mask = big_mask(self.nbits, i);
-            if elem & mask != mask {
-                return false;
-            }
-        }
-        true
+        let mut last_word = !0u;
+        // Check that every word but the last is all-ones...
+        self.mask_words(0).all(|(_, elem)|
+            { let tmp = last_word; last_word = elem; tmp == !0u }) &&
+        // ...and that the last word is ones as far as it needs to be
+        (last_word == ((1 << self.nbits % uint::BITS) - 1) || last_word == !0u)
     }
 
     /// Returns an iterator over the elements of the vector in order.
@@ -265,13 +259,7 @@ impl Bitv {
 
     /// Returns `true` if all bits are 0
     pub fn none(&self) -> bool {
-        for (i, &elem) in self.storage.iter().enumerate() {
-            let mask = big_mask(self.nbits, i);
-            if elem & mask != 0 {
-                return false;
-            }
-        }
-        true
+        self.mask_words(0).all(|(_, w)| w == 0)
     }
 
     #[inline]
@@ -397,8 +385,8 @@ impl fmt::Show for Bitv {
 impl<S: hash::Writer> hash::Hash<S> for Bitv {
     fn hash(&self, state: &mut S) {
         self.nbits.hash(state);
-        for (i, elem) in self.storage.iter().enumerate() {
-            (elem & big_mask(self.nbits, i)).hash(state);
+        for (_, elem) in self.mask_words(0) {
+            elem.hash(state);
         }
     }
 }
@@ -409,13 +397,7 @@ impl cmp::PartialEq for Bitv {
         if self.nbits != other.nbits {
             return false;
         }
-        for (i, (&w1, &w2)) in self.storage.iter().zip(other.storage.iter()).enumerate() {
-            let mask = big_mask(self.nbits, i);
-            if w1 & mask != w2 & mask {
-                return false;
-            }
-        }
-        true
+        self.mask_words(0).zip(other.mask_words(0)).all(|((_, w1), (_, w2))| w1 == w2)
     }
 }
 
@@ -546,7 +528,7 @@ impl BitvSet {
         // Unwrap Bitvs
         let &BitvSet(ref mut self_bitv) = self;
         let &BitvSet(ref other_bitv) = other;
-        for (i, w) in other_bitv.words(0) {
+        for (i, w) in other_bitv.mask_words(0) {
             let old = *self_bitv.storage.get(i);
             let new = f(old, w);
             *self_bitv.storage.get_mut(i) = new;
@@ -563,7 +545,7 @@ impl BitvSet {
         let n = bitv.storage.iter().rev().take_while(|&&n| n == 0).count();
         // Truncate
         let trunc_len = cmp::max(old_len - n, 1);
-        bitv.storage.truncate(cmp::max(old_len - n, 1));
+        bitv.storage.truncate(trunc_len);
         bitv.nbits = trunc_len * uint::BITS;
     }
 
@@ -710,6 +692,12 @@ impl MutableSet<uint> for BitvSet {
         }
         let &BitvSet(ref mut bitv) = self;
         if value >= bitv.nbits {
+            // If we are increasing nbits, make sure we mask out any previously-unconsidered bits
+            let old_rem = bitv.nbits % uint::BITS;
+            if old_rem != 0 {
+                let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
+                *bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1;
+            }
             bitv.nbits = value + 1;
         }
         bitv.set(value, true);
@@ -733,10 +721,10 @@ impl BitvSet {
     /// and w1/w2 are the words coming from the two vectors self, other.
     fn commons<'a>(&'a self, other: &'a BitvSet)
         -> Map<((uint, uint), (uint, uint)), (uint, uint, uint),
-               Zip<Words<'a>, Words<'a>>> {
+               Zip<MaskWords<'a>, MaskWords<'a>>> {
         let &BitvSet(ref self_bitv) = self;
         let &BitvSet(ref other_bitv) = other;
-        self_bitv.words(0).zip(other_bitv.words(0))
+        self_bitv.mask_words(0).zip(other_bitv.mask_words(0))
             .map(|((i, w1), (_, w2))| (i * uint::BITS, w1, w2))
     }
 
@@ -748,17 +736,17 @@ impl BitvSet {
     /// is true if the word comes from `self`, and `false` if it comes from
     /// `other`.
     fn outliers<'a>(&'a self, other: &'a BitvSet)
-        -> Map<(uint, uint), (bool, uint, uint), Words<'a>> {
+        -> Map<(uint, uint), (bool, uint, uint), MaskWords<'a>> {
         let slen = self.capacity() / uint::BITS;
         let olen = other.capacity() / uint::BITS;
         let &BitvSet(ref self_bitv) = self;
         let &BitvSet(ref other_bitv) = other;
 
         if olen < slen {
-            self_bitv.words(olen)
+            self_bitv.mask_words(olen)
                 .map(|(i, w)| (true, i * uint::BITS, w))
         } else {
-            other_bitv.words(slen)
+            other_bitv.mask_words(slen)
                 .map(|(i, w)| (false, i * uint::BITS, w))
         }
     }
@@ -1251,15 +1239,31 @@ mod tests {
     }
 
     #[test]
+    fn test_bitv_masking() {
+        let b = Bitv::new(140, true); 
+        let mut bs = BitvSet::from_bitv(b);
+        assert!(bs.contains(&139));
+        assert!(!bs.contains(&140));
+        assert!(bs.insert(150));
+        assert!(!bs.contains(&140));
+        assert!(!bs.contains(&149));
+        assert!(bs.contains(&150));
+        assert!(!bs.contains(&151));
+    }
+
+    #[test]
     fn test_bitv_set_basic() {
         let mut b = BitvSet::new();
         assert!(b.insert(3));
         assert!(!b.insert(3));
         assert!(b.contains(&3));
+        assert!(b.insert(4));
+        assert!(!b.insert(4));
+        assert!(b.contains(&3));
         assert!(b.insert(400));
         assert!(!b.insert(400));
         assert!(b.contains(&400));
-        assert_eq!(b.len(), 2);
+        assert_eq!(b.len(), 3);
     }
 
     #[test]