about summary refs log tree commit diff
path: root/src/libstd
diff options
context:
space:
mode:
authorAlex Crichton <alex@alexcrichton.com>2013-02-17 20:01:47 -0500
committerAlex Crichton <alex@alexcrichton.com>2013-02-17 23:09:21 -0500
commitbf8ed45adc485e0e8e678e7b43b0c67ff93392f5 (patch)
tree074f021c34a2dcd43f23b104a694f46847ea39a6 /src/libstd
parent393a4b41f60612f234394b58b8e3bf3261ca9566 (diff)
downloadrust-bf8ed45adc485e0e8e678e7b43b0c67ff93392f5.tar.gz
rust-bf8ed45adc485e0e8e678e7b43b0c67ff93392f5.zip
Implement Set container on top of a bit vector
Diffstat (limited to 'src/libstd')
-rw-r--r--src/libstd/bitv.rs510
1 files changed, 470 insertions, 40 deletions
diff --git a/src/libstd/bitv.rs b/src/libstd/bitv.rs
index c01e0d9d94c..955729ed2fe 100644
--- a/src/libstd/bitv.rs
+++ b/src/libstd/bitv.rs
@@ -8,10 +8,12 @@
 // option. This file may not be copied, modified, or distributed
 // except according to those terms.
 
+use core::container::{Container, Mutable, Set};
+use core::num::NumCast;
 use core::ops;
 use core::prelude::*;
 use core::uint;
