about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/coverage/counters
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/coverage/counters')
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs133
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters/iter_nodes.rs16
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs290
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs64
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters/tests.rs41
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters/union_find.rs116
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs32
7 files changed, 651 insertions, 41 deletions
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs b/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs
new file mode 100644
index 00000000000..c108f96a564
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs
@@ -0,0 +1,133 @@
+//! A control-flow graph can be said to have “balanced flow” if the flow
+//! (execution count) of each node is equal to the sum of its in-edge flows,
+//! and also equal to the sum of its out-edge flows.
+//!
+//! Control-flow graphs typically have one or more nodes that don't satisfy the
+//! balanced-flow property, e.g.:
+//! - The start node has out-edges, but no in-edges.
+//! - Return nodes have in-edges, but no out-edges.
+//! - `Yield` nodes can have an out-flow that is less than their in-flow.
+//! - Inescapable loops cause the in-flow/out-flow relationship to break down.
+//!
+//! Balanced-flow graphs are nevertheless useful for analysis, so this module
+//! provides a wrapper type ([`BalancedFlowGraph`]) that imposes balanced flow
+//! on an underlying graph. This is done by non-destructively adding synthetic
+//! nodes and edges as necessary.
+
+use rustc_data_structures::graph;
+use rustc_data_structures::graph::iterate::DepthFirstSearch;
+use rustc_data_structures::graph::reversed::ReversedGraph;
+use rustc_index::Idx;
+use rustc_index::bit_set::DenseBitSet;
+
+use crate::coverage::counters::iter_nodes::IterNodes;
+
+/// A view of an underlying graph that has been augmented to have “balanced flow”.
+/// This means that the flow (execution count) of each node is equal to the
+/// sum of its in-edge flows, and also equal to the sum of its out-edge flows.
+///
+/// To achieve this, a synthetic "sink" node is non-destructively added to the
+/// graph, with synthetic in-edges from these nodes:
+/// - Any node that has no out-edges.
+/// - Any node that explicitly requires a sink edge, as indicated by a
+///   caller-supplied `force_sink_edge` function.
+/// - Any node that would otherwise be unable to reach the sink, because it is
+///   part of an inescapable loop.
+///
+/// To make the graph fully balanced, there is also a synthetic edge from the
+/// sink node back to the start node.
+///
+/// ---
+/// The benefit of having a balanced-flow graph is that it can be subsequently
+/// transformed in ways that are guaranteed to preserve balanced flow
+/// (e.g. merging nodes together), which is useful for discovering relationships
+/// between the node flows of different nodes in the graph.
+pub(crate) struct BalancedFlowGraph<G: graph::DirectedGraph> {
+    graph: G,
+    sink_edge_nodes: DenseBitSet<G::Node>,
+    pub(crate) sink: G::Node,
+}
+
+impl<G: graph::DirectedGraph> BalancedFlowGraph<G> {
+    /// Creates a balanced view of an underlying graph, by adding a synthetic
+    /// sink node that has in-edges from nodes that need or request such an edge,
+    /// and a single out-edge to the start node.
+    ///
+    /// Assumes that all nodes in the underlying graph are reachable from the
+    /// start node.
+    pub(crate) fn for_graph(graph: G, force_sink_edge: impl Fn(G::Node) -> bool) -> Self
+    where
+        G: graph::ControlFlowGraph,
+    {
+        let mut sink_edge_nodes = DenseBitSet::new_empty(graph.num_nodes());
+        let mut dfs = DepthFirstSearch::new(ReversedGraph::new(&graph));
+
+        // First, determine the set of nodes that explicitly request or require
+        // an out-edge to the sink.
+        for node in graph.iter_nodes() {
+            if force_sink_edge(node) || graph.successors(node).next().is_none() {
+                sink_edge_nodes.insert(node);
+                dfs.push_start_node(node);
+            }
+        }
+
+        // Next, find all nodes that are currently not reverse-reachable from
+        // `sink_edge_nodes`, and add them to the set as well.
+        dfs.complete_search();
+        sink_edge_nodes.union_not(dfs.visited_set());
+
+        // The sink node is 1 higher than the highest real node.
+        let sink = G::Node::new(graph.num_nodes());
+
+        BalancedFlowGraph { graph, sink_edge_nodes, sink }
+    }
+}
+
+impl<G> graph::DirectedGraph for BalancedFlowGraph<G>
+where
+    G: graph::DirectedGraph,
+{
+    type Node = G::Node;
+
+    /// Returns the number of nodes in this balanced-flow graph, which is 1
+    /// more than the number of nodes in the underlying graph, to account for
+    /// the synthetic sink node.
+    fn num_nodes(&self) -> usize {
+        // The sink node's index is already the size of the underlying graph,
+        // so just add 1 to that instead.
+        self.sink.index() + 1
+    }
+}
+
+impl<G> graph::StartNode for BalancedFlowGraph<G>
+where
+    G: graph::StartNode,
+{
+    fn start_node(&self) -> Self::Node {
+        self.graph.start_node()
+    }
+}
+
+impl<G> graph::Successors for BalancedFlowGraph<G>
+where
+    G: graph::StartNode + graph::Successors,
+{
+    fn successors(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> {
+        let real_edges;
+        let sink_edge;
+
+        if node == self.sink {
+            // The sink node has no real out-edges, and one synthetic out-edge
+            // to the start node.
+            real_edges = None;
+            sink_edge = Some(self.graph.start_node());
+        } else {
+            // Real nodes have their real out-edges, and possibly one synthetic
+            // out-edge to the sink node.
+            real_edges = Some(self.graph.successors(node));
+            sink_edge = self.sink_edge_nodes.contains(node).then_some(self.sink);
+        }
+
+        real_edges.into_iter().flatten().chain(sink_edge)
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/iter_nodes.rs b/compiler/rustc_mir_transform/src/coverage/counters/iter_nodes.rs
new file mode 100644
index 00000000000..9d87f7af1b0
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coverage/counters/iter_nodes.rs
@@ -0,0 +1,16 @@
+use rustc_data_structures::graph;
+use rustc_index::Idx;
+
+pub(crate) trait IterNodes: graph::DirectedGraph {
+    /// Iterates over all nodes of a graph in ascending numeric order.
+    /// Assumes that nodes are densely numbered, i.e. every index in
+    /// `0..num_nodes` is a valid node.
+    ///
+    /// FIXME: Can this just be part of [`graph::DirectedGraph`]?
+    fn iter_nodes(
+        &self,
+    ) -> impl Iterator<Item = Self::Node> + DoubleEndedIterator + ExactSizeIterator {
+        (0..self.num_nodes()).map(<Self::Node as Idx>::new)
+    }
+}
+impl<G: graph::DirectedGraph> IterNodes for G {}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs b/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs
new file mode 100644
index 00000000000..5e5d6624959
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs
@@ -0,0 +1,290 @@
+//! For each node in a control-flow graph, determines whether that node should
+//! have a physical counter, or a counter expression that is derived from the
+//! physical counters of other nodes.
+//!
+//! Based on the algorithm given in
+//! "Optimal measurement points for program frequency counts"
+//! (Knuth & Stevenson, 1973).
+
+use rustc_data_structures::graph;
+use rustc_index::bit_set::DenseBitSet;
+use rustc_index::{Idx, IndexVec};
+use rustc_middle::mir::coverage::Op;
+use smallvec::SmallVec;
+
+use crate::coverage::counters::iter_nodes::IterNodes;
+use crate::coverage::counters::union_find::{FrozenUnionFind, UnionFind};
+
+#[cfg(test)]
+mod tests;
+
+/// View of some underlying graph, in which each node's successors have been
+/// merged into a single "supernode".
+///
+/// The resulting supernodes have no obvious meaning on their own.
+/// However, merging successor nodes means that a node's out-edges can all
+/// be combined into a single out-edge, whose flow is the same as the flow
+/// (execution count) of its corresponding node in the original graph.
+///
+/// With all node flows now in the original graph now represented as edge flows
+/// in the merged graph, it becomes possible to analyze the original node flows
+/// using techniques for analyzing edge flows.
+#[derive(Debug)]
+pub(crate) struct MergedNodeFlowGraph<Node: Idx> {
+    /// Maps each node to the supernode that contains it, indicated by some
+    /// arbitrary "root" node that is part of that supernode.
+    supernodes: FrozenUnionFind<Node>,
+    /// For each node, stores the single supernode that all of its successors
+    /// have been merged into.
+    ///
+    /// (Note that each node in a supernode can potentially have a _different_
+    /// successor supernode from its peers.)
+    succ_supernodes: IndexVec<Node, Node>,
+}
+
+impl<Node: Idx> MergedNodeFlowGraph<Node> {
+    /// Creates a "merged" view of an underlying graph.
+    ///
+    /// The given graph is assumed to have [“balanced flow”](balanced-flow),
+    /// though it does not necessarily have to be a `BalancedFlowGraph`.
+    ///
+    /// [balanced-flow]: `crate::coverage::counters::balanced_flow::BalancedFlowGraph`.
+    pub(crate) fn for_balanced_graph<G>(graph: G) -> Self
+    where
+        G: graph::DirectedGraph<Node = Node> + graph::Successors,
+    {
+        let mut supernodes = UnionFind::<G::Node>::new(graph.num_nodes());
+
+        // For each node, merge its successors into a single supernode, and
+        // arbitrarily choose one of those successors to represent all of them.
+        let successors = graph
+            .iter_nodes()
+            .map(|node| {
+                graph
+                    .successors(node)
+                    .reduce(|a, b| supernodes.unify(a, b))
+                    .expect("each node in a balanced graph must have at least one out-edge")
+            })
+            .collect::<IndexVec<G::Node, G::Node>>();
+
+        // Now that unification is complete, freeze the supernode forest,
+        // and resolve each arbitrarily-chosen successor to its canonical root.
+        // (This avoids having to explicitly resolve them later.)
+        let supernodes = supernodes.freeze();
+        let succ_supernodes = successors.into_iter().map(|succ| supernodes.find(succ)).collect();
+
+        Self { supernodes, succ_supernodes }
+    }
+
+    fn num_nodes(&self) -> usize {
+        self.succ_supernodes.len()
+    }
+
+    fn is_supernode(&self, node: Node) -> bool {
+        self.supernodes.find(node) == node
+    }
+
+    /// Using the information in this merged graph, together with a given
+    /// permutation of all nodes in the graph, to create physical counters and
+    /// counter expressions for each node in the underlying graph.
+    ///
+    /// The given list must contain exactly one copy of each node in the
+    /// underlying balanced-flow graph. The order of nodes is used as a hint to
+    /// influence counter allocation:
+    /// - Earlier nodes are more likely to receive counter expressions.
+    /// - Later nodes are more likely to receive physical counters.
+    pub(crate) fn make_node_counters(&self, all_nodes_permutation: &[Node]) -> NodeCounters<Node> {
+        let mut builder = SpantreeBuilder::new(self);
+
+        for &node in all_nodes_permutation {
+            builder.visit_node(node);
+        }
+
+        NodeCounters { counter_exprs: builder.finish() }
+    }
+}
+
+/// End result of allocating physical counters and counter expressions for the
+/// nodes of a graph.
+#[derive(Debug)]
+pub(crate) struct NodeCounters<Node: Idx> {
+    counter_exprs: IndexVec<Node, CounterExprVec<Node>>,
+}
+
+impl<Node: Idx> NodeCounters<Node> {
+    /// For the given node, returns the finished list of terms that represent
+    /// its physical counter or counter expression. Always non-empty.
+    ///
+    /// If a node was given a physical counter, its "expression" will contain
+    /// that counter as its sole element.
+    pub(crate) fn counter_expr(&self, this: Node) -> &[CounterTerm<Node>] {
+        self.counter_exprs[this].as_slice()
+    }
+}
+
+#[derive(Debug)]
+struct SpantreeEdge<Node> {
+    /// If true, this edge in the spantree has been reversed an odd number of
+    /// times, so all physical counters added to its node's counter expression
+    /// need to be negated.
+    is_reversed: bool,
+    /// Each spantree edge is "claimed" by the (regular) node that caused it to
+    /// be created. When a node with a physical counter traverses this edge,
+    /// that counter is added to the claiming node's counter expression.
+    claiming_node: Node,
+    /// Supernode at the other end of this spantree edge. Transitively points
+    /// to the "root" of this supernode's spantree component.
+    span_parent: Node,
+}
+
+/// Part of a node's counter expression, which is a sum of counter terms.
+#[derive(Debug)]
+pub(crate) struct CounterTerm<Node> {
+    /// Whether to add or subtract the value of the node's physical counter.
+    pub(crate) op: Op,
+    /// The node whose physical counter is represented by this term.
+    pub(crate) node: Node,
+}
+
+/// Stores the list of counter terms that make up a node's counter expression.
+type CounterExprVec<Node> = SmallVec<[CounterTerm<Node>; 2]>;
+
+#[derive(Debug)]
+struct SpantreeBuilder<'a, Node: Idx> {
+    graph: &'a MergedNodeFlowGraph<Node>,
+    is_unvisited: DenseBitSet<Node>,
+    /// Links supernodes to each other, gradually forming a spanning tree of
+    /// the merged-flow graph.
+    ///
+    /// A supernode without a span edge is the root of its component of the
+    /// spantree. Nodes that aren't supernodes cannot have a spantree edge.
+    span_edges: IndexVec<Node, Option<SpantreeEdge<Node>>>,
+    /// An in-progress counter expression for each node. Each expression is
+    /// initially empty, and will be filled in as relevant nodes are visited.
+    counter_exprs: IndexVec<Node, CounterExprVec<Node>>,
+}
+
+impl<'a, Node: Idx> SpantreeBuilder<'a, Node> {
+    fn new(graph: &'a MergedNodeFlowGraph<Node>) -> Self {
+        let num_nodes = graph.num_nodes();
+        Self {
+            graph,
+            is_unvisited: DenseBitSet::new_filled(num_nodes),
+            span_edges: IndexVec::from_fn_n(|_| None, num_nodes),
+            counter_exprs: IndexVec::from_fn_n(|_| SmallVec::new(), num_nodes),
+        }
+    }
+
+    /// Given a supernode, finds the supernode that is the "root" of its
+    /// spantree component. Two nodes that have the same spantree root are
+    /// connected in the spantree.
+    fn spantree_root(&self, this: Node) -> Node {
+        debug_assert!(self.graph.is_supernode(this));
+
+        match self.span_edges[this] {
+            None => this,
+            Some(SpantreeEdge { span_parent, .. }) => self.spantree_root(span_parent),
+        }
+    }
+
+    /// Rotates edges in the spantree so that `this` is the root of its
+    /// spantree component.
+    fn yank_to_spantree_root(&mut self, this: Node) {
+        debug_assert!(self.graph.is_supernode(this));
+
+        // Temporarily remove this supernode (any any spantree-children) from its
+        // spantree component, by disconnecting the edge to its spantree-parent.
+        let Some(SpantreeEdge { is_reversed, claiming_node, span_parent }) =
+            self.span_edges[this].take()
+        else {
+            // This supernode has no spantree-parent edge, so it is already the
+            // root of its spantree component.
+            return;
+        };
+
+        // Recursively make our immediate spantree-parent the root of what's
+        // left of its component, so that only one more edge rotation is needed.
+        self.yank_to_spantree_root(span_parent);
+
+        // Recreate the removed edge, but in the opposite direction.
+        // Now `this` is the root of its spantree component.
+        self.span_edges[span_parent] =
+            Some(SpantreeEdge { is_reversed: !is_reversed, claiming_node, span_parent: this });
+    }
+
+    /// Must be called exactly once for each node in the balanced-flow graph.
+    fn visit_node(&mut self, this: Node) {
+        // Assert that this node was unvisited, and mark it visited.
+        assert!(self.is_unvisited.remove(this), "node has already been visited: {this:?}");
+
+        // Get the supernode containing `this`, and make it the root of its
+        // component of the spantree.
+        let this_supernode = self.graph.supernodes.find(this);
+        self.yank_to_spantree_root(this_supernode);
+
+        // Get the supernode containing all of this's successors.
+        let succ_supernode = self.graph.succ_supernodes[this];
+        debug_assert!(self.graph.is_supernode(succ_supernode));
+
+        // If two supernodes are already connected in the spantree, they will
+        // have the same spantree root. (Each supernode is connected to itself.)
+        if this_supernode != self.spantree_root(succ_supernode) {
+            // Adding this node's flow edge to the spantree would cause two
+            // previously-disconnected supernodes to become connected, so add
+            // it. That spantree-edge is now "claimed" by this node.
+            //
+            // Claiming a spantree-edge means that this node will get a counter
+            // expression instead of a physical counter. That expression is
+            // currently empty, but will be built incrementally as the other
+            // nodes are visited.
+            self.span_edges[this_supernode] = Some(SpantreeEdge {
+                is_reversed: false,
+                claiming_node: this,
+                span_parent: succ_supernode,
+            });
+        } else {
+            // This node's flow edge would join two supernodes that are already
+            // connected in the spantree (or are the same supernode). That would
+            // create a cycle in the spantree, so don't add an edge.
+            //
+            // Instead, create a physical counter for this node, and add that
+            // counter to all expressions on the path from `succ_supernode` to
+            // `this_supernode`.
+
+            // Instead of setting `this.measure = true` as in the original paper,
+            // we just add the node's ID to its own "expression".
+            self.counter_exprs[this].push(CounterTerm { node: this, op: Op::Add });
+
+            // Walk the spantree from `this.successor` back to `this`. For each
+            // spantree edge along the way, add this node's physical counter to
+            // the counter expression of the node that claimed the spantree edge.
+            let mut curr = succ_supernode;
+            while curr != this_supernode {
+                let &SpantreeEdge { is_reversed, claiming_node, span_parent } =
+                    self.span_edges[curr].as_ref().unwrap();
+                let op = if is_reversed { Op::Subtract } else { Op::Add };
+                self.counter_exprs[claiming_node].push(CounterTerm { node: this, op });
+
+                curr = span_parent;
+            }
+        }
+    }
+
+    /// Asserts that all nodes have been visited, and returns the computed
+    /// counter expressions (made up of physical counters) for each node.
+    fn finish(self) -> IndexVec<Node, CounterExprVec<Node>> {
+        let Self { graph, is_unvisited, span_edges, counter_exprs } = self;
+        assert!(is_unvisited.is_empty(), "some nodes were never visited: {is_unvisited:?}");
+        debug_assert!(
+            span_edges
+                .iter_enumerated()
+                .all(|(node, span_edge)| { span_edge.is_some() <= graph.is_supernode(node) }),
+            "only supernodes can have a span edge",
+        );
+        debug_assert!(
+            counter_exprs.iter().all(|expr| !expr.is_empty()),
+            "after visiting all nodes, every node should have a non-empty expression",
+        );
+        counter_exprs
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs
new file mode 100644
index 00000000000..9e7f754523d
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs
@@ -0,0 +1,64 @@
+use itertools::Itertools;
+use rustc_data_structures::graph;
+use rustc_data_structures::graph::vec_graph::VecGraph;
+use rustc_index::Idx;
+use rustc_middle::mir::coverage::Op;
+
+use super::{CounterTerm, MergedNodeFlowGraph, NodeCounters};
+
+fn merged_node_flow_graph<G: graph::Successors>(graph: G) -> MergedNodeFlowGraph<G::Node> {
+    MergedNodeFlowGraph::for_balanced_graph(graph)
+}
+
+fn make_graph<Node: Idx + Ord>(num_nodes: usize, edge_pairs: Vec<(Node, Node)>) -> VecGraph<Node> {
+    VecGraph::new(num_nodes, edge_pairs)
+}
+
+/// Example used in "Optimal Measurement Points for Program Frequency Counts"
+/// (Knuth & Stevenson, 1973), but with 0-based node IDs.
+#[test]
+fn example_driver() {
+    let graph = make_graph::<u32>(5, vec![
+        (0, 1),
+        (0, 3),
+        (1, 0),
+        (1, 2),
+        (2, 1),
+        (2, 4),
+        (3, 3),
+        (3, 4),
+        (4, 0),
+    ]);
+
+    let merged = merged_node_flow_graph(&graph);
+    let counters = merged.make_node_counters(&[3, 1, 2, 0, 4]);
+
+    assert_eq!(format_counter_expressions(&counters), &[
+        // (comment to force vertical formatting for clarity)
+        "[0]: +c0",
+        "[1]: +c0 +c2 -c4",
+        "[2]: +c2",
+        "[3]: +c3",
+        "[4]: +c4",
+    ]);
+}
+
+fn format_counter_expressions<Node: Idx>(counters: &NodeCounters<Node>) -> Vec<String> {
+    let format_item = |&CounterTerm { node, op }| {
+        let op = match op {
+            Op::Subtract => '-',
+            Op::Add => '+',
+        };
+        format!("{op}c{node:?}")
+    };
+
+    counters
+        .counter_exprs
+        .indices()
+        .map(|node| {
+            let mut expr = counters.counter_expr(node).iter().collect::<Vec<_>>();
+            expr.sort_by_key(|item| item.node.index());
+            format!("[{node:?}]: {}", expr.into_iter().map(format_item).join(" "))
+        })
+        .collect()
+}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/tests.rs
deleted file mode 100644
index 794d4358f82..00000000000
--- a/compiler/rustc_mir_transform/src/coverage/counters/tests.rs
+++ /dev/null
@@ -1,41 +0,0 @@
-use std::fmt::Debug;
-
-use super::sort_and_cancel;
-
-fn flatten<T>(input: Vec<Option<T>>) -> Vec<T> {
-    input.into_iter().flatten().collect()
-}
-
-fn sort_and_cancel_and_flatten<T: Clone + Ord>(pos: Vec<T>, neg: Vec<T>) -> (Vec<T>, Vec<T>) {
-    let (pos_actual, neg_actual) = sort_and_cancel(pos, neg);
-    (flatten(pos_actual), flatten(neg_actual))
-}
-
-#[track_caller]
-fn check_test_case<T: Clone + Debug + Ord>(
-    pos: Vec<T>,
-    neg: Vec<T>,
-    pos_expected: Vec<T>,
-    neg_expected: Vec<T>,
-) {
-    eprintln!("pos = {pos:?}; neg = {neg:?}");
-    let output = sort_and_cancel_and_flatten(pos, neg);
-    assert_eq!(output, (pos_expected, neg_expected));
-}
-
-#[test]
-fn cancellation() {
-    let cases: &[(Vec<u32>, Vec<u32>, Vec<u32>, Vec<u32>)] = &[
-        (vec![], vec![], vec![], vec![]),
-        (vec![4, 2, 1, 5, 3], vec![], vec![1, 2, 3, 4, 5], vec![]),
-        (vec![5, 5, 5, 5, 5], vec![5], vec![5, 5, 5, 5], vec![]),
-        (vec![1, 1, 2, 2, 3, 3], vec![1, 2, 3], vec![1, 2, 3], vec![]),
-        (vec![1, 1, 2, 2, 3, 3], vec![2, 4, 2], vec![1, 1, 3, 3], vec![4]),
-    ];
-
-    for (pos, neg, pos_expected, neg_expected) in cases {
-        check_test_case(pos.to_vec(), neg.to_vec(), pos_expected.to_vec(), neg_expected.to_vec());
-        // Same test case, but with its inputs flipped and its outputs flipped.
-        check_test_case(neg.to_vec(), pos.to_vec(), neg_expected.to_vec(), pos_expected.to_vec());
-    }
-}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs b/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs
new file mode 100644
index 00000000000..2da4f5f5fce
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs
@@ -0,0 +1,116 @@
+use std::cmp::Ordering;
+use std::mem;
+
+use rustc_index::{Idx, IndexVec};
+
+#[cfg(test)]
+mod tests;
+
+/// Simple implementation of a union-find data structure, i.e. a disjoint-set
+/// forest.
+#[derive(Debug)]
+pub(crate) struct UnionFind<Key: Idx> {
+    table: IndexVec<Key, UnionFindEntry<Key>>,
+}
+
+#[derive(Debug)]
+struct UnionFindEntry<Key> {
+    /// Transitively points towards the "root" of the set containing this key.
+    ///
+    /// Invariant: A root key is its own parent.
+    parent: Key,
+    /// When merging two "root" keys, their ranks determine which key becomes
+    /// the new root, to prevent the parent tree from becoming unnecessarily
+    /// tall. See [`UnionFind::unify`] for details.
+    rank: u32,
+}
+
+impl<Key: Idx> UnionFind<Key> {
+    /// Creates a new disjoint-set forest containing the keys `0..num_keys`.
+    /// Initially, every key is part of its own one-element set.
+    pub(crate) fn new(num_keys: usize) -> Self {
+        // Initially, every key is the root of its own set, so its parent is itself.
+        Self { table: IndexVec::from_fn_n(|key| UnionFindEntry { parent: key, rank: 0 }, num_keys) }
+    }
+
+    /// Returns the "root" key of the disjoint-set containing the given key.
+    /// If two keys have the same root, they belong to the same set.
+    ///
+    /// Also updates internal data structures to make subsequent `find`
+    /// operations faster.
+    pub(crate) fn find(&mut self, key: Key) -> Key {
+        // Loop until we find a key that is its own parent.
+        let mut curr = key;
+        while let parent = self.table[curr].parent
+            && curr != parent
+        {
+            // Perform "path compression" by peeking one layer ahead, and
+            // setting the current key's parent to that value.
+            // (This works even when `parent` is the root of its set, because
+            // of the invariant that a root is its own parent.)
+            let parent_parent = self.table[parent].parent;
+            self.table[curr].parent = parent_parent;
+
+            // Advance by one step and continue.
+            curr = parent;
+        }
+        curr
+    }
+
+    /// Merges the set containing `a` and the set containing `b` into one set.
+    ///
+    /// Returns the common root of both keys, after the merge.
+    pub(crate) fn unify(&mut self, a: Key, b: Key) -> Key {
+        let mut a = self.find(a);
+        let mut b = self.find(b);
+
+        // If both keys have the same root, they're already in the same set,
+        // so there's nothing more to do.
+        if a == b {
+            return a;
+        };
+
+        // Ensure that `a` has strictly greater rank, swapping if necessary.
+        // If both keys have the same rank, increment the rank of `a` so that
+        // future unifications will also prefer `a`, leading to flatter trees.
+        match Ord::cmp(&self.table[a].rank, &self.table[b].rank) {
+            Ordering::Less => mem::swap(&mut a, &mut b),
+            Ordering::Equal => self.table[a].rank += 1,
+            Ordering::Greater => {}
+        }
+
+        debug_assert!(self.table[a].rank > self.table[b].rank);
+        debug_assert_eq!(self.table[b].parent, b);
+
+        // Make `a` the parent of `b`.
+        self.table[b].parent = a;
+
+        a
+    }
+
+    /// Creates a snapshot of this disjoint-set forest that can no longer be
+    /// mutated, but can be queried without mutation.
+    pub(crate) fn freeze(&mut self) -> FrozenUnionFind<Key> {
+        // Just resolve each key to its actual root.
+        let roots = self.table.indices().map(|key| self.find(key)).collect();
+        FrozenUnionFind { roots }
+    }
+}
+
+/// Snapshot of a disjoint-set forest that can no longer be mutated, but can be
+/// queried in O(1) time without mutation.
+///
+/// This is really just a wrapper around a direct mapping from keys to roots,
+/// but with a [`Self::find`] method that resembles [`UnionFind::find`].
+#[derive(Debug)]
+pub(crate) struct FrozenUnionFind<Key: Idx> {
+    roots: IndexVec<Key, Key>,
+}
+
+impl<Key: Idx> FrozenUnionFind<Key> {
+    /// Returns the "root" key of the disjoint-set containing the given key.
+    /// If two keys have the same root, they belong to the same set.
+    pub(crate) fn find(&self, key: Key) -> Key {
+        self.roots[key]
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs
new file mode 100644
index 00000000000..34a4e4f8e6e
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs
@@ -0,0 +1,32 @@
+use super::UnionFind;
+
+#[test]
+fn empty() {
+    let mut sets = UnionFind::<u32>::new(10);
+
+    for i in 1..10 {
+        assert_eq!(sets.find(i), i);
+    }
+}
+
+#[test]
+fn transitive() {
+    let mut sets = UnionFind::<u32>::new(10);
+
+    sets.unify(3, 7);
+    sets.unify(4, 2);
+
+    assert_eq!(sets.find(7), sets.find(3));
+    assert_eq!(sets.find(2), sets.find(4));
+    assert_ne!(sets.find(3), sets.find(4));
+
+    sets.unify(7, 4);
+
+    assert_eq!(sets.find(7), sets.find(3));
+    assert_eq!(sets.find(2), sets.find(4));
+    assert_eq!(sets.find(3), sets.find(4));
+
+    for i in [0, 1, 5, 6, 8, 9] {
+        assert_eq!(sets.find(i), i);
+    }
+}