about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJiri Bobek <jiri.bobek@gmail.com>2024-12-04 12:30:42 +0100
committerJiri Bobek <jiri.bobek@gmail.com>2025-02-17 06:39:58 +0100
commit7bb5f4dd78a7a45729ab503805ba1b3065cb0de9 (patch)
tree3f5ff28b464773e1418cc5900198bed2d1265343
parentc705b7d6f748823352f95247835a5839dcbd7bed (diff)
downloadrust-7bb5f4dd78a7a45729ab503805ba1b3065cb0de9.tar.gz
rust-7bb5f4dd78a7a45729ab503805ba1b3065cb0de9.zip
improve cold_path()
-rw-r--r--compiler/rustc_codegen_llvm/src/builder.rs48
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/block.rs33
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/mod.rs27
-rw-r--r--compiler/rustc_codegen_ssa/src/traits/builder.rs14
-rw-r--r--tests/codegen/intrinsics/cold_path2.rs36
-rw-r--r--tests/codegen/intrinsics/cold_path3.rs87
6 files changed, 230 insertions, 15 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs
index 2d007416263..b63c8f8caef 100644
--- a/compiler/rustc_codegen_llvm/src/builder.rs
+++ b/compiler/rustc_codegen_llvm/src/builder.rs
@@ -4,7 +4,7 @@ use std::{iter, ptr};
 
 pub(crate) mod autodiff;
 
-use libc::{c_char, c_uint};
+use libc::{c_char, c_uint, size_t};
 use rustc_abi as abi;
 use rustc_abi::{Align, Size, WrappingRange};
 use rustc_codegen_ssa::MemFlags;
@@ -32,7 +32,7 @@ use crate::abi::FnAbiLlvmExt;
 use crate::attributes;
 use crate::common::Funclet;
 use crate::context::{CodegenCx, SimpleCx};
-use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
+use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True};
 use crate::type_::Type;
 use crate::type_of::LayoutLlvmExt;
 use crate::value::Value;
@@ -333,6 +333,50 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
         }
     }
 
+    fn switch_with_weights(
+        &mut self,
+        v: Self::Value,
+        else_llbb: Self::BasicBlock,
+        else_is_cold: bool,
+        cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
+    ) {
+        if self.cx.sess().opts.optimize == rustc_session::config::OptLevel::No {
+            self.switch(v, else_llbb, cases.map(|(val, dest, _)| (val, dest)));
+            return;
+        }
+
+        let id_str = "branch_weights";
+        let id = unsafe {
+            llvm::LLVMMDStringInContext2(self.cx.llcx, id_str.as_ptr().cast(), id_str.len())
+        };
+
+        // For switch instructions with 2 targets, the `llvm.expect` intrinsic is used.
+        // This function handles switch instructions with more than 2 targets and it needs to
+        // emit branch weights metadata instead of using the intrinsic.
+        // The values 1 and 2000 are the same as the values used by the `llvm.expect` intrinsic.
+        let cold_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(1)) };
+        let hot_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(2000)) };
+        let weight =
+            |is_cold: bool| -> &Metadata { if is_cold { cold_weight } else { hot_weight } };
+
+        let mut md: SmallVec<[&Metadata; 16]> = SmallVec::with_capacity(cases.len() + 2);
+        md.push(id);
+        md.push(weight(else_is_cold));
+
+        let switch =
+            unsafe { llvm::LLVMBuildSwitch(self.llbuilder, v, else_llbb, cases.len() as c_uint) };
+        for (on_val, dest, is_cold) in cases {
+            let on_val = self.const_uint_big(self.val_ty(v), on_val);
+            unsafe { llvm::LLVMAddCase(switch, on_val, dest) }
+            md.push(weight(is_cold));
+        }
+
+        unsafe {
+            let md_node = llvm::LLVMMDNodeInContext2(self.cx.llcx, md.as_ptr(), md.len() as size_t);
+            self.cx.set_metadata(switch, llvm::MD_prof, md_node);
+        }
+    }
+
     fn invoke(
         &mut self,
         llty: &'ll Type,
diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs
index 4be363ca9a2..66008a47650 100644
--- a/compiler/rustc_codegen_ssa/src/mir/block.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/block.rs
@@ -429,11 +429,34 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
             let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
             bx.cond_br(cmp, ll1, ll2);
         } else {
-            bx.switch(
-                discr_value,
-                helper.llbb_with_cleanup(self, targets.otherwise()),
-                target_iter.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
-            );
+            let otherwise = targets.otherwise();
+            let otherwise_cold = self.cold_blocks[otherwise];
+            let otherwise_unreachable = self.mir[otherwise].is_empty_unreachable();
+            let cold_count = targets.iter().filter(|(_, target)| self.cold_blocks[*target]).count();
+            let none_cold = cold_count == 0;
+            let all_cold = cold_count == targets.iter().len();
+            if (none_cold && (!otherwise_cold || otherwise_unreachable))
+                || (all_cold && (otherwise_cold || otherwise_unreachable))
+            {
+                // All targets have the same weight,
+                // or `otherwise` is unreachable and it's the only target with a different weight.
+                bx.switch(
+                    discr_value,
+                    helper.llbb_with_cleanup(self, targets.otherwise()),
+                    target_iter
+                        .map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
+                );
+            } else {
+                // Targets have different weights
+                bx.switch_with_weights(
+                    discr_value,
+                    helper.llbb_with_cleanup(self, targets.otherwise()),
+                    otherwise_cold,
+                    target_iter.map(|(value, target)| {
+                        (value, helper.llbb_with_cleanup(self, target), self.cold_blocks[target])
+                    }),
+                );
+            }
         }
     }
 
