diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-01-24 16:05:26 -0500 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-01-24 16:05:26 -0500 |
| commit | 386c233858874c5412345df6fd6ebf87298727dd (patch) | |
| tree | 00e32a4928e01d989bda6ddf4aba700e8e699cf6 /compiler/rustc_codegen_llvm/src/builder | |
| parent | a48e7b00570baaaba9d32d783d5702c06afd104d (diff) | |
| download | rust-386c233858874c5412345df6fd6ebf87298727dd.tar.gz rust-386c233858874c5412345df6fd6ebf87298727dd.zip | |
Make CodegenCx and Builder generic
Co-authored-by: Oli Scherer <github35764891676564198441@oli-obk.de>
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 38f7eaa090f..6b17b5f6989 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,20 +3,19 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; -use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_errors::FatalError; -use rustc_middle::ty::TyCtxt; use rustc_session::config::Lto; use tracing::{debug, trace}; use crate::back::write::{llvm_err, llvm_optimize}; -use crate::builder::Builder; -use crate::declare::declare_raw_fn; +use crate::builder::SBuilder; +use crate::context::SimpleCx; +use crate::declare::declare_simple_fn; use crate::errors::LlvmError; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True}; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm}; +use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; fn get_params(fnc: &Value) -> Vec<&Value> { unsafe { @@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> { /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/> // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll, 'tcx>( - cx: &context::CodegenCx<'ll, 'tcx>, +fn generate_enzyme_call<'ll>( + cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_fn: &'ll Value, attrs: AutoDiffAttrs, @@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>( //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and // think a bit more about what should go here. let cc = llvm::LLVMGetFunctionCallConv(outer_fn); - let ad_fn = declare_raw_fn( + let ad_fn = declare_simple_fn( cx, &ad_name, llvm::CallConv::try_from(cc).expect("invalid callconv"), @@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>( llvm::LLVMRustEraseInstFromParent(br); let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = Builder::build(cx, entry); + let mut builder = SBuilder::build(cx, entry); let num_args = llvm::LLVMCountParams(&fn_to_diff); let mut args = Vec::with_capacity(num_args as usize + 1); @@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>( } } - let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); + let call = builder.call(enzyme_ty, ad_fn, &args, None); // This part is a bit iffy. LLVM requires that a call to an inlineable function has some // metadata attachted to it, but we just created this code oota. Given that the @@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>( } } -pub(crate) fn differentiate<'ll, 'tcx>( +pub(crate) fn differentiate<'ll>( module: &'ll ModuleCodegen<ModuleLlvm>, cgcx: &CodegenContext<LlvmCodegenBackend>, - tcx: TyCtxt<'tcx>, diff_items: Vec<AutoDiffItem>, config: &ModuleConfig, ) -> Result<(), FatalError> { @@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>( } let diag_handler = cgcx.create_dcx(); - let (_, cgus) = tcx.collect_and_partition_mono_items(()); - let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm); + let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx }; // Before dumping the module, we want all the TypeTrees to become part of the module. for item in diff_items.iter() { |
