about summary refs log tree commit diff
path: root/src/libstd
diff options
context:
space:
mode:
authorMazdak Farrokhzad <twingoow@gmail.com>2019-01-14 20:31:51 +0100
committerGitHub <noreply@github.com>2019-01-14 20:31:51 +0100
commit5bc95de47d960f7a4082798237f438ac8d9f225b (patch)
treeb4bee7a59c0c904ccf620e68db1ca36838131c1b /src/libstd
parentd10680818b2a0aabb76e6a07098e031b31707fcc (diff)
parentcef2e2f3d53795787085bf63a6c2a8563e7ba9c9 (diff)
downloadrust-5bc95de47d960f7a4082798237f438ac8d9f225b.tar.gz
rust-5bc95de47d960f7a4082798237f438ac8d9f225b.zip
Rollup merge of #57043 - ssomers:master, r=alexcrichton
Fix poor worst case performance of set intersection

Specifically, intersection of asymmetrically sized sets when the large set is on the left. See also the [latest answer on stackoverflow](https://stackoverflow.com/questions/35439376/python-set-intersection-is-faster-then-rust-hashset-intersection).

Also applied to the union member, where the effect is much less but still measurable.

Formatted the changed code only, does not increase the error count reported by tidy check, and tried to adhere to the spirit of the unit tests.
Diffstat (limited to 'src/libstd')
-rw-r--r--src/libstd/collections/hash/set.rs67
1 files changed, 60 insertions, 7 deletions
diff --git a/src/libstd/collections/hash/set.rs b/src/libstd/collections/hash/set.rs
index 92145907b95..c55dd049ec6 100644
--- a/src/libstd/collections/hash/set.rs
+++ b/src/libstd/collections/hash/set.rs
@@ -410,9 +410,16 @@ impl<T, S> HashSet<T, S>
     /// ```
     #[stable(feature = "rust1", since = "1.0.0")]
     pub fn intersection<'a>(&'a self, other: &'a HashSet<T, S>) -> Intersection<'a, T, S> {
-        Intersection {
-            iter: self.iter(),
-            other,
+        if self.len() <= other.len() {
+            Intersection {
+                iter: self.iter(),
+                other,
+            }
+        } else {
+            Intersection {
+                iter: other.iter(),
+                other: self,
+            }
         }
     }
 
@@ -436,7 +443,15 @@ impl<T, S> HashSet<T, S>
     /// ```
     #[stable(feature = "rust1", since = "1.0.0")]
     pub fn union<'a>(&'a self, other: &'a HashSet<T, S>) -> Union<'a, T, S> {
-        Union { iter: self.iter().chain(other.difference(self)) }
+        if self.len() <= other.len() {
+            Union {
+                iter: self.iter().chain(other.difference(self)),
+            }
+        } else {
+            Union {
+                iter: other.iter().chain(self.difference(other)),
+            }
+        }
     }
 
     /// Returns the number of elements in the set.
@@ -584,7 +599,11 @@ impl<T, S> HashSet<T, S>
     /// ```
     #[stable(feature = "rust1", since = "1.0.0")]
     pub fn is_disjoint(&self, other: &HashSet<T, S>) -> bool {
-        self.iter().all(|v| !other.contains(v))
+        if self.len() <= other.len() {
+            self.iter().all(|v| !other.contains(v))
+        } else {
+            other.iter().all(|v| !self.contains(v))
+        }
     }
 
     /// Returns `true` if the set is a subset of another,
@@ -1494,6 +1513,7 @@ mod test_set {
     fn test_intersection() {
         let mut a = HashSet::new();
         let mut b = HashSet::new();
+        assert!(a.intersection(&b).next().is_none());
 
         assert!(a.insert(11));
         assert!(a.insert(1));
@@ -1518,6 +1538,22 @@ mod test_set {
             i += 1
         }
         assert_eq!(i, expected.len());
+
+        assert!(a.insert(9)); // make a bigger than b
+
+        i = 0;
+        for x in a.intersection(&b) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
+
+        i = 0;
+        for x in b.intersection(&a) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
     }
 
     #[test]
@@ -1573,11 +1609,11 @@ mod test_set {
     fn test_union() {
         let mut a = HashSet::new();
         let mut b = HashSet::new();
+        assert!(a.union(&b).next().is_none());
+        assert!(b.union(&a).next().is_none());
 
         assert!(a.insert(1));
         assert!(a.insert(3));
-        assert!(a.insert(5));
-        assert!(a.insert(9));
         assert!(a.insert(11));
         assert!(a.insert(16));
         assert!(a.insert(19));
@@ -1597,6 +1633,23 @@ mod test_set {
             i += 1
         }
         assert_eq!(i, expected.len());
+
+        assert!(a.insert(9)); // make a bigger than b
+        assert!(a.insert(5));
+
+        i = 0;
+        for x in a.union(&b) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
+
+        i = 0;
+        for x in b.union(&a) {
+            assert!(expected.contains(x));
+            i += 1
+        }
+        assert_eq!(i, expected.len());
     }
 
     #[test]