diff --git a/compiler/rustc_codegen_ssa/src/mir/mod.rs b/compiler/rustc_codegen_ssa/src/mir/mod.rs
index 3a896071bc6..ba28720afec 100644
--- a/compiler/rustc_codegen_ssa/src/mir/mod.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/mod.rs
@@ -502,14 +502,25 @@ fn find_cold_blocks<'tcx>(
     for (bb, bb_data) in traversal::postorder(mir) {
         let terminator = bb_data.terminator();
 
-        // If a BB ends with a call to a cold function, mark it as cold.
-        if let mir::TerminatorKind::Call { ref func, .. } = terminator.kind
-            && let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
-            && let attrs = tcx.codegen_fn_attrs(def_id)
-            && attrs.flags.contains(CodegenFnAttrFlags::COLD)
-        {
-            cold_blocks[bb] = true;
-            continue;
+        match terminator.kind {
+            // If a BB ends with a call to a cold function, mark it as cold.
+            mir::TerminatorKind::Call { ref func, .. }
+            | mir::TerminatorKind::TailCall { ref func, .. }
+                if let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
+                    && let attrs = tcx.codegen_fn_attrs(def_id)
+                    && attrs.flags.contains(CodegenFnAttrFlags::COLD) =>
+            {
+                cold_blocks[bb] = true;
+                continue;
+            }
+
+            // If a BB ends with an `unreachable`, also mark it as cold.
+            mir::TerminatorKind::Unreachable => {
+                cold_blocks[bb] = true;
+                continue;
+            }
+
+            _ => {}
         }
 
         // If all successors of a BB are cold and there's at least one of them, mark this BB as cold
diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs
index bbf87a59942..cbe0f2e8059 100644
--- a/compiler/rustc_codegen_ssa/src/traits/builder.rs
+++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs
@@ -110,6 +110,20 @@ pub trait BuilderMethods<'a, 'tcx>:
         else_llbb: Self::BasicBlock,
         cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock)>,
     );
