From f9f71cc32497ee4e3cbc7d9795bcf358a9268c13 Mon Sep 17 00:00:00 2001 From: Stein Somers Date: Fri, 21 Dec 2018 14:56:52 +0100 Subject: Fix poor worst case performance of set intersection (and union, somewhat) on asymmetrically sized sets and extend unit tests slightly beyond that --- src/libstd/collections/hash/set.rs | 62 ++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 6 deletions(-) (limited to 'src/libstd') diff --git a/src/libstd/collections/hash/set.rs b/src/libstd/collections/hash/set.rs index d3267e4e801..8b1aafaa99a 100644 --- a/src/libstd/collections/hash/set.rs +++ b/src/libstd/collections/hash/set.rs @@ -420,9 +420,16 @@ impl HashSet /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn intersection<'a>(&'a self, other: &'a HashSet) -> Intersection<'a, T, S> { - Intersection { - iter: self.iter(), - other, + if self.len() <= other.len() { + Intersection { + iter: self.iter(), + other, + } + } else { + Intersection { + iter: other.iter(), + other: self, + } } } @@ -446,7 +453,15 @@ impl HashSet /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn union<'a>(&'a self, other: &'a HashSet) -> Union<'a, T, S> { - Union { iter: self.iter().chain(other.difference(self)) } + if self.len() <= other.len() { + Union { + iter: self.iter().chain(other.difference(self)), + } + } else { + Union { + iter: other.iter().chain(self.difference(other)), + } + } } /// Returns the number of elements in the set. @@ -1504,6 +1519,8 @@ mod test_set { fn test_intersection() { let mut a = HashSet::new(); let mut b = HashSet::new(); + assert!(a.intersection(&b).next().is_none()); + assert!(b.intersection(&a).next().is_none()); assert!(a.insert(11)); assert!(a.insert(1)); @@ -1528,6 +1545,22 @@ mod test_set { i += 1 } assert_eq!(i, expected.len()); + + assert!(a.insert(9)); // make a bigger than b + + i = 0; + for x in a.intersection(&b) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); + + i = 0; + for x in b.intersection(&a) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); } #[test] @@ -1583,11 +1616,11 @@ mod test_set { fn test_union() { let mut a = HashSet::new(); let mut b = HashSet::new(); + assert!(a.union(&b).next().is_none()); + assert!(b.union(&a).next().is_none()); assert!(a.insert(1)); assert!(a.insert(3)); - assert!(a.insert(5)); - assert!(a.insert(9)); assert!(a.insert(11)); assert!(a.insert(16)); assert!(a.insert(19)); @@ -1607,6 +1640,23 @@ mod test_set { i += 1 } assert_eq!(i, expected.len()); + + assert!(a.insert(9)); // make a bigger than b + assert!(a.insert(5)); + + i = 0; + for x in a.union(&b) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); + + i = 0; + for x in b.union(&a) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); } #[test] -- cgit 1.4.1-3-g733a5