about summary refs log tree commit diff
path: root/compiler/rustc_codegen_ssa/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_ssa/src')
-rw-r--r--compiler/rustc_codegen_ssa/src/back/lto.rs21
-rw-r--r--compiler/rustc_codegen_ssa/src/traits/write.rs9
2 files changed, 30 insertions, 0 deletions
diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs
index ab8b06a05fc..efccf7687a1 100644
--- a/compiler/rustc_codegen_ssa/src/back/lto.rs
+++ b/compiler/rustc_codegen_ssa/src/back/lto.rs
@@ -1,11 +1,14 @@
 use std::ffi::CString;
 use std::sync::Arc;
 
+use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
 use rustc_data_structures::memmap::Mmap;
 use rustc_errors::FatalError;
+use rustc_middle::ty::TyCtxt;
 
 use super::write::CodegenContext;
 use crate::ModuleCodegen;
+use crate::back::write::ModuleConfig;
 use crate::traits::*;
 
 pub struct ThinModule<B: WriteBackendMethods> {
@@ -81,6 +84,24 @@ impl<B: WriteBackendMethods> LtoModuleCodegen<B> {
             LtoModuleCodegen::Thin(ref m) => m.cost(),
         }
     }
+
+    /// Run autodiff on Fat LTO module
+    pub unsafe fn autodiff(
+        self,
+        cgcx: &CodegenContext<B>,
+        tcx: TyCtxt<'_>,
+        diff_fncs: Vec<AutoDiffItem>,
+        config: &ModuleConfig,
+    ) -> Result<LtoModuleCodegen<B>, FatalError> {
+        match &self {
+            LtoModuleCodegen::Fat(module) => {
+                B::autodiff(cgcx, tcx, &module, diff_fncs, config)?;
+            }
+            _ => panic!("autodiff called with non-fat LTO module"),
+        }
+
+        Ok(self)
+    }
 }
 
 pub enum SerializedModule<M: ModuleBufferMethods> {
diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs
index aabe9e33c4a..51e2255efe1 100644
--- a/compiler/rustc_codegen_ssa/src/traits/write.rs
+++ b/compiler/rustc_codegen_ssa/src/traits/write.rs
@@ -1,5 +1,7 @@
+use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
 use rustc_errors::{DiagCtxtHandle, FatalError};
 use rustc_middle::dep_graph::WorkProduct;
+use rustc_middle::ty::TyCtxt;
 
 use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
 use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig};
@@ -61,6 +63,13 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
         want_summary: bool,
     ) -> (String, Self::ThinBuffer);
     fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer);
+    fn autodiff(
+        cgcx: &CodegenContext<Self>,
+        tcx: TyCtxt<'_>,
+        module: &ModuleCodegen<Self::Module>,
+        diff_fncs: Vec<AutoDiffItem>,
+        config: &ModuleConfig,
+    ) -> Result<(), FatalError>;
 }
 
 pub trait ThinBufferMethods: Send + Sync {