+
+    // This is like `switch()`, but every case has a bool flag indicating whether it's cold.
+    //
+    // Default implementation throws away the cold flags and calls `switch()`.
+    fn switch_with_weights(
+        &mut self,
+        v: Self::Value,
+        else_llbb: Self::BasicBlock,
+        _else_is_cold: bool,
+        cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
+    ) {
+        self.switch(v, else_llbb, cases.map(|(val, bb, _)| (val, bb)))
+    }
+
     fn invoke(
         &mut self,
         llty: Self::Type,
diff --git a/tests/codegen/intrinsics/cold_path2.rs b/tests/codegen/intrinsics/cold_path2.rs
new file mode 100644
index 00000000000..1e7e0478f4f
--- /dev/null
+++ b/tests/codegen/intrinsics/cold_path2.rs
@@ -0,0 +1,36 @@
+//@ compile-flags: -O
+#![crate_type = "lib"]
+#![feature(core_intrinsics)]
+
+use std::intrinsics::cold_path;
+
+#[inline(never)]
+#[no_mangle]
+pub fn path_a() {
+    println!("path a");
+}
+
+#[inline(never)]
+#[no_mangle]
+pub fn path_b() {
+    println!("path b");
+}
+
+#[no_mangle]
+pub fn test(x: Option<bool>) {
+    if let Some(_) = x {
+        path_a();
+    } else {
+        cold_path();
+        path_b();
+    }
+
+    // CHECK-LABEL: @test(
+    // CHECK: br i1 %1, label %bb2, label %bb1, !prof ![[NUM:[0-9]+]]
+    // CHECK: bb1:
+    // CHECK: path_a
+    // CHECK: bb2:
+    // CHECK: path_b
+}
+
+// CHECK: ![[NUM]] = !{!"branch_weights", {{(!"expected", )?}}i32 1, i32 2000}
diff --git a/tests/codegen/intrinsics/cold_path3.rs b/tests/codegen/intrinsics/cold_path3.rs
new file mode 100644
index 00000000000..bf3347de665
--- /dev/null
+++ b/tests/codegen/intrinsics/cold_path3.rs
@@ -0,0 +1,87 @@
+//@ compile-flags: -O
+#![crate_type = "lib"]
+#![feature(core_intrinsics)]
+
+use std::intrinsics::cold_path;
+
+#[inline(never)]
+#[no_mangle]
+pub fn path_a() {
+    println!("path a");
+}
+
+#[inline(never)]
+#[no_mangle]
+pub fn path_b() {
+    println!("path b");
+}
+
+#[inline(never)]
+#[no_mangle]
+pub fn path_c() {
+    println!("path c");
+}
+
+#[inline(never)]
+#[no_mangle]
+pub fn path_d() {
+    println!("path d");
+}
+
+#[no_mangle]
+pub fn test(x: Option<u32>) {
+    match x {
+        Some(0) => path_a(),
+        Some(1) => {
+            cold_path();
+            path_b()
+        }
+        Some(2) => path_c(),
+        Some(3) => {
+            cold_path();
+            path_d()
+        }
+        _ => path_a(),
+    }
+
+    // CHECK-LABEL: @test(
+    // CHECK: switch i32 %1, label %bb1 [
+    // CHECK: i32 0, label %bb6
+    // CHECK: i32 1, label %bb5
+    // CHECK: i32 2, label %bb4
+    // CHECK: i32 3, label %bb3
+    // CHECK: ], !prof ![[NUM1:[0-9]+]]
+}
+
+#[no_mangle]
+pub fn test2(x: Option<u32>) {
+    match x {
+        Some(10) => path_a(),
+        Some(11) => {
+            cold_path();
+            path_b()
+        }
+        Some(12) => {
+            unsafe { core::intrinsics::unreachable() };
+            path_c()
+        }
+        Some(13) => {
+            cold_path();
+            path_d()
+        }
+        _ => {
+            cold_path();
+            path_a()
+        }
+    }
+
+    // CHECK-LABEL: @test2(
+    // CHECK: switch i32 %1, label %bb1 [
+    // CHECK: i32 10, label %bb5
+    // CHECK: i32 11, label %bb4
+    // CHECK: i32 13, label %bb3
+    // CHECK: ], !prof ![[NUM2:[0-9]+]]
+}
+
+// CHECK: ![[NUM1]] = !{!"branch_weights", i32 2000, i32 2000, i32 1, i32 2000, i32 1}
+// CHECK: ![[NUM2]] = !{!"branch_weights", i32 1, i32 2000, i32 1, i32 1}