about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2025-01-24 16:05:26 -0500
committerManuel Drehwald <git@manuel.drehwald.info>2025-01-24 16:05:26 -0500
commit386c233858874c5412345df6fd6ebf87298727dd (patch)
tree00e32a4928e01d989bda6ddf4aba700e8e699cf6 /compiler/rustc_codegen_llvm/src/builder/autodiff.rs
parenta48e7b00570baaaba9d32d783d5702c06afd104d (diff)
downloadrust-386c233858874c5412345df6fd6ebf87298727dd.tar.gz
rust-386c233858874c5412345df6fd6ebf87298727dd.zip
Make CodegenCx and Builder generic
Co-authored-by: Oli Scherer <github35764891676564198441@oli-obk.de>
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder/autodiff.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs25
1 files changed, 11 insertions, 14 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 38f7eaa090f..6b17b5f6989 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -3,20 +3,19 @@ use std::ptr;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
 use rustc_codegen_ssa::ModuleCodegen;
 use rustc_codegen_ssa::back::write::ModuleConfig;
-use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
 use rustc_errors::FatalError;
-use rustc_middle::ty::TyCtxt;
 use rustc_session::config::Lto;
 use tracing::{debug, trace};
 
 use crate::back::write::{llvm_err, llvm_optimize};
-use crate::builder::Builder;
-use crate::declare::declare_raw_fn;
+use crate::builder::SBuilder;
+use crate::context::SimpleCx;
+use crate::declare::declare_simple_fn;
 use crate::errors::LlvmError;
 use crate::llvm::AttributePlace::Function;
 use crate::llvm::{Metadata, True};
 use crate::value::Value;
-use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
+use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
 
 fn get_params(fnc: &Value) -> Vec<&Value> {
     unsafe {
@@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
 /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
 // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
 // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
-fn generate_enzyme_call<'ll, 'tcx>(
-    cx: &context::CodegenCx<'ll, 'tcx>,
+fn generate_enzyme_call<'ll>(
+    cx: &SimpleCx<'ll>,
     fn_to_diff: &'ll Value,
     outer_fn: &'ll Value,
     attrs: AutoDiffAttrs,
@@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
         //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
         // think a bit more about what should go here.
         let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
-        let ad_fn = declare_raw_fn(
+        let ad_fn = declare_simple_fn(
             cx,
             &ad_name,
             llvm::CallConv::try_from(cc).expect("invalid callconv"),
@@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
         llvm::LLVMRustEraseInstFromParent(br);
 
         let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
-        let mut builder = Builder::build(cx, entry);
+        let mut builder = SBuilder::build(cx, entry);
 
         let num_args = llvm::LLVMCountParams(&fn_to_diff);
         let mut args = Vec::with_capacity(num_args as usize + 1);
@@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
             }
         }
 
-        let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
+        let call = builder.call(enzyme_ty, ad_fn, &args, None);
 
         // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
         // metadata attachted to it, but we just created this code oota. Given that the
@@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>(
     }
 }
 
-pub(crate) fn differentiate<'ll, 'tcx>(
+pub(crate) fn differentiate<'ll>(
     module: &'ll ModuleCodegen<ModuleLlvm>,
     cgcx: &CodegenContext<LlvmCodegenBackend>,
-    tcx: TyCtxt<'tcx>,
     diff_items: Vec<AutoDiffItem>,
     config: &ModuleConfig,
 ) -> Result<(), FatalError> {
@@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
     }
 
     let diag_handler = cgcx.create_dcx();
-    let (_, cgus) = tcx.collect_and_partition_mono_items(());
-    let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
+    let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
 
     // Before dumping the module, we want all the TypeTrees to become part of the module.
     for item in diff_items.iter() {