about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs25
-rw-r--r--compiler/rustc_codegen_llvm/src/back/write.rs41
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs46
-rw-r--r--compiler/rustc_codegen_llvm/src/coverageinfo/mapgen/covfun.rs55
-rw-r--r--compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs23
-rw-r--r--compiler/rustc_codegen_llvm/src/lib.rs1
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/ffi.rs1
7 files changed, 95 insertions, 97 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index 78c759bbe8c..8bad437eeb7 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -606,10 +606,31 @@ pub(crate) fn run_pass_manager(
 
     // If this rustc version was build with enzyme/autodiff enabled, and if users applied the
     // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
-    let first_run = true;
     debug!("running llvm pm opt pipeline");
     unsafe {
-        write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
+        write::llvm_optimize(
+            cgcx,
+            dcx,
+            module,
+            config,
+            opt_level,
+            opt_stage,
+            write::AutodiffStage::DuringAD,
+        )?;
+    }
+    // FIXME(ZuseZ4): Make this more granular
+    if cfg!(llvm_enzyme) && !thin {
+        unsafe {
+            write::llvm_optimize(
+                cgcx,
+                dcx,
+                module,
+                config,
+                opt_level,
+                llvm::OptStage::FatLTO,
+                write::AutodiffStage::PostAD,
+            )?;
+        }
     }
     debug!("lto done");
     Ok(())
diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs
index 4706744f353..ae4c4d5876e 100644
--- a/compiler/rustc_codegen_llvm/src/back/write.rs
+++ b/compiler/rustc_codegen_llvm/src/back/write.rs
@@ -530,6 +530,16 @@ fn get_instr_profile_output_path(config: &ModuleConfig) -> Option<CString> {
     config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned())
 }
 
