diff options
| author | Jacob Pratt <jacob@jhpratt.dev> | 2025-01-31 00:26:30 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-31 00:26:30 -0500 |
| commit | c19c4b91f53dbe9bda18171b493514ed2bbec2b2 (patch) | |
| tree | df4f28f749f427d1e391919bc5cb937fe6f247c0 /compiler/rustc_middle | |
| parent | ae9dbf169fa43b45617107d8bd90eced26b1dea2 (diff) | |
| parent | 1f30517d40a9a8fe3b89479891c7a167adb75cbd (diff) | |
| download | rust-c19c4b91f53dbe9bda18171b493514ed2bbec2b2.tar.gz rust-c19c4b91f53dbe9bda18171b493514ed2bbec2b2.zip | |
Rollup merge of #133429 - EnzymeAD:autodiff-middle, r=oli-obk
Autodiff Upstreaming - rustc_codegen_ssa, rustc_middle This PR should not be merged until the rustc_codegen_llvm part is merged. I will also alter it a little based on what get's shaved off from the cg_llvm PR, and address some of the feedback I received in the other PR (including cleanups). I am putting it already up to 1) Discuss with `@jieyouxu` if there is more work needed to add tests to this and 2) Pray that there is someone reviewing who can tell me why some of my autodiff invocations get lost. Re 1: My test require fat-lto. I also modify the compilation pipeline. So if there are any other llvm-ir tests in the same compilation unit then I will likely break them. Luckily there are two groups who currently have the same fat-lto requirement for their GPU code which I have for my autodiff code and both groups have some plans to enable support for thin-lto. Once either that work pans out, I'll copy it over for this feature. I will also work on not changing the optimization pipeline for functions not differentiated, but that will require some thoughts and engineering, so I think it would be good to be able to run the autodiff tests isolated from the rest for now. Can you guide me here please? For context, here are some of my tests in the samples folder: https://github.com/EnzymeAD/rustbook Re 2: This is a pretty serious issue, since it effectively prevents publishing libraries making use of autodiff: https://github.com/EnzymeAD/rust/issues/173. For some reason my dummy code persists till the end, so the code which calls autodiff, deletes the dummy, and inserts the code to compute the derivative never gets executed. To me it looks like the rustc_autodiff attribute just get's dropped, but I don't know WHY? Any help would be super appreciated, as rustc queries look a bit voodoo to me. Tracking: - https://github.com/rust-lang/rust/issues/124509 r? `@jieyouxu`
Diffstat (limited to 'compiler/rustc_middle')
| -rw-r--r-- | compiler/rustc_middle/messages.ftl | 4 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/arena.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/error.rs | 14 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/middle/codegen_fn_attrs.rs | 4 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/mir/mono.rs | 2 |
5 files changed, 25 insertions, 0 deletions
diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 8dc529b4d7a..cfb495df9d5 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -32,6 +32,8 @@ middle_assert_shl_overflow = middle_assert_shr_overflow = attempt to shift right by `{$val}`, which would overflow +middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe + middle_bounds_check = index out of bounds: the length is {$len} but the index is {$index} @@ -107,6 +109,8 @@ middle_type_length_limit = reached the type-length limit while instantiating `{$ middle_unknown_layout = the type `{$ty}` has an unknown layout +middle_unsupported_union = we don't support unions yet: '{$ty_name}' + middle_values_too_big = values of the type `{$ty}` are too big for the target architecture diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 750531b638e..eaccd8c360e 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -87,6 +87,7 @@ macro_rules! arena_types { [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, [decode] attribute: rustc_hir::Attribute, [] name_set: rustc_data_structures::unord::UnordSet<rustc_span::Symbol>, + [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>, [] pats: rustc_middle::ty::PatternKind<'tcx>, diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index b0187a1848c..b30d3a950c6 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -37,6 +37,20 @@ pub struct OpaqueHiddenTypeMismatch<'tcx> { pub sub: TypeMismatchReason, } +#[derive(Diagnostic)] +#[diag(middle_unsupported_union)] +pub struct UnsupportedUnion { + pub ty_name: String, +} + +#[derive(Diagnostic)] +#[diag(middle_autodiff_unsafe_inner_const_ref)] +pub struct AutodiffUnsafeInnerConstRef { + #[primary_span] + pub span: Span, + pub ty: String, +} + #[derive(Subdiagnostic)] pub enum TypeMismatchReason { #[label(middle_conflict_types)] diff --git a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs index cc980f6e62a..1784665bcae 100644 --- a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs +++ b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs @@ -1,4 +1,5 @@ use rustc_abi::Align; +use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_macros::{HashStable, TyDecodable, TyEncodable}; use rustc_span::Symbol; @@ -52,6 +53,8 @@ pub struct CodegenFnAttrs { /// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around /// the function entry. pub patchable_function_entry: Option<PatchableFunctionEntry>, + /// For the `#[autodiff]` macros. + pub autodiff_item: Option<AutoDiffAttrs>, } #[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)] @@ -160,6 +163,7 @@ impl CodegenFnAttrs { instruction_set: None, alignment: None, patchable_function_entry: None, + autodiff_item: None, } } diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index 6fa3fa2432d..e49fc376a13 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -1,6 +1,7 @@ use std::fmt; use std::hash::Hash; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_attr_parsing::InlineAttr; use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN}; use rustc_data_structures::fingerprint::Fingerprint; @@ -246,6 +247,7 @@ impl ToStableHashKey<StableHashingContext<'_>> for MonoItem<'_> { pub struct MonoItemPartitions<'tcx> { pub codegen_units: &'tcx [CodegenUnit<'tcx>], pub all_mono_items: &'tcx DefIdSet, + pub autodiff_items: &'tcx [AutoDiffItem], } #[derive(Debug, HashStable)] |
