about summary refs log tree commit diff
path: root/src/libstd
diff options
context:
space:
mode:
authorStein Somers <git@steinsomers.be>2018-12-21 14:56:52 +0100
committerStein Somers <git@steinsomers.be>2018-12-21 14:57:56 +0100
commitf9f71cc32497ee4e3cbc7d9795bcf358a9268c13 (patch)
tree7704d0a8ade79d354e6c7f0f623ae562abaaab80 /src/libstd
parent01c6ea2f37326674bf6ca64de55bf5fd90e45015 (diff)
downloadrust-f9f71cc32497ee4e3cbc7d9795bcf358a9268c13.tar.gz
rust-f9f71cc32497ee4e3cbc7d9795bcf358a9268c13.zip
Fix poor worst case performance of set intersection (and union, somewhat) on asymmetrically sized sets and extend unit tests slightly beyond that
Diffstat (limited to 'src/libstd')
-rw-r--r--src/libstd/collections/hash/set.rs62
1 files changed, 56 insertions, 6 deletions
diff --git a/src/libstd/collections/hash/set.rs b/src/libstd/collections/hash/set.rs
index d3267e4e801..8b1aafaa99a 100644
--- a/src/libstd/collections/hash/set.rs
+++ b/src/libstd/collections/hash/set.rs
@@ -420,9 +420,16 @@ impl<T, S> HashSet<T, S>
     /// ```
     #[stable(feature = "rust1", since = "1.0.0")]
     pub fn intersection<'a>(&'a self, other: &'a HashSet<T, S>) -> Intersection<'a, T, S> {
-        Intersection {
-            iter: self.iter(),
-            other,
+        if self.len() <= other.len() {
+            Intersection {
+                iter: self.iter(),
+                other,
+            }
+        } else {
+            Intersection {
+                iter: other.iter(),
+                other: self,
+            }
         }
     }
 
@@ -446,7 +453,15 @@ impl<T, S> HashSet<T, S>
     /// ```
     #[stable(feature = "rust1", since = "1.0.0")]
     pub fn union<'a>(&'a self, other: &'a HashSet<T, S>) -> Union<'a, T, S> {
-        Union { iter: self.iter().chain(other.difference(self)) }
+        if self.len() <= other.len() {
+            Union {
+                iter: self.iter().chain(other.difference(self)),
+            }
+        } else {
+            Union {
+                iter: other.iter().chain(self.difference(other)),
+            }
+        }
     }
 
     /// Returns the number of elements in the set.
@@ -1504,6 +1519,8 @@ mod test_set {
     fn test_intersection() {
         let mut a = HashSet::new();
         let mut b = HashSet::new();
+        assert!(a.intersection(&b).next().is_none());
+        assert!(b.intersection(&a).next().is_none());
 
         assert!(a.insert(11));
         assert!(a.insert(1));
@@ -1528,6 +1545,22 @@ mod test_set {
             i += 1
         }
         assert_eq!(i, expected.len());
+
+        assert!(a.insert(9)); // make a bigger than b
+
+        i = 0;
+        for x in a.intersection(&b) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
+
+        i = 0;
+        for x in b.intersection(&a) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
     }
 
     #[test]
@@ -1583,11 +1616,11 @@ mod test_set {
     fn test_union() {
         let mut a = HashSet::new();
         let mut b = HashSet::new();
+        assert!(a.union(&b).next().is_none());
+        assert!(b.union(&a).next().is_none());
 
         assert!(a.insert(1));
         assert!(a.insert(3));
-        assert!(a.insert(5));
-        assert!(a.insert(9));
         assert!(a.insert(11));
         assert!(a.insert(16));
         assert!(a.insert(19));
@@ -1607,6 +1640,23 @@ mod test_set {
             i += 1
         }
         assert_eq!(i, expected.len());
+
+        assert!(a.insert(9)); // make a bigger than b
+        assert!(a.insert(5));
+
+        i = 0;
+        for x in a.union(&b) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
+
+        i = 0;
+        for x in b.union(&a) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
     }
 
     #[test]