about summary refs log tree commit diff
path: root/compiler/rustc_ast/src/expand/autodiff_attrs.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_ast/src/expand/autodiff_attrs.rs')
-rw-r--r--compiler/rustc_ast/src/expand/autodiff_attrs.rs283
1 files changed, 283 insertions, 0 deletions
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
new file mode 100644
index 00000000000..05714731b9d
--- /dev/null
+++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
@@ -0,0 +1,283 @@
+//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
+//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
+//! is the function to which the autodiff attribute is applied, and the target is the function
+//! getting generated by us (with a name given by the user as the first autodiff arg).
+
+use std::fmt::{self, Display, Formatter};
+use std::str::FromStr;
+
+use crate::expand::typetree::TypeTree;
+use crate::expand::{Decodable, Encodable, HashStable_Generic};
+use crate::ptr::P;
+use crate::{Ty, TyKind};
+
+/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
+/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
+/// are a hack to support higher order derivatives. We need to compute first order derivatives
+/// before we compute second order derivatives, otherwise we would differentiate our placeholder
+/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
+/// as it's already done in the C++ and Julia frontend of Enzyme.
+///
+/// (FIXME) remove *First variants.
+/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
+/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
+#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub enum DiffMode {
+    /// No autodiff is applied (used during error handling).
+    Error,
+    /// The primal function which we will differentiate.
+    Source,
+    /// The target function, to be created using forward mode AD.
+    Forward,
+    /// The target function, to be created using reverse mode AD.
+    Reverse,
+    /// The target function, to be created using forward mode AD.
+    /// This target function will also be used as a source for higher order derivatives,
+    /// so compute it before all Forward/Reverse targets and optimize it through llvm.
+    ForwardFirst,
+    /// The target function, to be created using reverse mode AD.
+    /// This target function will also be used as a source for higher order derivatives,
+    /// so compute it before all Forward/Reverse targets and optimize it through llvm.
+    ReverseFirst,
+}
+
+/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
+/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
+/// we add to the previous shadow value. To not surprise users, we picked different names.
+/// Dual numbers is also a quite well known name for forward mode AD types.
+#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub enum DiffActivity {
+    /// Implicit or Explicit () return type, so a special case of Const.
+    None,
+    /// Don't compute derivatives with respect to this input/output.
+    Const,
+    /// Reverse Mode, Compute derivatives for this scalar input/output.
+    Active,
+    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
+    /// the original return value.
+    ActiveOnly,
+    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
+    /// with it.
+    Dual,
+    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
+    /// with it. Drop the code which updates the original input/output for maximum performance.
+    DualOnly,
+    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
+    Duplicated,
+    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
+    /// Drop the code which updates the original input for maximum performance.
+    DuplicatedOnly,
+    /// All Integers must be Const, but these are used to mark the integer which represents the
+    /// length of a slice/vec. This is used for safety checks on slices.
+    FakeActivitySize,
+}
+/// We generate one of these structs for each `#[autodiff(...)]` attribute.
+#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub struct AutoDiffItem {
+    /// The name of the function getting differentiated
+    pub source: String,
+    /// The name of the function being generated
+    pub target: String,
+    pub attrs: AutoDiffAttrs,
+    /// Describe the memory layout of input types
+    pub inputs: Vec<TypeTree>,
+    /// Describe the memory layout of the output type
+    pub output: TypeTree,
+}
+#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
+pub struct AutoDiffAttrs {
+    /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
+    /// e.g. in the [JAX
+    /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
+    pub mode: DiffMode,
+    pub ret_activity: DiffActivity,
+    pub input_activity: Vec<DiffActivity>,
+}
+
+impl DiffMode {
+    pub fn is_rev(&self) -> bool {
+        matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
+    }
+    pub fn is_fwd(&self) -> bool {
+        matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
+    }
+}
+
+impl Display for DiffMode {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        match self {
+            DiffMode::Error => write!(f, "Error"),
+            DiffMode::Source => write!(f, "Source"),
+            DiffMode::Forward => write!(f, "Forward"),
+            DiffMode::Reverse => write!(f, "Reverse"),
+            DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
+            DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
+        }
+    }
+}
+
+/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
+/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
+/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
+/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
+/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
+pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
+    if activity == DiffActivity::None {
+        // Only valid if primal returns (), but we can't check that here.
+        return true;
+    }
+    match mode {
+        DiffMode::Error => false,
+        DiffMode::Source => false,
+        DiffMode::Forward | DiffMode::ForwardFirst => {
+            activity == DiffActivity::Dual
+                || activity == DiffActivity::DualOnly
+                || activity == DiffActivity::Const
+        }
+        DiffMode::Reverse | DiffMode::ReverseFirst => {
+            activity == DiffActivity::Const
+                || activity == DiffActivity::Active
+                || activity == DiffActivity::ActiveOnly
+        }
+    }
+}
+
+/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
+/// for the given argument, but we generally can't know the size of such a type.
+/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
+/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
+/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
+/// users here from marking scalars as Duplicated, due to type aliases.
+pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
+    use DiffActivity::*;
+    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
+    if matches!(activity, Const) {
+        return true;
+    }
+    if matches!(activity, Dual | DualOnly) {
+        return true;
+    }
+    // FIXME(ZuseZ4) We should make this more robust to also
+    // handle type aliases. Once that is done, we can be more restrictive here.
+    if matches!(activity, Active | ActiveOnly) {
+        return true;
+    }
+    matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
+        && matches!(activity, Duplicated | DuplicatedOnly)
+}
+pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
+    use DiffActivity::*;
+    return match mode {
+        DiffMode::Error => false,
+        DiffMode::Source => false,
+        DiffMode::Forward | DiffMode::ForwardFirst => {
+            matches!(activity, Dual | DualOnly | Const)
+        }
+        DiffMode::Reverse | DiffMode::ReverseFirst => {
+            matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
+        }
+    };
+}
+
+impl Display for DiffActivity {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            DiffActivity::None => write!(f, "None"),
+            DiffActivity::Const => write!(f, "Const"),
+            DiffActivity::Active => write!(f, "Active"),
+            DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
+            DiffActivity::Dual => write!(f, "Dual"),
+            DiffActivity::DualOnly => write!(f, "DualOnly"),
+            DiffActivity::Duplicated => write!(f, "Duplicated"),
+            DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
+            DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
+        }
+    }
+}
+
+impl FromStr for DiffMode {
+    type Err = ();
+
+    fn from_str(s: &str) -> Result<DiffMode, ()> {
+        match s {
+            "Error" => Ok(DiffMode::Error),
+            "Source" => Ok(DiffMode::Source),
+            "Forward" => Ok(DiffMode::Forward),
+            "Reverse" => Ok(DiffMode::Reverse),
+            "ForwardFirst" => Ok(DiffMode::ForwardFirst),
+            "ReverseFirst" => Ok(DiffMode::ReverseFirst),
+            _ => Err(()),
+        }
+    }
+}
+impl FromStr for DiffActivity {
+    type Err = ();
+
+    fn from_str(s: &str) -> Result<DiffActivity, ()> {
+        match s {
+            "None" => Ok(DiffActivity::None),
+            "Active" => Ok(DiffActivity::Active),
+            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
+            "Const" => Ok(DiffActivity::Const),
+            "Dual" => Ok(DiffActivity::Dual),
+            "DualOnly" => Ok(DiffActivity::DualOnly),
+            "Duplicated" => Ok(DiffActivity::Duplicated),
+            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
+            _ => Err(()),
+        }
+    }
+}
+
+impl AutoDiffAttrs {
+    pub fn has_ret_activity(&self) -> bool {
+        self.ret_activity != DiffActivity::None
+    }
+    pub fn has_active_only_ret(&self) -> bool {
+        self.ret_activity == DiffActivity::ActiveOnly
+    }
+
+    pub fn error() -> Self {
+        AutoDiffAttrs {
+            mode: DiffMode::Error,
+            ret_activity: DiffActivity::None,
+            input_activity: Vec::new(),
+        }
+    }
+    pub fn source() -> Self {
+        AutoDiffAttrs {
+            mode: DiffMode::Source,
+            ret_activity: DiffActivity::None,
+            input_activity: Vec::new(),
+        }
+    }
+
+    pub fn is_active(&self) -> bool {
+        self.mode != DiffMode::Error
+    }
+
+    pub fn is_source(&self) -> bool {
+        self.mode == DiffMode::Source
+    }
+    pub fn apply_autodiff(&self) -> bool {
+        !matches!(self.mode, DiffMode::Error | DiffMode::Source)
+    }
+
+    pub fn into_item(
+        self,
+        source: String,
+        target: String,
+        inputs: Vec<TypeTree>,
+        output: TypeTree,
+    ) -> AutoDiffItem {
+        AutoDiffItem { source, target, inputs, output, attrs: self }
+    }
+}
+
+impl fmt::Display for AutoDiffItem {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "Differentiating {} -> {}", self.source, self.target)?;
+        write!(f, " with attributes: {:?}", self.attrs)?;
+        write!(f, " with inputs: {:?}", self.inputs)?;
+        write!(f, " with output: {:?}", self.output)
+    }
+}