+// PreAD will run llvm opts but disable size increasing opts (vectorization, loop unrolling)
+// DuringAD is the same as above, but also runs the enzyme opt and autodiff passes.
+// PostAD will run all opts, including size increasing opts.
+#[derive(Debug, Eq, PartialEq)]
+pub(crate) enum AutodiffStage {
+    PreAD,
+    DuringAD,
+    PostAD,
+}
+
 pub(crate) unsafe fn llvm_optimize(
     cgcx: &CodegenContext<LlvmCodegenBackend>,
     dcx: DiagCtxtHandle<'_>,
@@ -537,7 +547,7 @@ pub(crate) unsafe fn llvm_optimize(
     config: &ModuleConfig,
     opt_level: config::OptLevel,
     opt_stage: llvm::OptStage,
-    skip_size_increasing_opts: bool,
+    autodiff_stage: AutodiffStage,
 ) -> Result<(), FatalError> {
     // Enzyme:
     // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
@@ -550,12 +560,16 @@ pub(crate) unsafe fn llvm_optimize(
     let unroll_loops;
     let vectorize_slp;
     let vectorize_loop;
+    let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
 
     // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
-    // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
+    // optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
+    // We therefore have two calls to llvm_optimize, if autodiff is used.
+    //
+    // FIXME(ZuseZ4): Before shipping on nightly,
     // we should make this more granular, or at least check that the user has at least one autodiff
     // call in their code, to justify altering the compilation pipeline.
-    if skip_size_increasing_opts && cfg!(llvm_enzyme) {
+    if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
         unroll_loops = false;
         vectorize_slp = false;
         vectorize_loop = false;
@@ -565,7 +579,7 @@ pub(crate) unsafe fn llvm_optimize(
         vectorize_slp = config.vectorize_slp;
         vectorize_loop = config.vectorize_loop;
     }
-    trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop);
+    trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop, ?run_enzyme);
     let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
     let pgo_gen_path = get_pgo_gen_path(config);
     let pgo_use_path = get_pgo_use_path(config);
@@ -633,6 +647,7 @@ pub(crate) unsafe fn llvm_optimize(
             vectorize_loop,
             config.no_builtins,
             config.emit_lifetime_markers,
+            run_enzyme,
             sanitizer_options.as_ref(),
             pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
             pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
@@ -684,18 +699,14 @@ pub(crate) unsafe fn optimize(
             _ => llvm::OptStage::PreLinkNoLTO,
         };
 
-        // If we know that we will later run AD, then we disable vectorization and loop unrolling
-        let skip_size_increasing_opts = cfg!(llvm_enzyme);
+        // If we know that we will later run AD, then we disable vectorization and loop unrolling.
+        // Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
+        // FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
+        // usages, not just if we build rustc with autodiff support.
+        let autodiff_stage =
+            if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
         return unsafe {
-            llvm_optimize(
-                cgcx,
-                dcx,
-                module,
-                config,
-                opt_level,
-                opt_stage,
-                skip_size_increasing_opts,
-            )
+            llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
         };
     }
     Ok(())
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index dd5e726160d..b2c1088e3fc 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
 use rustc_codegen_ssa::ModuleCodegen;
 use rustc_codegen_ssa::back::write::ModuleConfig;
 use rustc_errors::FatalError;
-use rustc_session::config::Lto;
 use tracing::{debug, trace};
 
-use crate::back::write::{llvm_err, llvm_optimize};
+use crate::back::write::llvm_err;
 use crate::builder::SBuilder;
 use crate::context::SimpleCx;
 use crate::declare::declare_simple_fn;
@@ -53,8 +52,6 @@ fn generate_enzyme_call<'ll>(
     let mut ad_name: String = match attrs.mode {
         DiffMode::Forward => "__enzyme_fwddiff",
         DiffMode::Reverse => "__enzyme_autodiff",
-        DiffMode::ForwardFirst => "__enzyme_fwddiff",
-        DiffMode::ReverseFirst => "__enzyme_autodiff",
         _ => panic!("logic bug in autodiff, unrecognized mode"),
     }
     .to_string();
@@ -153,7 +150,7 @@ fn generate_enzyme_call<'ll>(
             _ => {}
         }
 
-        trace!("matching autodiff arguments");
+        debug!("matching autodiff arguments");
         // We now handle the issue that Rust level arguments not always match the llvm-ir level
         // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
         // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -164,10 +161,10 @@ fn generate_enzyme_call<'ll>(
         let mut activity_pos = 0;
         let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
         while activity_pos < inputs.len() {
-            let activity = inputs[activity_pos as usize];
+            let diff_activity = inputs[activity_pos as usize];
             // Duplicated arguments received a shadow argument, into which enzyme will write the
             // gradient.
-            let (activity, duplicated): (&Metadata, bool) = match activity {
+            let (activity, duplicated): (&Metadata, bool) = match diff_activity {
                 DiffActivity::None => panic!("not a valid input activity"),
                 DiffActivity::Const => (enzyme_const, false),
                 DiffActivity::Active => (enzyme_out, false),
@@ -222,7 +219,15 @@ fn generate_enzyme_call<'ll>(
                     // A duplicated pointer will have the following two outer_fn arguments:
                     // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
                     // (..., metadata! enzyme_dup, ptr, ptr, ...).
-                    assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
+                    if matches!(
+                        diff_activity,
+                        DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
+                    ) {
+                        assert!(
+                            llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
+                        );
+                    }
+                    // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
                     args.push(next_outer_arg);
                     outer_pos += 2;
                     activity_pos += 1;
@@ -277,7 +282,7 @@ pub(crate) fn differentiate<'ll>(
     module: &'ll ModuleCodegen<ModuleLlvm>,
     cgcx: &CodegenContext<LlvmCodegenBackend>,
     diff_items: Vec<AutoDiffItem>,
-    config: &ModuleConfig,
+    _config: &ModuleConfig,
 ) -> Result<(), FatalError> {
     for item in &diff_items {
         trace!("{}", item);
@@ -318,29 +323,6 @@ pub(crate) fn differentiate<'ll>(
 
     // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
 
-    if let Some(opt_level) = config.opt_level {
-        let opt_stage = match cgcx.lto {
-            Lto::Fat => llvm::OptStage::PreLinkFatLTO,
-            Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
-            _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
-            _ => llvm::OptStage::PreLinkNoLTO,
-        };
-        // This is our second opt call, so now we run all opts,
-        // to make sure we get the best performance.
-        let skip_size_increasing_opts = false;
-        trace!("running Module Optimization after differentiation");
-        unsafe {
-            llvm_optimize(
-                cgcx,
-                diag_handler.handle(),
-                module,
-                config,
-                opt_level,
-                opt_stage,
-                skip_size_increasing_opts,
-            )?
-        };
-    }
     trace!("done with differentiate()");
 
     Ok(())
diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen/covfun.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen/covfun.rs
index 460a4664615..c53ea6d4666 100644
--- a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen/covfun.rs
+++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen/covfun.rs
@@ -11,7 +11,8 @@ use rustc_codegen_ssa::traits::{
     BaseTypeCodegenMethods, ConstCodegenMethods, StaticCodegenMethods,
 };
 use rustc_middle::mir::coverage::{
-    CovTerm, CoverageIdsInfo, Expression, FunctionCoverageInfo, Mapping, MappingKind, Op,
+    BasicCoverageBlock, CovTerm, CoverageIdsInfo, Expression, FunctionCoverageInfo, Mapping,
+    MappingKind, Op,
 };
 use rustc_middle::ty::{Instance, TyCtxt};
 use rustc_span::Span;
@@ -53,7 +54,7 @@ pub(crate) fn prepare_covfun_record<'tcx>(
     let fn_cov_info = tcx.instance_mir(instance.def).function_coverage_info.as_deref()?;
     let ids_info = tcx.coverage_ids_info(instance.def)?;
 
-    let expressions = prepare_expressions(fn_cov_info, ids_info, is_used);
+    let expressions = prepare_expressions(ids_info);
 
     let mut covfun = CovfunRecord {
         mangled_function_name: tcx.symbol_name(instance).name,
@@ -75,26 +76,14 @@ pub(crate) fn prepare_covfun_record<'tcx>(
 }
 
 /// Convert the function's coverage-counter expressions into a form suitable for FFI.
-fn prepare_expressions(
-    fn_cov_info: &FunctionCoverageInfo,
-    ids_info: &CoverageIdsInfo,
-    is_used: bool,
-) -> Vec<ffi::CounterExpression> {
-    // If any counters or expressions were removed by MIR opts, replace their
-    // terms with zero.
-    let counter_for_term = |term| {
-        if !is_used || ids_info.is_zero_term(term) {
-            ffi::Counter::ZERO
-        } else {
-            ffi::Counter::from_term(term)
-        }
-    };
+fn prepare_expressions(ids_info: &CoverageIdsInfo) -> Vec<ffi::CounterExpression> {
+    let counter_for_term = ffi::Counter::from_term;
 
     // We know that LLVM will optimize out any unused expressions before
     // producing the final coverage map, so there's no need to do the same
     // thing on the Rust side unless we're confident we can do much better.
     // (See `CounterExpressionsMinimizer` in `CoverageMappingWriter.cpp`.)
-    fn_cov_info
+    ids_info
         .expressions
         .iter()
         .map(move |&Expression { lhs, op, rhs }| ffi::CounterExpression {
@@ -136,11 +125,16 @@ fn fill_region_tables<'tcx>(
 
     // For each counter/region pair in this function+file, convert it to a
     // form suitable for FFI.
-    let is_zero_term = |term| !covfun.is_used || ids_info.is_zero_term(term);
     for &Mapping { ref kind, span } in &fn_cov_info.mappings {
-        // If the mapping refers to counters/expressions that were removed by
-        // MIR opts, replace those occurrences with zero.
-        let kind = kind.map_terms(|term| if is_zero_term(term) { CovTerm::Zero } else { term });
+        // If this function is unused, replace all counters with zero.
+        let counter_for_bcb = |bcb: BasicCoverageBlock| -> ffi::Counter {
+            let term = if covfun.is_used {
+                ids_info.term_for_bcb[bcb].expect("every BCB in a mapping was given a term")
+            } else {
+                CovTerm::Zero
+            };
+            ffi::Counter::from_term(term)
+        };
 
         // Convert the `Span` into coordinates that we can pass to LLVM, or
         // discard the span if conversion fails. In rare, cases _all_ of a
@@ -154,23 +148,22 @@ fn fill_region_tables<'tcx>(
             continue;
         }
 
-        match kind {
-            MappingKind::Code(term) => {
-                code_regions
-                    .push(ffi::CodeRegion { cov_span, counter: ffi::Counter::from_term(term) });
+        match *kind {
+            MappingKind::Code { bcb } => {
+                code_regions.push(ffi::CodeRegion { cov_span, counter: counter_for_bcb(bcb) });
             }
-            MappingKind::Branch { true_term, false_term } => {
+            MappingKind::Branch { true_bcb, false_bcb } => {
                 branch_regions.push(ffi::BranchRegion {
                     cov_span,
-                    true_counter: ffi::Counter::from_term(true_term),
-                    false_counter: ffi::Counter::from_term(false_term),
+                    true_counter: counter_for_bcb(true_bcb),
+                    false_counter: counter_for_bcb(false_bcb),
                 });
             }
-            MappingKind::MCDCBranch { true_term, false_term, mcdc_params } => {
+            MappingKind::MCDCBranch { true_bcb, false_bcb, mcdc_params } => {
                 mcdc_branch_regions.push(ffi::MCDCBranchRegion {
                     cov_span,
-                    true_counter: ffi::Counter::from_term(true_term),
-                    false_counter: ffi::Counter::from_term(false_term),
+                    true_counter: counter_for_bcb(true_bcb),
+                    false_counter: counter_for_bcb(false_bcb),
                     mcdc_branch_params: ffi::mcdc::BranchParameters::from(mcdc_params),
                 });
             }
diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs
index 021108cd51c..ea7f581a3cb 100644
--- a/compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs
+++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs
@@ -160,21 +160,12 @@ impl<'tcx> CoverageInfoBuilderMethods<'tcx> for Builder<'_, '_, 'tcx> {
             CoverageKind::SpanMarker | CoverageKind::BlockMarker { .. } => unreachable!(
                 "marker statement {kind:?} should have been removed by CleanupPostBorrowck"
             ),
-            CoverageKind::CounterIncrement { id } => {
-                // The number of counters passed to `llvm.instrprof.increment` might
-                // be smaller than the number originally inserted by the instrumentor,
-                // if some high-numbered counters were removed by MIR optimizations.
-                // If so, LLVM's profiler runtime will use fewer physical counters.
-                let num_counters = ids_info.num_counters_after_mir_opts();
-                assert!(
-                    num_counters as usize <= function_coverage_info.num_counters,
-                    "num_counters disagreement: query says {num_counters} but function info only has {}",
-                    function_coverage_info.num_counters
-                );
-
+            CoverageKind::VirtualCounter { bcb }
+                if let Some(&id) = ids_info.phys_counter_for_node.get(&bcb) =>
+            {
                 let fn_name = bx.get_pgo_func_name_var(instance);
                 let hash = bx.const_u64(function_coverage_info.function_source_hash);
-                let num_counters = bx.const_u32(num_counters);
+                let num_counters = bx.const_u32(ids_info.num_counters);
                 let index = bx.const_u32(id.as_u32());
                 debug!(
                     "codegen intrinsic instrprof.increment(fn_name={:?}, hash={:?}, num_counters={:?}, index={:?})",
@@ -182,10 +173,8 @@ impl<'tcx> CoverageInfoBuilderMethods<'tcx> for Builder<'_, '_, 'tcx> {
                 );
                 bx.instrprof_increment(fn_name, hash, num_counters, index);
             }
-            CoverageKind::ExpressionUsed { id: _ } => {
-                // Expression-used statements are markers that are handled by
-                // `coverage_ids_info`, so there's nothing to codegen here.
-            }
+            // If a BCB doesn't have an associated physical counter, there's nothing to codegen.
+            CoverageKind::VirtualCounter { .. } => {}
             CoverageKind::CondBitmapUpdate { index, decision_depth } => {
                 let cond_bitmap = coverage_cx
                     .try_get_mcdc_condition_bitmap(&instance, decision_depth)
diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs
index 14346795fda..f0d04b2b644 100644
--- a/compiler/rustc_codegen_llvm/src/lib.rs
+++ b/compiler/rustc_codegen_llvm/src/lib.rs
@@ -13,6 +13,7 @@
 #![feature(extern_types)]
 #![feature(file_buffered)]
 #![feature(hash_raw_entry)]
+#![feature(if_let_guard)]
 #![feature(impl_trait_in_assoc_type)]
 #![feature(iter_intersperse)]
 #![feature(let_chains)]
diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
index 50a40c9c309..4d6a76b23ea 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
@@ -2382,6 +2382,7 @@ unsafe extern "C" {
         LoopVectorize: bool,
         DisableSimplifyLibCalls: bool,
         EmitLifetimeMarkers: bool,
+        RunEnzyme: bool,
         SanitizerOptions: Option<&SanitizerOptions>,
         PGOGenPath: *const c_char,
         PGOUsePath: *const c_char,