diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src/coverage/counters')
7 files changed, 651 insertions, 41 deletions
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs b/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs new file mode 100644 index 00000000000..c108f96a564 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs @@ -0,0 +1,133 @@ +//! A control-flow graph can be said to have “balanced flow” if the flow +//! (execution count) of each node is equal to the sum of its in-edge flows, +//! and also equal to the sum of its out-edge flows. +//! +//! Control-flow graphs typically have one or more nodes that don't satisfy the +//! balanced-flow property, e.g.: +//! - The start node has out-edges, but no in-edges. +//! - Return nodes have in-edges, but no out-edges. +//! - `Yield` nodes can have an out-flow that is less than their in-flow. +//! - Inescapable loops cause the in-flow/out-flow relationship to break down. +//! +//! Balanced-flow graphs are nevertheless useful for analysis, so this module +//! provides a wrapper type ([`BalancedFlowGraph`]) that imposes balanced flow +//! on an underlying graph. This is done by non-destructively adding synthetic +//! nodes and edges as necessary. + +use rustc_data_structures::graph; +use rustc_data_structures::graph::iterate::DepthFirstSearch; +use rustc_data_structures::graph::reversed::ReversedGraph; +use rustc_index::Idx; +use rustc_index::bit_set::DenseBitSet; + +use crate::coverage::counters::iter_nodes::IterNodes; + +/// A view of an underlying graph that has been augmented to have “balanced flow”. +/// This means that the flow (execution count) of each node is equal to the +/// sum of its in-edge flows, and also equal to the sum of its out-edge flows. +/// +/// To achieve this, a synthetic "sink" node is non-destructively added to the +/// graph, with synthetic in-edges from these nodes: +/// - Any node that has no out-edges. +/// - Any node that explicitly requires a sink edge, as indicated by a +/// caller-supplied `force_sink_edge` function. +/// - Any node that would otherwise be unable to reach the sink, because it is +/// part of an inescapable loop. +/// +/// To make the graph fully balanced, there is also a synthetic edge from the +/// sink node back to the start node. +/// +/// --- +/// The benefit of having a balanced-flow graph is that it can be subsequently +/// transformed in ways that are guaranteed to preserve balanced flow +/// (e.g. merging nodes together), which is useful for discovering relationships +/// between the node flows of different nodes in the graph. +pub(crate) struct BalancedFlowGraph<G: graph::DirectedGraph> { + graph: G, + sink_edge_nodes: DenseBitSet<G::Node>, + pub(crate) sink: G::Node, +} + +impl<G: graph::DirectedGraph> BalancedFlowGraph<G> { + /// Creates a balanced view of an underlying graph, by adding a synthetic + /// sink node that has in-edges from nodes that need or request such an edge, + /// and a single out-edge to the start node. + /// + /// Assumes that all nodes in the underlying graph are reachable from the + /// start node. + pub(crate) fn for_graph(graph: G, force_sink_edge: impl Fn(G::Node) -> bool) -> Self + where + G: graph::ControlFlowGraph, + { + let mut sink_edge_nodes = DenseBitSet::new_empty(graph.num_nodes()); + let mut dfs = DepthFirstSearch::new(ReversedGraph::new(&graph)); + + // First, determine the set of nodes that explicitly request or require + // an out-edge to the sink. + for node in graph.iter_nodes() { + if force_sink_edge(node) || graph.successors(node).next().is_none() { + sink_edge_nodes.insert(node); + dfs.push_start_node(node); + } + } + + // Next, find all nodes that are currently not reverse-reachable from + // `sink_edge_nodes`, and add them to the set as well. + dfs.complete_search(); + sink_edge_nodes.union_not(dfs.visited_set()); + + // The sink node is 1 higher than the highest real node. + let sink = G::Node::new(graph.num_nodes()); + + BalancedFlowGraph { graph, sink_edge_nodes, sink } + } +} + +impl<G> graph::DirectedGraph for BalancedFlowGraph<G> +where + G: graph::DirectedGraph, +{ + type Node = G::Node; + + /// Returns the number of nodes in this balanced-flow graph, which is 1 + /// more than the number of nodes in the underlying graph, to account for + /// the synthetic sink node. + fn num_nodes(&self) -> usize { + // The sink node's index is already the size of the underlying graph, + // so just add 1 to that instead. + self.sink.index() + 1 + } +} + +impl<G> graph::StartNode for BalancedFlowGraph<G> +where + G: graph::StartNode, +{ + fn start_node(&self) -> Self::Node { + self.graph.start_node() + } +} + +impl<G> graph::Successors for BalancedFlowGraph<G> +where + G: graph::StartNode + graph::Successors, +{ + fn successors(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> { + let real_edges; + let sink_edge; + + if node == self.sink { + // The sink node has no real out-edges, and one synthetic out-edge + // to the start node. + real_edges = None; + sink_edge = Some(self.graph.start_node()); + } else { + // Real nodes have their real out-edges, and possibly one synthetic + // out-edge to the sink node. + real_edges = Some(self.graph.successors(node)); + sink_edge = self.sink_edge_nodes.contains(node).then_some(self.sink); + } + + real_edges.into_iter().flatten().chain(sink_edge) + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/iter_nodes.rs b/compiler/rustc_mir_transform/src/coverage/counters/iter_nodes.rs new file mode 100644 index 00000000000..9d87f7af1b0 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/iter_nodes.rs @@ -0,0 +1,16 @@ +use rustc_data_structures::graph; +use rustc_index::Idx; + +pub(crate) trait IterNodes: graph::DirectedGraph { + /// Iterates over all nodes of a graph in ascending numeric order. + /// Assumes that nodes are densely numbered, i.e. every index in + /// `0..num_nodes` is a valid node. + /// + /// FIXME: Can this just be part of [`graph::DirectedGraph`]? + fn iter_nodes( + &self, + ) -> impl Iterator<Item = Self::Node> + DoubleEndedIterator + ExactSizeIterator { + (0..self.num_nodes()).map(<Self::Node as Idx>::new) + } +} +impl<G: graph::DirectedGraph> IterNodes for G {} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs b/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs new file mode 100644 index 00000000000..5e5d6624959 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs @@ -0,0 +1,290 @@ +//! For each node in a control-flow graph, determines whether that node should +//! have a physical counter, or a counter expression that is derived from the +//! physical counters of other nodes. +//! +//! Based on the algorithm given in +//! "Optimal measurement points for program frequency counts" +//! (Knuth & Stevenson, 1973). + +use rustc_data_structures::graph; +use rustc_index::bit_set::DenseBitSet; +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::coverage::Op; +use smallvec::SmallVec; + +use crate::coverage::counters::iter_nodes::IterNodes; +use crate::coverage::counters::union_find::{FrozenUnionFind, UnionFind}; + +#[cfg(test)] +mod tests; + +/// View of some underlying graph, in which each node's successors have been +/// merged into a single "supernode". +/// +/// The resulting supernodes have no obvious meaning on their own. +/// However, merging successor nodes means that a node's out-edges can all +/// be combined into a single out-edge, whose flow is the same as the flow +/// (execution count) of its corresponding node in the original graph. +/// +/// With all node flows now in the original graph now represented as edge flows +/// in the merged graph, it becomes possible to analyze the original node flows +/// using techniques for analyzing edge flows. +#[derive(Debug)] +pub(crate) struct MergedNodeFlowGraph<Node: Idx> { + /// Maps each node to the supernode that contains it, indicated by some + /// arbitrary "root" node that is part of that supernode. + supernodes: FrozenUnionFind<Node>, + /// For each node, stores the single supernode that all of its successors + /// have been merged into. + /// + /// (Note that each node in a supernode can potentially have a _different_ + /// successor supernode from its peers.) + succ_supernodes: IndexVec<Node, Node>, +} + +impl<Node: Idx> MergedNodeFlowGraph<Node> { + /// Creates a "merged" view of an underlying graph. + /// + /// The given graph is assumed to have [“balanced flow”](balanced-flow), + /// though it does not necessarily have to be a `BalancedFlowGraph`. + /// + /// [balanced-flow]: `crate::coverage::counters::balanced_flow::BalancedFlowGraph`. + pub(crate) fn for_balanced_graph<G>(graph: G) -> Self + where + G: graph::DirectedGraph<Node = Node> + graph::Successors, + { + let mut supernodes = UnionFind::<G::Node>::new(graph.num_nodes()); + + // For each node, merge its successors into a single supernode, and + // arbitrarily choose one of those successors to represent all of them. + let successors = graph + .iter_nodes() + .map(|node| { + graph + .successors(node) + .reduce(|a, b| supernodes.unify(a, b)) + .expect("each node in a balanced graph must have at least one out-edge") + }) + .collect::<IndexVec<G::Node, G::Node>>(); + + // Now that unification is complete, freeze the supernode forest, + // and resolve each arbitrarily-chosen successor to its canonical root. + // (This avoids having to explicitly resolve them later.) + let supernodes = supernodes.freeze(); + let succ_supernodes = successors.into_iter().map(|succ| supernodes.find(succ)).collect(); + + Self { supernodes, succ_supernodes } + } + + fn num_nodes(&self) -> usize { + self.succ_supernodes.len() + } + + fn is_supernode(&self, node: Node) -> bool { + self.supernodes.find(node) == node + } + + /// Using the information in this merged graph, together with a given + /// permutation of all nodes in the graph, to create physical counters and + /// counter expressions for each node in the underlying graph. + /// + /// The given list must contain exactly one copy of each node in the + /// underlying balanced-flow graph. The order of nodes is used as a hint to + /// influence counter allocation: + /// - Earlier nodes are more likely to receive counter expressions. + /// - Later nodes are more likely to receive physical counters. + pub(crate) fn make_node_counters(&self, all_nodes_permutation: &[Node]) -> NodeCounters<Node> { + let mut builder = SpantreeBuilder::new(self); + + for &node in all_nodes_permutation { + builder.visit_node(node); + } + + NodeCounters { counter_exprs: builder.finish() } + } +} + +/// End result of allocating physical counters and counter expressions for the +/// nodes of a graph. +#[derive(Debug)] +pub(crate) struct NodeCounters<Node: Idx> { + counter_exprs: IndexVec<Node, CounterExprVec<Node>>, +} + +impl<Node: Idx> NodeCounters<Node> { + /// For the given node, returns the finished list of terms that represent + /// its physical counter or counter expression. Always non-empty. + /// + /// If a node was given a physical counter, its "expression" will contain + /// that counter as its sole element. + pub(crate) fn counter_expr(&self, this: Node) -> &[CounterTerm<Node>] { + self.counter_exprs[this].as_slice() + } +} + +#[derive(Debug)] +struct SpantreeEdge<Node> { + /// If true, this edge in the spantree has been reversed an odd number of + /// times, so all physical counters added to its node's counter expression + /// need to be negated. + is_reversed: bool, + /// Each spantree edge is "claimed" by the (regular) node that caused it to + /// be created. When a node with a physical counter traverses this edge, + /// that counter is added to the claiming node's counter expression. + claiming_node: Node, + /// Supernode at the other end of this spantree edge. Transitively points + /// to the "root" of this supernode's spantree component. + span_parent: Node, +} + +/// Part of a node's counter expression, which is a sum of counter terms. +#[derive(Debug)] +pub(crate) struct CounterTerm<Node> { + /// Whether to add or subtract the value of the node's physical counter. + pub(crate) op: Op, + /// The node whose physical counter is represented by this term. + pub(crate) node: Node, +} + +/// Stores the list of counter terms that make up a node's counter expression. +type CounterExprVec<Node> = SmallVec<[CounterTerm<Node>; 2]>; + +#[derive(Debug)] +struct SpantreeBuilder<'a, Node: Idx> { + graph: &'a MergedNodeFlowGraph<Node>, + is_unvisited: DenseBitSet<Node>, + /// Links supernodes to each other, gradually forming a spanning tree of + /// the merged-flow graph. + /// + /// A supernode without a span edge is the root of its component of the + /// spantree. Nodes that aren't supernodes cannot have a spantree edge. + span_edges: IndexVec<Node, Option<SpantreeEdge<Node>>>, + /// An in-progress counter expression for each node. Each expression is + /// initially empty, and will be filled in as relevant nodes are visited. + counter_exprs: IndexVec<Node, CounterExprVec<Node>>, +} + +impl<'a, Node: Idx> SpantreeBuilder<'a, Node> { + fn new(graph: &'a MergedNodeFlowGraph<Node>) -> Self { + let num_nodes = graph.num_nodes(); + Self { + graph, + is_unvisited: DenseBitSet::new_filled(num_nodes), + span_edges: IndexVec::from_fn_n(|_| None, num_nodes), + counter_exprs: IndexVec::from_fn_n(|_| SmallVec::new(), num_nodes), + } + } + + /// Given a supernode, finds the supernode that is the "root" of its + /// spantree component. Two nodes that have the same spantree root are + /// connected in the spantree. + fn spantree_root(&self, this: Node) -> Node { + debug_assert!(self.graph.is_supernode(this)); + + match self.span_edges[this] { + None => this, + Some(SpantreeEdge { span_parent, .. }) => self.spantree_root(span_parent), + } + } + + /// Rotates edges in the spantree so that `this` is the root of its + /// spantree component. + fn yank_to_spantree_root(&mut self, this: Node) { + debug_assert!(self.graph.is_supernode(this)); + + // Temporarily remove this supernode (any any spantree-children) from its + // spantree component, by disconnecting the edge to its spantree-parent. + let Some(SpantreeEdge { is_reversed, claiming_node, span_parent }) = + self.span_edges[this].take() + else { + // This supernode has no spantree-parent edge, so it is already the + // root of its spantree component. + return; + }; + + // Recursively make our immediate spantree-parent the root of what's + // left of its component, so that only one more edge rotation is needed. + self.yank_to_spantree_root(span_parent); + + // Recreate the removed edge, but in the opposite direction. + // Now `this` is the root of its spantree component. + self.span_edges[span_parent] = + Some(SpantreeEdge { is_reversed: !is_reversed, claiming_node, span_parent: this }); + } + + /// Must be called exactly once for each node in the balanced-flow graph. + fn visit_node(&mut self, this: Node) { + // Assert that this node was unvisited, and mark it visited. + assert!(self.is_unvisited.remove(this), "node has already been visited: {this:?}"); + + // Get the supernode containing `this`, and make it the root of its + // component of the spantree. + let this_supernode = self.graph.supernodes.find(this); + self.yank_to_spantree_root(this_supernode); + + // Get the supernode containing all of this's successors. + let succ_supernode = self.graph.succ_supernodes[this]; + debug_assert!(self.graph.is_supernode(succ_supernode)); + + // If two supernodes are already connected in the spantree, they will + // have the same spantree root. (Each supernode is connected to itself.) + if this_supernode != self.spantree_root(succ_supernode) { + // Adding this node's flow edge to the spantree would cause two + // previously-disconnected supernodes to become connected, so add + // it. That spantree-edge is now "claimed" by this node. + // + // Claiming a spantree-edge means that this node will get a counter + // expression instead of a physical counter. That expression is + // currently empty, but will be built incrementally as the other + // nodes are visited. + self.span_edges[this_supernode] = Some(SpantreeEdge { + is_reversed: false, + claiming_node: this, + span_parent: succ_supernode, + }); + } else { + // This node's flow edge would join two supernodes that are already + // connected in the spantree (or are the same supernode). That would + // create a cycle in the spantree, so don't add an edge. + // + // Instead, create a physical counter for this node, and add that + // counter to all expressions on the path from `succ_supernode` to + // `this_supernode`. + + // Instead of setting `this.measure = true` as in the original paper, + // we just add the node's ID to its own "expression". + self.counter_exprs[this].push(CounterTerm { node: this, op: Op::Add }); + + // Walk the spantree from `this.successor` back to `this`. For each + // spantree edge along the way, add this node's physical counter to + // the counter expression of the node that claimed the spantree edge. + let mut curr = succ_supernode; + while curr != this_supernode { + let &SpantreeEdge { is_reversed, claiming_node, span_parent } = + self.span_edges[curr].as_ref().unwrap(); + let op = if is_reversed { Op::Subtract } else { Op::Add }; + self.counter_exprs[claiming_node].push(CounterTerm { node: this, op }); + + curr = span_parent; + } + } + } + + /// Asserts that all nodes have been visited, and returns the computed + /// counter expressions (made up of physical counters) for each node. + fn finish(self) -> IndexVec<Node, CounterExprVec<Node>> { + let Self { graph, is_unvisited, span_edges, counter_exprs } = self; + assert!(is_unvisited.is_empty(), "some nodes were never visited: {is_unvisited:?}"); + debug_assert!( + span_edges + .iter_enumerated() + .all(|(node, span_edge)| { span_edge.is_some() <= graph.is_supernode(node) }), + "only supernodes can have a span edge", + ); + debug_assert!( + counter_exprs.iter().all(|expr| !expr.is_empty()), + "after visiting all nodes, every node should have a non-empty expression", + ); + counter_exprs + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs new file mode 100644 index 00000000000..9e7f754523d --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs @@ -0,0 +1,64 @@ +use itertools::Itertools; +use rustc_data_structures::graph; +use rustc_data_structures::graph::vec_graph::VecGraph; +use rustc_index::Idx; +use rustc_middle::mir::coverage::Op; + +use super::{CounterTerm, MergedNodeFlowGraph, NodeCounters}; + +fn merged_node_flow_graph<G: graph::Successors>(graph: G) -> MergedNodeFlowGraph<G::Node> { + MergedNodeFlowGraph::for_balanced_graph(graph) +} + +fn make_graph<Node: Idx + Ord>(num_nodes: usize, edge_pairs: Vec<(Node, Node)>) -> VecGraph<Node> { + VecGraph::new(num_nodes, edge_pairs) +} + +/// Example used in "Optimal Measurement Points for Program Frequency Counts" +/// (Knuth & Stevenson, 1973), but with 0-based node IDs. +#[test] +fn example_driver() { + let graph = make_graph::<u32>(5, vec![ + (0, 1), + (0, 3), + (1, 0), + (1, 2), + (2, 1), + (2, 4), + (3, 3), + (3, 4), + (4, 0), + ]); + + let merged = merged_node_flow_graph(&graph); + let counters = merged.make_node_counters(&[3, 1, 2, 0, 4]); + + assert_eq!(format_counter_expressions(&counters), &[ + // (comment to force vertical formatting for clarity) + "[0]: +c0", + "[1]: +c0 +c2 -c4", + "[2]: +c2", + "[3]: +c3", + "[4]: +c4", + ]); +} + +fn format_counter_expressions<Node: Idx>(counters: &NodeCounters<Node>) -> Vec<String> { + let format_item = |&CounterTerm { node, op }| { + let op = match op { + Op::Subtract => '-', + Op::Add => '+', + }; + format!("{op}c{node:?}") + }; + + counters + .counter_exprs + .indices() + .map(|node| { + let mut expr = counters.counter_expr(node).iter().collect::<Vec<_>>(); + expr.sort_by_key(|item| item.node.index()); + format!("[{node:?}]: {}", expr.into_iter().map(format_item).join(" ")) + }) + .collect() +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/tests.rs deleted file mode 100644 index 794d4358f82..00000000000 --- a/compiler/rustc_mir_transform/src/coverage/counters/tests.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::fmt::Debug; - -use super::sort_and_cancel; - -fn flatten<T>(input: Vec<Option<T>>) -> Vec<T> { - input.into_iter().flatten().collect() -} - -fn sort_and_cancel_and_flatten<T: Clone + Ord>(pos: Vec<T>, neg: Vec<T>) -> (Vec<T>, Vec<T>) { - let (pos_actual, neg_actual) = sort_and_cancel(pos, neg); - (flatten(pos_actual), flatten(neg_actual)) -} - -#[track_caller] -fn check_test_case<T: Clone + Debug + Ord>( - pos: Vec<T>, - neg: Vec<T>, - pos_expected: Vec<T>, - neg_expected: Vec<T>, -) { - eprintln!("pos = {pos:?}; neg = {neg:?}"); - let output = sort_and_cancel_and_flatten(pos, neg); - assert_eq!(output, (pos_expected, neg_expected)); -} - -#[test] -fn cancellation() { - let cases: &[(Vec<u32>, Vec<u32>, Vec<u32>, Vec<u32>)] = &[ - (vec![], vec![], vec![], vec![]), - (vec![4, 2, 1, 5, 3], vec![], vec![1, 2, 3, 4, 5], vec![]), - (vec![5, 5, 5, 5, 5], vec![5], vec![5, 5, 5, 5], vec![]), - (vec![1, 1, 2, 2, 3, 3], vec![1, 2, 3], vec![1, 2, 3], vec![]), - (vec![1, 1, 2, 2, 3, 3], vec![2, 4, 2], vec![1, 1, 3, 3], vec![4]), - ]; - - for (pos, neg, pos_expected, neg_expected) in cases { - check_test_case(pos.to_vec(), neg.to_vec(), pos_expected.to_vec(), neg_expected.to_vec()); - // Same test case, but with its inputs flipped and its outputs flipped. - check_test_case(neg.to_vec(), pos.to_vec(), neg_expected.to_vec(), pos_expected.to_vec()); - } -} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs b/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs new file mode 100644 index 00000000000..2da4f5f5fce --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs @@ -0,0 +1,116 @@ +use std::cmp::Ordering; +use std::mem; + +use rustc_index::{Idx, IndexVec}; + +#[cfg(test)] +mod tests; + +/// Simple implementation of a union-find data structure, i.e. a disjoint-set +/// forest. +#[derive(Debug)] +pub(crate) struct UnionFind<Key: Idx> { + table: IndexVec<Key, UnionFindEntry<Key>>, +} + +#[derive(Debug)] +struct UnionFindEntry<Key> { + /// Transitively points towards the "root" of the set containing this key. + /// + /// Invariant: A root key is its own parent. + parent: Key, + /// When merging two "root" keys, their ranks determine which key becomes + /// the new root, to prevent the parent tree from becoming unnecessarily + /// tall. See [`UnionFind::unify`] for details. + rank: u32, +} + +impl<Key: Idx> UnionFind<Key> { + /// Creates a new disjoint-set forest containing the keys `0..num_keys`. + /// Initially, every key is part of its own one-element set. + pub(crate) fn new(num_keys: usize) -> Self { + // Initially, every key is the root of its own set, so its parent is itself. + Self { table: IndexVec::from_fn_n(|key| UnionFindEntry { parent: key, rank: 0 }, num_keys) } + } + + /// Returns the "root" key of the disjoint-set containing the given key. + /// If two keys have the same root, they belong to the same set. + /// + /// Also updates internal data structures to make subsequent `find` + /// operations faster. + pub(crate) fn find(&mut self, key: Key) -> Key { + // Loop until we find a key that is its own parent. + let mut curr = key; + while let parent = self.table[curr].parent + && curr != parent + { + // Perform "path compression" by peeking one layer ahead, and + // setting the current key's parent to that value. + // (This works even when `parent` is the root of its set, because + // of the invariant that a root is its own parent.) + let parent_parent = self.table[parent].parent; + self.table[curr].parent = parent_parent; + + // Advance by one step and continue. + curr = parent; + } + curr + } + + /// Merges the set containing `a` and the set containing `b` into one set. + /// + /// Returns the common root of both keys, after the merge. + pub(crate) fn unify(&mut self, a: Key, b: Key) -> Key { + let mut a = self.find(a); + let mut b = self.find(b); + + // If both keys have the same root, they're already in the same set, + // so there's nothing more to do. + if a == b { + return a; + }; + + // Ensure that `a` has strictly greater rank, swapping if necessary. + // If both keys have the same rank, increment the rank of `a` so that + // future unifications will also prefer `a`, leading to flatter trees. + match Ord::cmp(&self.table[a].rank, &self.table[b].rank) { + Ordering::Less => mem::swap(&mut a, &mut b), + Ordering::Equal => self.table[a].rank += 1, + Ordering::Greater => {} + } + + debug_assert!(self.table[a].rank > self.table[b].rank); + debug_assert_eq!(self.table[b].parent, b); + + // Make `a` the parent of `b`. + self.table[b].parent = a; + + a + } + + /// Creates a snapshot of this disjoint-set forest that can no longer be + /// mutated, but can be queried without mutation. + pub(crate) fn freeze(&mut self) -> FrozenUnionFind<Key> { + // Just resolve each key to its actual root. + let roots = self.table.indices().map(|key| self.find(key)).collect(); + FrozenUnionFind { roots } + } +} + +/// Snapshot of a disjoint-set forest that can no longer be mutated, but can be +/// queried in O(1) time without mutation. +/// +/// This is really just a wrapper around a direct mapping from keys to roots, +/// but with a [`Self::find`] method that resembles [`UnionFind::find`]. +#[derive(Debug)] +pub(crate) struct FrozenUnionFind<Key: Idx> { + roots: IndexVec<Key, Key>, +} + +impl<Key: Idx> FrozenUnionFind<Key> { + /// Returns the "root" key of the disjoint-set containing the given key. + /// If two keys have the same root, they belong to the same set. + pub(crate) fn find(&self, key: Key) -> Key { + self.roots[key] + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs new file mode 100644 index 00000000000..34a4e4f8e6e --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs @@ -0,0 +1,32 @@ +use super::UnionFind; + +#[test] +fn empty() { + let mut sets = UnionFind::<u32>::new(10); + + for i in 1..10 { + assert_eq!(sets.find(i), i); + } +} + +#[test] +fn transitive() { + let mut sets = UnionFind::<u32>::new(10); + + sets.unify(3, 7); + sets.unify(4, 2); + + assert_eq!(sets.find(7), sets.find(3)); + assert_eq!(sets.find(2), sets.find(4)); + assert_ne!(sets.find(3), sets.find(4)); + + sets.unify(7, 4); + + assert_eq!(sets.find(7), sets.find(3)); + assert_eq!(sets.find(2), sets.find(4)); + assert_eq!(sets.find(3), sets.find(4)); + + for i in [0, 1, 5, 6, 8, 9] { + assert_eq!(sets.find(i), i); + } +} |
