about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters.rs125
1 files changed, 54 insertions, 71 deletions
diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs
index e96c7c84d88..ca7feb942fe 100644
--- a/compiler/rustc_mir_transform/src/coverage/counters.rs
+++ b/compiler/rustc_mir_transform/src/coverage/counters.rs
@@ -94,6 +94,14 @@ impl CoverageCounters {
         BcbCounter::Expression { id }
     }
 
+    /// Variant of `make_expression` that makes `lhs` optional and assumes [`Op::Add`].
+    ///
+    /// This is useful when using [`Iterator::fold`] to build an arbitrary-length sum.
+    fn make_sum_expression(&mut self, lhs: Option<BcbCounter>, rhs: BcbCounter) -> BcbCounter {
+        let Some(lhs) = lhs else { return rhs };
+        self.make_expression(lhs, Op::Add, rhs)
+    }
+
     /// Counter IDs start from one and go up.
     fn next_counter(&mut self) -> CounterId {
         let next = self.next_counter_id;
@@ -158,8 +166,8 @@ impl CoverageCounters {
         }
     }
 
-    pub(super) fn bcb_counter(&self, bcb: BasicCoverageBlock) -> Option<&BcbCounter> {
-        self.bcb_counters[bcb].as_ref()
+    pub(super) fn bcb_counter(&self, bcb: BasicCoverageBlock) -> Option<BcbCounter> {
+        self.bcb_counters[bcb]
     }
 
     pub(super) fn bcb_node_counters(
@@ -278,41 +286,30 @@ impl<'a> MakeBcbCounters<'a> {
         // counter.)
         let expression_branch = self.choose_preferred_expression_branch(traversal, &branches);
 
-        // Assign a Counter or Expression to each branch, plus additional `Expression`s, as needed,
-        // to sum up intermediate results.
-        let mut some_sumup_counter_operand = None;
-        for branch in branches {
-            // Skip the selected `expression_branch`, if any. It's expression will be assigned after
-            // all others.
-            if branch != expression_branch {
-                let branch_counter_operand = if branch.is_only_path_to_target() {
-                    debug!(
-                        "  {branch:?} has only one incoming edge (from {from_bcb:?}), \
-                        so adding a counter",
-                    );
-                    self.get_or_make_counter_operand(branch.target_bcb)
-                } else {
-                    debug!("  {:?} has multiple incoming edges, so adding an edge counter", branch);
-                    self.get_or_make_edge_counter_operand(from_bcb, branch.target_bcb)
-                };
-                if let Some(sumup_counter_operand) =
-                    some_sumup_counter_operand.replace(branch_counter_operand)
-                {
-                    let intermediate_expression = self.coverage_counters.make_expression(
-                        branch_counter_operand,
-                        Op::Add,
-                        sumup_counter_operand,
-                    );
-                    debug!("  [new intermediate expression: {:?}]", intermediate_expression);
-                    some_sumup_counter_operand.replace(intermediate_expression);
-                }
-            }
-        }
-
-        // Assign the final expression to the `expression_branch` by subtracting the total of all
-        // other branches from the counter of the branching BCB.
-        let sumup_counter_operand =
-            some_sumup_counter_operand.expect("sumup_counter_operand should have a value");
+        // For each branch arm other than the one that was chosen to get an expression,
+        // ensure that it has a counter (existing counter/expression or a new counter),
+        // and accumulate the corresponding terms into a single sum term.
+        let sum_of_all_other_branches: BcbCounter = {
+            let _span = debug_span!("sum_of_all_other_branches", ?expression_branch).entered();
+            branches
+                .into_iter()
+                // Skip the chosen branch, since we'll calculate it from the other branches.
+                .filter(|branch| branch != &expression_branch)
+                .fold(None, |accum, branch| {
+                    let _span = debug_span!("branch", ?accum, ?branch).entered();
+                    let branch_counter = if branch.is_only_path_to_target() {
+                        self.get_or_make_counter_operand(branch.target_bcb)
+                    } else {
+                        self.get_or_make_edge_counter_operand(from_bcb, branch.target_bcb)
+                    };
+                    Some(self.coverage_counters.make_sum_expression(accum, branch_counter))
+                })
+                .expect("there must be at least one other branch")
+        };
+
+        // For the branch that was chosen to get an expression, create that expression
+        // by taking the count of the node we're branching from, and subtracting the
+        // sum of all the other branches.
         debug!(
             "Making an expression for the selected expression_branch: {:?} \
             (expression_branch predecessors: {:?})",
@@ -322,7 +319,7 @@ impl<'a> MakeBcbCounters<'a> {
         let expression = self.coverage_counters.make_expression(
             from_bcb_operand,
             Op::Subtract,
-            sumup_counter_operand,
+            sum_of_all_other_branches,
         );
         debug!("{:?} gets an expression: {:?}", expression_branch, expression);
         let bcb = expression_branch.target_bcb;
@@ -359,39 +356,25 @@ impl<'a> MakeBcbCounters<'a> {
             return self.coverage_counters.set_bcb_counter(bcb, counter_kind);
         }
 
-        // A BCB with multiple incoming edges can compute its count by `Expression`, summing up the
-        // counters and/or expressions of its incoming edges. This will recursively get or create
-        // counters for those incoming edges first, then call `make_expression()` to sum them up,
-        // with additional intermediate expressions as needed.
-        let _sumup_debug_span = debug_span!("(preparing sum-up expression)").entered();
-
-        let mut predecessors = self.bcb_predecessors(bcb).to_owned().into_iter();
-        let first_edge_counter_operand =
-            self.get_or_make_edge_counter_operand(predecessors.next().unwrap(), bcb);
-        let mut some_sumup_edge_counter_operand = None;
-        for predecessor in predecessors {
-            let edge_counter_operand = self.get_or_make_edge_counter_operand(predecessor, bcb);
-            if let Some(sumup_edge_counter_operand) =
-                some_sumup_edge_counter_operand.replace(edge_counter_operand)
-            {
-                let intermediate_expression = self.coverage_counters.make_expression(
-                    sumup_edge_counter_operand,
-                    Op::Add,
-                    edge_counter_operand,
-                );
-                debug!("new intermediate expression: {intermediate_expression:?}");
-                some_sumup_edge_counter_operand.replace(intermediate_expression);
-            }
-        }
-        let counter_kind = self.coverage_counters.make_expression(
-            first_edge_counter_operand,
-            Op::Add,
-            some_sumup_edge_counter_operand.unwrap(),
-        );
-        drop(_sumup_debug_span);
-
-        debug!("{bcb:?} gets a new counter (sum of predecessor counters): {counter_kind:?}");
-        self.coverage_counters.set_bcb_counter(bcb, counter_kind)
+        // A BCB with multiple incoming edges can compute its count by ensuring that counters
+        // exist for each of those edges, and then adding them up to get a total count.
+        let sum_of_in_edges: BcbCounter = {
+            let _span = debug_span!("sum_of_in_edges", ?bcb).entered();
+            // We avoid calling `self.bcb_predecessors` here so that we can
+            // call methods on `&mut self` inside the fold.
+            self.basic_coverage_blocks.predecessors[bcb]
+                .iter()
+                .copied()
+                .fold(None, |accum, from_bcb| {
+                    let _span = debug_span!("from_bcb", ?accum, ?from_bcb).entered();
+                    let edge_counter = self.get_or_make_edge_counter_operand(from_bcb, bcb);
+                    Some(self.coverage_counters.make_sum_expression(accum, edge_counter))
+                })
+                .expect("there must be at least one in-edge")
+        };
+
+        debug!("{bcb:?} gets a new counter (sum of predecessor counters): {sum_of_in_edges:?}");
+        self.coverage_counters.set_bcb_counter(bcb, sum_of_in_edges)
     }
 
     #[instrument(level = "debug", skip(self))]