about summary refs log tree commit diff
path: root/compiler/rustc_codegen_ssa/src/back
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_ssa/src/back')
-rw-r--r--compiler/rustc_codegen_ssa/src/back/write.rs38
1 files changed, 33 insertions, 5 deletions
diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs
index b40bb4ed5d2..914f2c21fa7 100644
--- a/compiler/rustc_codegen_ssa/src/back/write.rs
+++ b/compiler/rustc_codegen_ssa/src/back/write.rs
@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
 use std::{fs, io, mem, str, thread};
 
 use rustc_ast::attr;
+use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
 use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
 use rustc_data_structures::jobserver::{self, Acquired};
 use rustc_data_structures::memmap::Mmap;
@@ -40,7 +41,7 @@ use tracing::debug;
 use super::link::{self, ensure_removed};
 use super::lto::{self, SerializedModule};
 use super::symbol_export::symbol_name_for_instance_in_crate;
-use crate::errors::ErrorCreatingRemarkDir;
+use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
 use crate::traits::*;
 use crate::{
     CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
     pub merge_functions: bool,
     pub emit_lifetime_markers: bool,
     pub llvm_plugins: Vec<String>,
+    pub autodiff: Vec<config::AutoDiff>,
 }
 
 impl ModuleConfig {
@@ -266,6 +268,7 @@ impl ModuleConfig {
 
             emit_lifetime_markers: sess.emit_lifetime_markers(),
             llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
+            autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
         }
     }
 
@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
 
 fn generate_lto_work<B: ExtraBackendMethods>(
     cgcx: &CodegenContext<B>,
+    autodiff: Vec<AutoDiffItem>,
     needs_fat_lto: Vec<FatLtoInput<B>>,
     needs_thin_lto: Vec<(String, B::ThinBuffer)>,
     import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
@@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
 
     if !needs_fat_lto.is_empty() {
         assert!(needs_thin_lto.is_empty());
-        let module =
+        let mut module =
             B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
+        if cgcx.lto == Lto::Fat {
+            let config = cgcx.config(ModuleKind::Regular);
+            module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
+        }
         // We are adding a single work item, so the cost doesn't matter.
         vec![(WorkItem::LTO(module), 0)]
     } else {
+        if !autodiff.is_empty() {
+            let dcx = cgcx.create_dcx();
+            dcx.handle().emit_fatal(AutodiffWithoutLto {});
+        }
         assert!(needs_fat_lto.is_empty());
         let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
             .unwrap_or_else(|e| e.raise());
@@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
     /// Sent from a backend worker thread.
     WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
 
+    /// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
+    AddAutoDiffItems(Vec<AutoDiffItem>),
+
     /// The frontend has finished generating something (backend IR or a
     /// post-LTO artifact) for a codegen unit, and it should be passed to the
     /// backend. Sent from the main thread.
@@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
 
         // This is where we collect codegen units that have gone all the way
         // through codegen and LLVM.
+        let mut autodiff_items = Vec::new();
         let mut compiled_modules = vec![];
         let mut compiled_allocator_module = None;
         let mut needs_link = Vec::new();
@@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
                     let needs_thin_lto = mem::take(&mut needs_thin_lto);
                     let import_only_modules = mem::take(&mut lto_import_only_modules);
 
-                    for (work, cost) in
-                        generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
-                    {
+                    for (work, cost) in generate_lto_work(
+                        &cgcx,
+                        autodiff_items.clone(),
+                        needs_fat_lto,
+                        needs_thin_lto,
+                        import_only_modules,
+                    ) {
                         let insertion_index = work_items
                             .binary_search_by_key(&cost, |&(_, cost)| cost)
                             .unwrap_or_else(|e| e);
@@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
                     main_thread_state = MainThreadState::Idle;
                 }
 
+                Message::AddAutoDiffItems(mut items) => {
+                    autodiff_items.append(&mut items);
+                }
+
                 Message::CodegenComplete => {
                     if codegen_state != Aborted {
                         codegen_state = Completed;
@@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
         drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
     }
 
+    pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
+        drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
+    }
+
     pub(crate) fn check_for_errors(&self, sess: &Session) {
         self.shared_emitter_main.check(sess, false);
     }