about summary refs log tree commit diff
path: root/compiler/rustc_codegen_ssa
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2025-02-18 07:49:09 +0000
committerbors <bors@rust-lang.org>2025-02-18 07:49:09 +0000
commit3b022d8ceea570db9730be34d964f0cc663a567f (patch)
tree98b0584387859ce2ebcc221b9b86997cce17ba7b /compiler/rustc_codegen_ssa
parentaaa861493456e8a10e552dd208f85486de772007 (diff)
parent7bb5f4dd78a7a45729ab503805ba1b3065cb0de9 (diff)
downloadrust-3b022d8ceea570db9730be34d964f0cc663a567f.tar.gz
rust-3b022d8ceea570db9730be34d964f0cc663a567f.zip
Auto merge of #133852 - x17jiri:cold_path, r=saethlin
improve cold_path()

#120370 added a new instrinsic `cold_path()` and used it to fix `likely` and `unlikely`

However, in order to limit scope, the information about cold code paths is only used in 2-target switch instructions. This is sufficient for `likely` and `unlikely`, but limits usefulness of `cold_path` for idiomatic rust. For example, code like this:

```
if let Some(x) = y { ... }
```

may generate 3-target switch:

```
switch y.discriminator:
0 => true branch
1 = > false branch
_ => unreachable
```

and therefore marking a branch as cold will have no effect.

This PR improves `cold_path()` to work with arbitrary switch instructions.

Note that for 2-target switches, we can use `llvm.expect`, but for multiple targets we need to manually emit branch weights. I checked Clang and it also emits weights in this situation. The Clang's weight calculation is more complex that this PR, which I believe is mainly because `switch` in `C/C++` can have multiple cases going to the same target.
Diffstat (limited to 'compiler/rustc_codegen_ssa')
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/block.rs33
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/mod.rs27
-rw-r--r--compiler/rustc_codegen_ssa/src/traits/builder.rs14
3 files changed, 61 insertions, 13 deletions
diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs
index 4630ed48c52..616d748a299 100644
--- a/compiler/rustc_codegen_ssa/src/mir/block.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/block.rs
@@ -429,11 +429,34 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
             let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
             bx.cond_br(cmp, ll1, ll2);
         } else {
-            bx.switch(
-                discr_value,
-                helper.llbb_with_cleanup(self, targets.otherwise()),
-                target_iter.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
-            );
+            let otherwise = targets.otherwise();
+            let otherwise_cold = self.cold_blocks[otherwise];
+            let otherwise_unreachable = self.mir[otherwise].is_empty_unreachable();
+            let cold_count = targets.iter().filter(|(_, target)| self.cold_blocks[*target]).count();
+            let none_cold = cold_count == 0;
+            let all_cold = cold_count == targets.iter().len();
+            if (none_cold && (!otherwise_cold || otherwise_unreachable))
+                || (all_cold && (otherwise_cold || otherwise_unreachable))
+            {
+                // All targets have the same weight,
+                // or `otherwise` is unreachable and it's the only target with a different weight.
+                bx.switch(
+                    discr_value,
+                    helper.llbb_with_cleanup(self, targets.otherwise()),
+                    target_iter
+                        .map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
+                );
+            } else {
+                // Targets have different weights
+                bx.switch_with_weights(
+                    discr_value,
+                    helper.llbb_with_cleanup(self, targets.otherwise()),
+                    otherwise_cold,
+                    target_iter.map(|(value, target)| {
+                        (value, helper.llbb_with_cleanup(self, target), self.cold_blocks[target])
+                    }),
+                );
+            }
         }
     }
 
diff --git a/compiler/rustc_codegen_ssa/src/mir/mod.rs b/compiler/rustc_codegen_ssa/src/mir/mod.rs
index 3a896071bc6..ba28720afec 100644
--- a/compiler/rustc_codegen_ssa/src/mir/mod.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/mod.rs
@@ -502,14 +502,25 @@ fn find_cold_blocks<'tcx>(
     for (bb, bb_data) in traversal::postorder(mir) {
         let terminator = bb_data.terminator();
 
-        // If a BB ends with a call to a cold function, mark it as cold.
-        if let mir::TerminatorKind::Call { ref func, .. } = terminator.kind
-            && let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
-            && let attrs = tcx.codegen_fn_attrs(def_id)
-            && attrs.flags.contains(CodegenFnAttrFlags::COLD)
-        {
-            cold_blocks[bb] = true;
-            continue;
+        match terminator.kind {
+            // If a BB ends with a call to a cold function, mark it as cold.
+            mir::TerminatorKind::Call { ref func, .. }
+            | mir::TerminatorKind::TailCall { ref func, .. }
+                if let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
+                    && let attrs = tcx.codegen_fn_attrs(def_id)
+                    && attrs.flags.contains(CodegenFnAttrFlags::COLD) =>
+            {
+                cold_blocks[bb] = true;
+                continue;
+            }
+
+            // If a BB ends with an `unreachable`, also mark it as cold.
+            mir::TerminatorKind::Unreachable => {
+                cold_blocks[bb] = true;
+                continue;
+            }
+
+            _ => {}
         }
 
         // If all successors of a BB are cold and there's at least one of them, mark this BB as cold
diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs
index 2c843e2f5e4..48ae000f2c6 100644
--- a/compiler/rustc_codegen_ssa/src/traits/builder.rs
+++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs
@@ -110,6 +110,20 @@ pub trait BuilderMethods<'a, 'tcx>:
         else_llbb: Self::BasicBlock,
         cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock)>,
     );
+
+    // This is like `switch()`, but every case has a bool flag indicating whether it's cold.
+    //
+    // Default implementation throws away the cold flags and calls `switch()`.
+    fn switch_with_weights(
+        &mut self,
+        v: Self::Value,
+        else_llbb: Self::BasicBlock,
+        _else_is_cold: bool,
+        cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
+    ) {
+        self.switch(v, else_llbb, cases.map(|(val, bb, _)| (val, bb)))
+    }
+
     fn invoke(
         &mut self,
         llty: Self::Type,