diff options
Diffstat (limited to 'compiler/rustc_ast/src/expand/autodiff_attrs.rs')
| -rw-r--r-- | compiler/rustc_ast/src/expand/autodiff_attrs.rs | 283 |
1 files changed, 283 insertions, 0 deletions
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs new file mode 100644 index 00000000000..05714731b9d --- /dev/null +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -0,0 +1,283 @@ +//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute, +//! we create an [`AutoDiffItem`] which contains the source and target function names. The source +//! is the function to which the autodiff attribute is applied, and the target is the function +//! getting generated by us (with a name given by the user as the first autodiff arg). + +use std::fmt::{self, Display, Formatter}; +use std::str::FromStr; + +use crate::expand::typetree::TypeTree; +use crate::expand::{Decodable, Encodable, HashStable_Generic}; +use crate::ptr::P; +use crate::{Ty, TyKind}; + +/// Forward and Reverse Mode are well known names for automatic differentiation implementations. +/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants +/// are a hack to support higher order derivatives. We need to compute first order derivatives +/// before we compute second order derivatives, otherwise we would differentiate our placeholder +/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations, +/// as it's already done in the C++ and Julia frontend of Enzyme. +/// +/// (FIXME) remove *First variants. +/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and +/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online. +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffMode { + /// No autodiff is applied (used during error handling). + Error, + /// The primal function which we will differentiate. + Source, + /// The target function, to be created using forward mode AD. + Forward, + /// The target function, to be created using reverse mode AD. + Reverse, + /// The target function, to be created using forward mode AD. + /// This target function will also be used as a source for higher order derivatives, + /// so compute it before all Forward/Reverse targets and optimize it through llvm. + ForwardFirst, + /// The target function, to be created using reverse mode AD. + /// This target function will also be used as a source for higher order derivatives, + /// so compute it before all Forward/Reverse targets and optimize it through llvm. + ReverseFirst, +} + +/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. +/// However, under forward mode we overwrite the previous shadow value, while for reverse mode +/// we add to the previous shadow value. To not surprise users, we picked different names. +/// Dual numbers is also a quite well known name for forward mode AD types. +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffActivity { + /// Implicit or Explicit () return type, so a special case of Const. + None, + /// Don't compute derivatives with respect to this input/output. + Const, + /// Reverse Mode, Compute derivatives for this scalar input/output. + Active, + /// Reverse Mode, Compute derivatives for this scalar output, but don't compute + /// the original return value. + ActiveOnly, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. + Dual, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. Drop the code which updates the original input/output for maximum performance. + DualOnly, + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. + Duplicated, + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. + /// Drop the code which updates the original input for maximum performance. + DuplicatedOnly, + /// All Integers must be Const, but these are used to mark the integer which represents the + /// length of a slice/vec. This is used for safety checks on slices. + FakeActivitySize, +} +/// We generate one of these structs for each `#[autodiff(...)]` attribute. +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffItem { + /// The name of the function getting differentiated + pub source: String, + /// The name of the function being generated + pub target: String, + pub attrs: AutoDiffAttrs, + /// Describe the memory layout of input types + pub inputs: Vec<TypeTree>, + /// Describe the memory layout of the output type + pub output: TypeTree, +} +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffAttrs { + /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and + /// e.g. in the [JAX + /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec<DiffActivity>, +} + +impl DiffMode { + pub fn is_rev(&self) -> bool { + matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst) + } + pub fn is_fwd(&self) -> bool { + matches!(self, DiffMode::Forward | DiffMode::ForwardFirst) + } +} + +impl Display for DiffMode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + DiffMode::Error => write!(f, "Error"), + DiffMode::Source => write!(f, "Source"), + DiffMode::Forward => write!(f, "Forward"), + DiffMode::Reverse => write!(f, "Reverse"), + DiffMode::ForwardFirst => write!(f, "ForwardFirst"), + DiffMode::ReverseFirst => write!(f, "ReverseFirst"), + } + } +} + +/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...). +/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...). +/// Const is valid for all cases and means that we don't compute derivatives wrt. this output. +/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg, +/// but this is too complex to verify here. Also it's just a logic error if users get this wrong. +pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { + if activity == DiffActivity::None { + // Only valid if primal returns (), but we can't check that here. + return true; + } + match mode { + DiffMode::Error => false, + DiffMode::Source => false, + DiffMode::Forward | DiffMode::ForwardFirst => { + activity == DiffActivity::Dual + || activity == DiffActivity::DualOnly + || activity == DiffActivity::Const + } + DiffMode::Reverse | DiffMode::ReverseFirst => { + activity == DiffActivity::Const + || activity == DiffActivity::Active + || activity == DiffActivity::ActiveOnly + } + } +} + +/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value +/// for the given argument, but we generally can't know the size of such a type. +/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated, +/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value +/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent +/// users here from marking scalars as Duplicated, due to type aliases. +pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool { + use DiffActivity::*; + // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it. + if matches!(activity, Const) { + return true; + } + if matches!(activity, Dual | DualOnly) { + return true; + } + // FIXME(ZuseZ4) We should make this more robust to also + // handle type aliases. Once that is done, we can be more restrictive here. + if matches!(activity, Active | ActiveOnly) { + return true; + } + matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..)) + && matches!(activity, Duplicated | DuplicatedOnly) +} +pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { + use DiffActivity::*; + return match mode { + DiffMode::Error => false, + DiffMode::Source => false, + DiffMode::Forward | DiffMode::ForwardFirst => { + matches!(activity, Dual | DualOnly | Const) + } + DiffMode::Reverse | DiffMode::ReverseFirst => { + matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const) + } + }; +} + +impl Display for DiffActivity { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DiffActivity::None => write!(f, "None"), + DiffActivity::Const => write!(f, "Const"), + DiffActivity::Active => write!(f, "Active"), + DiffActivity::ActiveOnly => write!(f, "ActiveOnly"), + DiffActivity::Dual => write!(f, "Dual"), + DiffActivity::DualOnly => write!(f, "DualOnly"), + DiffActivity::Duplicated => write!(f, "Duplicated"), + DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), + DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"), + } + } +} + +impl FromStr for DiffMode { + type Err = (); + + fn from_str(s: &str) -> Result<DiffMode, ()> { + match s { + "Error" => Ok(DiffMode::Error), + "Source" => Ok(DiffMode::Source), + "Forward" => Ok(DiffMode::Forward), + "Reverse" => Ok(DiffMode::Reverse), + "ForwardFirst" => Ok(DiffMode::ForwardFirst), + "ReverseFirst" => Ok(DiffMode::ReverseFirst), + _ => Err(()), + } + } +} +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result<DiffActivity, ()> { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "ActiveOnly" => Ok(DiffActivity::ActiveOnly), + "Const" => Ok(DiffActivity::Const), + "Dual" => Ok(DiffActivity::Dual), + "DualOnly" => Ok(DiffActivity::DualOnly), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly), + _ => Err(()), + } + } +} + +impl AutoDiffAttrs { + pub fn has_ret_activity(&self) -> bool { + self.ret_activity != DiffActivity::None + } + pub fn has_active_only_ret(&self) -> bool { + self.ret_activity == DiffActivity::ActiveOnly + } + + pub fn error() -> Self { + AutoDiffAttrs { + mode: DiffMode::Error, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + pub fn source() -> Self { + AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + self.mode != DiffMode::Error + } + + pub fn is_source(&self) -> bool { + self.mode == DiffMode::Source + } + pub fn apply_autodiff(&self) -> bool { + !matches!(self.mode, DiffMode::Error | DiffMode::Source) + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec<TypeTree>, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +impl fmt::Display for AutoDiffItem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Differentiating {} -> {}", self.source, self.target)?; + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) + } +} |
