about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/builder
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs111
1 files changed, 48 insertions, 63 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 38f7eaa090f..b2c1088e3fc 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -3,20 +3,18 @@ use std::ptr;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
 use rustc_codegen_ssa::ModuleCodegen;
 use rustc_codegen_ssa::back::write::ModuleConfig;
-use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
 use rustc_errors::FatalError;
-use rustc_middle::ty::TyCtxt;
-use rustc_session::config::Lto;
 use tracing::{debug, trace};
 
-use crate::back::write::{llvm_err, llvm_optimize};
-use crate::builder::Builder;
-use crate::declare::declare_raw_fn;
+use crate::back::write::llvm_err;
+use crate::builder::SBuilder;
+use crate::context::SimpleCx;
+use crate::declare::declare_simple_fn;
 use crate::errors::LlvmError;
 use crate::llvm::AttributePlace::Function;
 use crate::llvm::{Metadata, True};
 use crate::value::Value;
-use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
+use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
 
 fn get_params(fnc: &Value) -> Vec<&Value> {
     unsafe {
@@ -38,8 +36,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
 /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
 // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
 // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
-fn generate_enzyme_call<'ll, 'tcx>(
-    cx: &context::CodegenCx<'ll, 'tcx>,
+fn generate_enzyme_call<'ll>(
+    cx: &SimpleCx<'ll>,
     fn_to_diff: &'ll Value,
     outer_fn: &'ll Value,
     attrs: AutoDiffAttrs,
@@ -54,8 +52,6 @@ fn generate_enzyme_call<'ll, 'tcx>(
     let mut ad_name: String = match attrs.mode {
         DiffMode::Forward => "__enzyme_fwddiff",
         DiffMode::Reverse => "__enzyme_autodiff",
-        DiffMode::ForwardFirst => "__enzyme_fwddiff",
-        DiffMode::ReverseFirst => "__enzyme_autodiff",
         _ => panic!("logic bug in autodiff, unrecognized mode"),
     }
     .to_string();
@@ -63,8 +59,8 @@ fn generate_enzyme_call<'ll, 'tcx>(
     // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
     // functions. Unwrap will only panic, if LLVM gave us an invalid string.
     let name = llvm::get_value_name(outer_fn);
-    let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
-    ad_name.push_str(outer_fn_name.to_string().as_str());
+    let outer_fn_name = std::str::from_utf8(name).unwrap();
+    ad_name.push_str(outer_fn_name);
 
     // Let us assume the user wrote the following function square:
     //
@@ -112,7 +108,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
         //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
         // think a bit more about what should go here.
         let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
-        let ad_fn = declare_raw_fn(
+        let ad_fn = declare_simple_fn(
             cx,
             &ad_name,
             llvm::CallConv::try_from(cc).expect("invalid callconv"),
@@ -132,7 +128,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
         llvm::LLVMRustEraseInstFromParent(br);
 
         let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
-        let mut builder = Builder::build(cx, entry);
+        let mut builder = SBuilder::build(cx, entry);
 
         let num_args = llvm::LLVMCountParams(&fn_to_diff);
         let mut args = Vec::with_capacity(num_args as usize + 1);
@@ -154,7 +150,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
             _ => {}
         }
 
-        trace!("matching autodiff arguments");
+        debug!("matching autodiff arguments");
         // We now handle the issue that Rust level arguments not always match the llvm-ir level
         // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
         // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -165,10 +161,10 @@ fn generate_enzyme_call<'ll, 'tcx>(
         let mut activity_pos = 0;
         let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
         while activity_pos < inputs.len() {
-            let activity = inputs[activity_pos as usize];
+            let diff_activity = inputs[activity_pos as usize];
             // Duplicated arguments received a shadow argument, into which enzyme will write the
             // gradient.
-            let (activity, duplicated): (&Metadata, bool) = match activity {
+            let (activity, duplicated): (&Metadata, bool) = match diff_activity {
                 DiffActivity::None => panic!("not a valid input activity"),
                 DiffActivity::Const => (enzyme_const, false),
                 DiffActivity::Active => (enzyme_out, false),
@@ -223,7 +219,15 @@ fn generate_enzyme_call<'ll, 'tcx>(
                     // A duplicated pointer will have the following two outer_fn arguments:
                     // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
                     // (..., metadata! enzyme_dup, ptr, ptr, ...).
-                    assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
+                    if matches!(
+                        diff_activity,
+                        DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
+                    ) {
+                        assert!(
+                            llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
+                        );
+                    }
+                    // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
                     args.push(next_outer_arg);
                     outer_pos += 2;
                     activity_pos += 1;
@@ -236,7 +240,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
             }
         }
 
-        let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
+        let call = builder.call(enzyme_ty, ad_fn, &args, None);
 
         // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
         // metadata attachted to it, but we just created this code oota. Given that the
@@ -256,14 +260,14 @@ fn generate_enzyme_call<'ll, 'tcx>(
             // have no debug info to copy, which would then be ok.
             trace!("no dbg info");
         }
+
         // Now that we copied the metadata, get rid of dummy code.
-        llvm::LLVMRustEraseInstBefore(entry, last_inst);
-        llvm::LLVMRustEraseInstFromParent(last_inst);
+        llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
 
-        if cx.val_ty(outer_fn) != cx.type_void() {
-            builder.ret(call);
-        } else {
+        if cx.val_ty(call) == cx.type_void() {
             builder.ret_void();
+        } else {
+            builder.ret(call);
         }
 
         // Let's crash in case that we messed something up above and generated invalid IR.
@@ -274,40 +278,44 @@ fn generate_enzyme_call<'ll, 'tcx>(
     }
 }
 
-pub(crate) fn differentiate<'ll, 'tcx>(
+pub(crate) fn differentiate<'ll>(
     module: &'ll ModuleCodegen<ModuleLlvm>,
     cgcx: &CodegenContext<LlvmCodegenBackend>,
-    tcx: TyCtxt<'tcx>,
     diff_items: Vec<AutoDiffItem>,
-    config: &ModuleConfig,
+    _config: &ModuleConfig,
 ) -> Result<(), FatalError> {
     for item in &diff_items {
         trace!("{}", item);
     }
 
     let diag_handler = cgcx.create_dcx();
-    let (_, cgus) = tcx.collect_and_partition_mono_items(());
-    let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
+    let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
 
     // Before dumping the module, we want all the TypeTrees to become part of the module.
     for item in diff_items.iter() {
         let name = item.source.clone();
         let fn_def: Option<&llvm::Value> = cx.get_function(&name);
         let Some(fn_def) = fn_def else {
-            return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
-                src: item.source.clone(),
-                target: item.target.clone(),
-                error: "could not find source function".to_owned(),
-            }));
+            return Err(llvm_err(
+                diag_handler.handle(),
+                LlvmError::PrepareAutoDiff {
+                    src: item.source.clone(),
+                    target: item.target.clone(),
+                    error: "could not find source function".to_owned(),
+                },
+            ));
         };
         debug!(?item.target);
         let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
         let Some(fn_target) = fn_target else {
-            return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
-                src: item.source.clone(),
-                target: item.target.clone(),
-                error: "could not find target function".to_owned(),
-            }));
+            return Err(llvm_err(
+                diag_handler.handle(),
+                LlvmError::PrepareAutoDiff {
+                    src: item.source.clone(),
+                    target: item.target.clone(),
+                    error: "could not find target function".to_owned(),
+                },
+            ));
         };
 
         generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
@@ -315,29 +323,6 @@ pub(crate) fn differentiate<'ll, 'tcx>(
 
     // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
 
-    if let Some(opt_level) = config.opt_level {
-        let opt_stage = match cgcx.lto {
-            Lto::Fat => llvm::OptStage::PreLinkFatLTO,
-            Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
-            _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
-            _ => llvm::OptStage::PreLinkNoLTO,
-        };
-        // This is our second opt call, so now we run all opts,
-        // to make sure we get the best performance.
-        let skip_size_increasing_opts = false;
-        trace!("running Module Optimization after differentiation");
-        unsafe {
-            llvm_optimize(
-                cgcx,
-                diag_handler.handle(),
-                module,
-                config,
-                opt_level,
-                opt_stage,
-                skip_size_increasing_opts,
-            )?
-        };
-    }
     trace!("done with differentiate()");
 
     Ok(())