diff options
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 111 |
1 files changed, 48 insertions, 63 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 38f7eaa090f..b2c1088e3fc 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,20 +3,18 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; -use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_errors::FatalError; -use rustc_middle::ty::TyCtxt; -use rustc_session::config::Lto; use tracing::{debug, trace}; -use crate::back::write::{llvm_err, llvm_optimize}; -use crate::builder::Builder; -use crate::declare::declare_raw_fn; +use crate::back::write::llvm_err; +use crate::builder::SBuilder; +use crate::context::SimpleCx; +use crate::declare::declare_simple_fn; use crate::errors::LlvmError; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True}; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm}; +use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; fn get_params(fnc: &Value) -> Vec<&Value> { unsafe { @@ -38,8 +36,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> { /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/> // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll, 'tcx>( - cx: &context::CodegenCx<'ll, 'tcx>, +fn generate_enzyme_call<'ll>( + cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_fn: &'ll Value, attrs: AutoDiffAttrs, @@ -54,8 +52,6 @@ fn generate_enzyme_call<'ll, 'tcx>( let mut ad_name: String = match attrs.mode { DiffMode::Forward => "__enzyme_fwddiff", DiffMode::Reverse => "__enzyme_autodiff", - DiffMode::ForwardFirst => "__enzyme_fwddiff", - DiffMode::ReverseFirst => "__enzyme_autodiff", _ => panic!("logic bug in autodiff, unrecognized mode"), } .to_string(); @@ -63,8 +59,8 @@ fn generate_enzyme_call<'ll, 'tcx>( // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple // functions. Unwrap will only panic, if LLVM gave us an invalid string. let name = llvm::get_value_name(outer_fn); - let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap(); - ad_name.push_str(outer_fn_name.to_string().as_str()); + let outer_fn_name = std::str::from_utf8(name).unwrap(); + ad_name.push_str(outer_fn_name); // Let us assume the user wrote the following function square: // @@ -112,7 +108,7 @@ fn generate_enzyme_call<'ll, 'tcx>( //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and // think a bit more about what should go here. let cc = llvm::LLVMGetFunctionCallConv(outer_fn); - let ad_fn = declare_raw_fn( + let ad_fn = declare_simple_fn( cx, &ad_name, llvm::CallConv::try_from(cc).expect("invalid callconv"), @@ -132,7 +128,7 @@ fn generate_enzyme_call<'ll, 'tcx>( llvm::LLVMRustEraseInstFromParent(br); let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = Builder::build(cx, entry); + let mut builder = SBuilder::build(cx, entry); let num_args = llvm::LLVMCountParams(&fn_to_diff); let mut args = Vec::with_capacity(num_args as usize + 1); @@ -154,7 +150,7 @@ fn generate_enzyme_call<'ll, 'tcx>( _ => {} } - trace!("matching autodiff arguments"); + debug!("matching autodiff arguments"); // We now handle the issue that Rust level arguments not always match the llvm-ir level // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on // llvm-ir level. The number of activities matches the number of Rust level arguments, so we @@ -165,10 +161,10 @@ fn generate_enzyme_call<'ll, 'tcx>( let mut activity_pos = 0; let outer_args: Vec<&llvm::Value> = get_params(outer_fn); while activity_pos < inputs.len() { - let activity = inputs[activity_pos as usize]; + let diff_activity = inputs[activity_pos as usize]; // Duplicated arguments received a shadow argument, into which enzyme will write the // gradient. - let (activity, duplicated): (&Metadata, bool) = match activity { + let (activity, duplicated): (&Metadata, bool) = match diff_activity { DiffActivity::None => panic!("not a valid input activity"), DiffActivity::Const => (enzyme_const, false), DiffActivity::Active => (enzyme_out, false), @@ -223,7 +219,15 @@ fn generate_enzyme_call<'ll, 'tcx>( // A duplicated pointer will have the following two outer_fn arguments: // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: // (..., metadata! enzyme_dup, ptr, ptr, ...). - assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); + if matches!( + diff_activity, + DiffActivity::Duplicated | DiffActivity::DuplicatedOnly + ) { + assert!( + llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer + ); + } + // In the case of Dual we don't have assumptions, e.g. f32 would be valid. args.push(next_outer_arg); outer_pos += 2; activity_pos += 1; @@ -236,7 +240,7 @@ fn generate_enzyme_call<'ll, 'tcx>( } } - let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); + let call = builder.call(enzyme_ty, ad_fn, &args, None); // This part is a bit iffy. LLVM requires that a call to an inlineable function has some // metadata attachted to it, but we just created this code oota. Given that the @@ -256,14 +260,14 @@ fn generate_enzyme_call<'ll, 'tcx>( // have no debug info to copy, which would then be ok. trace!("no dbg info"); } + // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstBefore(entry, last_inst); - llvm::LLVMRustEraseInstFromParent(last_inst); + llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); - if cx.val_ty(outer_fn) != cx.type_void() { - builder.ret(call); - } else { + if cx.val_ty(call) == cx.type_void() { builder.ret_void(); + } else { + builder.ret(call); } // Let's crash in case that we messed something up above and generated invalid IR. @@ -274,40 +278,44 @@ fn generate_enzyme_call<'ll, 'tcx>( } } -pub(crate) fn differentiate<'ll, 'tcx>( +pub(crate) fn differentiate<'ll>( module: &'ll ModuleCodegen<ModuleLlvm>, cgcx: &CodegenContext<LlvmCodegenBackend>, - tcx: TyCtxt<'tcx>, diff_items: Vec<AutoDiffItem>, - config: &ModuleConfig, + _config: &ModuleConfig, ) -> Result<(), FatalError> { for item in &diff_items { trace!("{}", item); } let diag_handler = cgcx.create_dcx(); - let (_, cgus) = tcx.collect_and_partition_mono_items(()); - let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm); + let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx }; // Before dumping the module, we want all the TypeTrees to become part of the module. for item in diff_items.iter() { let name = item.source.clone(); let fn_def: Option<&llvm::Value> = cx.get_function(&name); let Some(fn_def) = fn_def else { - return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find source function".to_owned(), - })); + return Err(llvm_err( + diag_handler.handle(), + LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find source function".to_owned(), + }, + )); }; debug!(?item.target); let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); let Some(fn_target) = fn_target else { - return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find target function".to_owned(), - })); + return Err(llvm_err( + diag_handler.handle(), + LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find target function".to_owned(), + }, + )); }; generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); @@ -315,29 +323,6 @@ pub(crate) fn differentiate<'ll, 'tcx>( // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts - if let Some(opt_level) = config.opt_level { - let opt_stage = match cgcx.lto { - Lto::Fat => llvm::OptStage::PreLinkFatLTO, - Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, - _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, - _ => llvm::OptStage::PreLinkNoLTO, - }; - // This is our second opt call, so now we run all opts, - // to make sure we get the best performance. - let skip_size_increasing_opts = false; - trace!("running Module Optimization after differentiation"); - unsafe { - llvm_optimize( - cgcx, - diag_handler.handle(), - module, - config, - opt_level, - opt_stage, - skip_size_increasing_opts, - )? - }; - } trace!("done with differentiate()"); Ok(()) |