-use core::vec::{cast_to_mut, from_elem};
+use core::vec::from_elem;
 use core::vec;
 
 struct SmallBitv {
@@ -133,18 +135,15 @@ impl BigBitv {
         let len = b.storage.len();
         assert (self.storage.len() == len);
         let mut changed = false;
-        do uint::range(0, len) |i| {
+        for uint::range(0, len) |i| {
             let mask = big_mask(nbits, i);
             let w0 = self.storage[i] & mask;
             let w1 = b.storage[i] & mask;
             let w = op(w0, w1) & mask;
             if w0 != w {
-                unsafe {
-                    changed = true;
-                    self.storage[i] = w;
-                }
+                changed = true;
+                self.storage[i] = w;
             }
-            true
         }
         changed
     }
@@ -556,13 +555,314 @@ pub fn from_fn(len: uint, f: fn(index: uint) -> bool) -> Bitv {
     bitv
 }
 
+impl ops::Index<uint,bool> for Bitv {
+    pure fn index(&self, i: uint) -> bool {
+        self.get(i)
+    }
+}
 
+#[inline(always)]
+pure fn iterate_bits(base: uint, bits: uint, f: fn(uint) -> bool) -> bool {
+    if bits == 0 {
+        return true;
+    }
+    for uint::range(0, uint::bits) |i| {
+        if bits & (1 << i) != 0 {
+            if !f(base + i) {
+                return false;
+            }
+        }
+    }
+    return true;
+}
 
+/// An implementation of a set using a bit vector as an underlying
+/// representation for holding numerical elements.
+///
+/// It should also be noted that the amount of storage necessary for holding a
+/// set of objects is proportional to the maximum of the objects when viewed
+/// as a uint.
+pub struct BitvSet {
+    priv size: uint,
+
+    // In theory this is a Bitv instead of always a BigBitv, but knowing that
+    // there's an array of storage makes our lives a whole lot easier when
+    // performing union/intersection/etc operations
+    priv bitv: BigBitv
+}
 
+impl BitvSet {
+    /// Creates a new bit vector set with initially no contents
+    static fn new() -> BitvSet {
+        BitvSet{ size: 0, bitv: BigBitv::new(~[0]) }
+    }
 
-impl ops::Index<uint,bool> for Bitv {
-    pure fn index(&self, i: uint) -> bool {
-        self.get(i)
+    /// Creates a new bit vector set from the given bit vector
+    static fn from_bitv(bitv: Bitv) -> BitvSet {
+        let mut size = 0;
+        for bitv.ones |_| {
+            size += 1;
+        }
+        let Bitv{rep, _} = bitv;
+        match rep {
+            Big(~b) => BitvSet{ size: size, bitv: b },
+            Small(~SmallBitv{bits}) =>
+                BitvSet{ size: size, bitv: BigBitv{ storage: ~[bits] } },
+        }
+    }
+
+    /// Returns the capacity in bits for this bit vector. Inserting any
+    /// element less than this amount will not trigger a resizing.
+    pure fn capacity(&self) -> uint { self.bitv.storage.len() * uint::bits }
+
+    /// Consumes this set to return the underlying bit vector
+    fn unwrap(self) -> Bitv {
+        let cap = self.capacity();
+        let BitvSet{bitv, _} = self;
+        return Bitv{ nbits:cap, rep: Big(~bitv) };
+    }
+
+    #[inline(always)]
+    priv fn other_op(&mut self, other: &BitvSet, f: fn(uint, uint) -> uint) {
+        fn nbits(mut w: uint) -> uint {
+            let mut bits = 0;
+            for uint::bits.times {
+                if w == 0 {
+                    break;
+                }
+                bits += w & 1;
+                w >>= 1;
+            }
+            return bits;
+        }
+        if self.capacity() < other.capacity() {
+            self.bitv.storage.grow(other.capacity() / uint::bits, &0);
+        }
+        for other.bitv.storage.eachi |i, &w| {
+            let old = self.bitv.storage[i];
+            let new = f(old, w);
+            self.bitv.storage[i] = new;
+            self.size += nbits(new) - nbits(old);
+        }
+    }
+
+    /// Union in-place with the specified other bit vector
+    fn union_with(&mut self, other: &BitvSet) {
+        self.other_op(other, |w1, w2| w1 | w2);
+    }
+
+    /// Intersect in-place with the specified other bit vector
+    fn intersect_with(&mut self, other: &BitvSet) {
+        self.other_op(other, |w1, w2| w1 & w2);
+    }
+
+    /// Difference in-place with the specified other bit vector
+    fn difference_with(&mut self, other: &BitvSet) {
+        self.other_op(other, |w1, w2| w1 & !w2);
+    }
+
+    /// Symmetric difference in-place with the specified other bit vector
+    fn symmetric_difference_with(&mut self, other: &BitvSet) {
+        self.other_op(other, |w1, w2| w1 ^ w2);
+    }
+}
+
+impl BaseIter<uint> for BitvSet {
+    pure fn size_hint(&self) -> Option<uint> { Some(self.len()) }
+
+    pure fn each(&self, blk: fn(v: &uint) -> bool) {
+        for self.bitv.storage.eachi |i, &w| {
+            if !iterate_bits(i * uint::bits, w, |b| blk(&b)) {
+                return;
+            }
+        }
+    }
+}
+
+impl cmp::Eq for BitvSet {
+    pure fn eq(&self, other: &BitvSet) -> bool {
+        if self.size != other.size {
+            return false;
+        }
+        for self.each_common(other) |_, w1, w2| {
+            if w1 != w2 {
+                return false;
+            }
+        }
+        for self.each_outlier(other) |_, _, w| {
+            if w != 0 {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    pure fn ne(&self, other: &BitvSet) -> bool { !self.eq(other) }
+}
+
+impl Container for BitvSet {
+    pure fn len(&self) -> uint { self.size }
+    pure fn is_empty(&self) -> bool { self.size == 0 }
+}
+
+impl Mutable for BitvSet {
+    fn clear(&mut self) {
+        for self.bitv.each_storage |w| { *w = 0; }
+        self.size = 0;
+    }
+}
+
+impl Set<uint> for BitvSet {
+    pure fn contains(&self, value: &uint) -> bool {
+        *value < self.bitv.storage.len() * uint::bits && self.bitv.get(*value)
+    }
+
+    fn insert(&mut self, value: uint) -> bool {
+        if self.contains(&value) {
+            return false;
+        }
+        let nbits = self.capacity();
+        if value >= nbits {
+            let newsize = uint::max(value, nbits * 2) / uint::bits + 1;
+            assert newsize > self.bitv.storage.len();
+            self.bitv.storage.grow(newsize, &0);
+        }
+        self.size += 1;
+        self.bitv.set(value, true);
+        return true;
+    }
+
+    fn remove(&mut self, value: &uint) -> bool {
+        if !self.contains(value) {
+            return false;
+        }
+        self.size -= 1;
+        self.bitv.set(*value, false);
+
+        // Attempt to truncate our storage
+        let mut i = self.bitv.storage.len();
+        while i > 1 && self.bitv.storage[i - 1] == 0 {
+            i -= 1;
+        }
+        self.bitv.storage.truncate(i);
+
+        return true;
+    }
+
+    pure fn is_disjoint(&self, other: &BitvSet) -> bool {
+        for self.intersection(other) |_| {
+            return false;
+        }
+        return true;
+    }
+
+    pure fn is_subset(&self, other: &BitvSet) -> bool {
+        for self.each_common(other) |_, w1, w2| {
+            if w1 & w2 != w1 {
+                return false;
+            }
+        }
+        /* If anything is not ours, then everything is not ours so we're
+           definitely a subset in that case. Otherwise if there's any stray
+           ones that 'other' doesn't have, we're not a subset. */
+        for self.each_outlier(other) |mine, _, w| {
+            if !mine {
+                return true;
+            } else if w != 0 {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    pure fn is_superset(&self, other: &BitvSet) -> bool {
+        other.is_subset(self)
+    }
+
+    pure fn difference(&self, other: &BitvSet, f: fn(&uint) -> bool) {
+        for self.each_common(other) |i, w1, w2| {
+            if !iterate_bits(i, w1 & !w2, |b| f(&b)) {
+                return;
+            }
+        }
+        /* everything we have that they don't also shows up */
+        self.each_outlier(other, |mine, i, w|
+            !mine || iterate_bits(i, w, |b| f(&b))
+        );
+    }
+
+    pure fn symmetric_difference(&self, other: &BitvSet,
+                                 f: fn(&uint) -> bool) {
+        for self.each_common(other) |i, w1, w2| {
+            if !iterate_bits(i, w1 ^ w2, |b| f(&b)) {
+                return;
+            }
+        }
+        self.each_outlier(other, |_, i, w|
+            iterate_bits(i, w, |b| f(&b))
+        );
+    }
+
+    pure fn intersection(&self, other: &BitvSet, f: fn(&uint) -> bool) {
+        for self.each_common(other) |i, w1, w2| {
+            if !iterate_bits(i, w1 & w2, |b| f(&b)) {
+                return;
+            }
+        }
+    }
+
+    pure fn union(&self, other: &BitvSet, f: fn(&uint) -> bool) {
+        for self.each_common(other) |i, w1, w2| {
+            if !iterate_bits(i, w1 | w2, |b| f(&b)) {
+                return;
+            }
+        }
+        self.each_outlier(other, |_, i, w|
+            iterate_bits(i, w, |b| f(&b))
+        );
+    }
+}
+
+priv impl BitvSet {
+    /// Visits each of the words that the two bit vectors (self and other)
+    /// both have in common. The three yielded arguments are (bit location,
+    /// w1, w2) where the bit location is the number of bits offset so far,
+    /// and w1/w2 are the words coming from the two vectors self, other.
+    pure fn each_common(&self, other: &BitvSet,
+                        f: fn(uint, uint, uint) -> bool) {
+        let min = uint::min(self.bitv.storage.len(),
+                            other.bitv.storage.len());
+        for self.bitv.storage.view(0, min).eachi |i, &w| {
+            if !f(i * uint::bits, w, other.bitv.storage[i]) {
+                return;
+            }
+        }
+    }
+
+    /// Visits each word in self or other that extends beyond the other. This
+    /// will only iterate through one of the vectors, and it only iterates
+    /// over the portion that doesn't overlap with the other one.
+    ///
+    /// The yielded arguments are a bool, the bit offset, and a word. The bool
+    /// is true if the word comes from 'self', and false if it comes from
+    /// 'other'.
+    pure fn each_outlier(&self, other: &BitvSet,
+                         f: fn(bool, uint, uint) -> bool) {
+        let len1 = self.bitv.storage.len();
+        let len2 = other.bitv.storage.len();
+        let min = uint::min(len1, len2);
+
+        /* only one of these loops will execute and that's the point */
+        for self.bitv.storage.view(min, len1).eachi |i, &w| {
+            if !f(true, (i + min) * uint::bits, w) {
+                return;
+            }
+        }
+        for other.bitv.storage.view(min, len2).eachi |i, &w| {
+            if !f(false, (i + min) * uint::bits, w) {
+                return;
+            }
+        }
     }
 }
 
@@ -946,48 +1246,178 @@ mod tests {
 
     #[test]
     pub fn test_small_difference() {
-      let mut b1 = Bitv::new(3, false);
-      let mut b2 = Bitv::new(3, false);
-      b1.set(0, true);
-      b1.set(1, true);
-      b2.set(1, true);
-      b2.set(2, true);
-      assert b1.difference(&b2);
-      assert b1[0];
-      assert !b1[1];
-      assert !b1[2];
+        let mut b1 = Bitv::new(3, false);
+        let mut b2 = Bitv::new(3, false);
+        b1.set(0, true);
+        b1.set(1, true);
+        b2.set(1, true);
+        b2.set(2, true);
+        assert b1.difference(&b2);
+        assert b1[0];
+        assert !b1[1];
+        assert !b1[2];
     }
 
     #[test]
     pub fn test_big_difference() {
-      let mut b1 = Bitv::new(100, false);
-      let mut b2 = Bitv::new(100, false);
-      b1.set(0, true);
-      b1.set(40, true);
-      b2.set(40, true);
-      b2.set(80, true);
-      assert b1.difference(&b2);
-      assert b1[0];
-      assert !b1[40];
-      assert !b1[80];
+        let mut b1 = Bitv::new(100, false);
+        let mut b2 = Bitv::new(100, false);
+        b1.set(0, true);
+        b1.set(40, true);
+        b2.set(40, true);
+        b2.set(80, true);
+        assert b1.difference(&b2);
+        assert b1[0];
+        assert !b1[40];
+        assert !b1[80];
     }
 
     #[test]
     pub fn test_small_clear() {
-      let mut b = Bitv::new(14, true);
-      b.clear();
-      for b.ones |i| {
-          fail!(fmt!("found 1 at %?", i));
-      }
+        let mut b = Bitv::new(14, true);
+        b.clear();
+        for b.ones |i| {
+            fail!(fmt!("found 1 at %?", i));
+        }
     }
 
     #[test]
     pub fn test_big_clear() {
-      let mut b = Bitv::new(140, true);
-      b.clear();
-      for b.ones |i| {
-          fail!(fmt!("found 1 at %?", i));
-      }
+        let mut b = Bitv::new(140, true);
+        b.clear();
+        for b.ones |i| {
+            fail!(fmt!("found 1 at %?", i));
+        }
+    }
+
+    #[test]
+    pub fn test_bitv_set_basic() {
+        let mut b = BitvSet::new();
+        assert b.insert(3);
+        assert !b.insert(3);
+        assert b.contains(&3);
+        assert b.insert(400);
+        assert !b.insert(400);
+        assert b.contains(&400);
+        assert b.len() == 2;
+    }
+
+    #[test]
+    fn test_bitv_set_intersection() {
+        let mut a = BitvSet::new();
+        let mut b = BitvSet::new();
+
+        assert a.insert(11);
+        assert a.insert(1);
+        assert a.insert(3);
+        assert a.insert(77);
+        assert a.insert(103);
+        assert a.insert(5);
+
+        assert b.insert(2);
+        assert b.insert(11);
+        assert b.insert(77);
+        assert b.insert(5);
+        assert b.insert(3);
+
+        let mut i = 0;
+        let expected = [3, 5, 11, 77];
+        for a.intersection(&b) |x| {
+            assert *x == expected[i];
+            i += 1
+        }
+        assert i == expected.len();
+    }
+
+    #[test]
+    fn test_bitv_set_difference() {
+        let mut a = BitvSet::new();
+        let mut b = BitvSet::new();
+
+        assert a.insert(1);
+        assert a.insert(3);
+        assert a.insert(5);
+        assert a.insert(200);
+        assert a.insert(500);
+
+        assert b.insert(3);
+        assert b.insert(200);
+
+        let mut i = 0;
+        let expected = [1, 5, 500];
+        for a.difference(&b) |x| {
+            assert *x == expected[i];
+            i += 1
+        }
+        assert i == expected.len();
+    }
+
+    #[test]
+    fn test_bitv_set_symmetric_difference() {
+        let mut a = BitvSet::new();
+        let mut b = BitvSet::new();
+
+        assert a.insert(1);
+        assert a.insert(3);
+        assert a.insert(5);
+        assert a.insert(9);
+        assert a.insert(11);
+
+        assert b.insert(3);
+        assert b.insert(9);
+        assert b.insert(14);
+        assert b.insert(220);
+
+        let mut i = 0;
+        let expected = [1, 5, 11, 14, 220];
+        for a.symmetric_difference(&b) |x| {
+            assert *x == expected[i];
+            i += 1
+        }
+        assert i == expected.len();
+    }
+
+    #[test]
+    pub fn test_bitv_set_union() {
+        let mut a = BitvSet::new();
+        let mut b = BitvSet::new();
+        assert a.insert(1);
+        assert a.insert(3);
+        assert a.insert(5);
+        assert a.insert(9);
+        assert a.insert(11);
+        assert a.insert(160);
+        assert a.insert(19);
+        assert a.insert(24);
+
+        assert b.insert(1);
+        assert b.insert(5);
+        assert b.insert(9);
+        assert b.insert(13);
+        assert b.insert(19);
+
+        let mut i = 0;
+        let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160];
+        for a.union(&b) |x| {
+            assert *x == expected[i];
+            i += 1
+        }
+        assert i == expected.len();
+    }
+
+    #[test]
+    pub fn test_bitv_remove() {
+        let mut a = BitvSet::new();
+
+        assert a.insert(1);
+        assert a.remove(&1);
+
+        assert a.insert(100);
+        assert a.remove(&100);
+
+        assert a.insert(1000);
+        assert a.remove(&1000);
+        assert a.capacity() == uint::bits;
     }
 }