From 8f648d7185a8455151951f30357b93dee726c6cb Mon Sep 17 00:00:00 2001 From: bjorn3 <17426603+bjorn3@users.noreply.github.com> Date: Fri, 8 Aug 2025 11:24:28 +0000 Subject: Remove bitcode_llvm_cmdline It used to be necessary on Apple platforms to ship with the App Store, but XCode 15 has stopped embedding LLVM bitcode and the App Store no longer accepts apps with bitcode embedded. --- compiler/rustc_codegen_llvm/src/back/write.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'compiler/rustc_codegen_llvm/src') diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 85a06f457eb..62998003ca1 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -862,7 +862,7 @@ pub(crate) fn codegen( .generic_activity_with_arg("LLVM_module_codegen_embed_bitcode", &*module.name); let thin_bc = module.thin_lto_buffer.as_deref().expect("cannot find embedded bitcode"); - embed_bitcode(cgcx, llcx, llmod, &config.bc_cmdline, &thin_bc); + embed_bitcode(cgcx, llcx, llmod, &thin_bc); } } @@ -1058,7 +1058,6 @@ fn embed_bitcode( cgcx: &CodegenContext, llcx: &llvm::Context, llmod: &llvm::Module, - cmdline: &str, bitcode: &[u8], ) { // We're adding custom sections to the output object file, but we definitely @@ -1074,7 +1073,9 @@ fn embed_bitcode( // * Mach-O - this is for macOS. Inspecting the source code for the native // linker here shows that the `.llvmbc` and `.llvmcmd` sections are // automatically skipped by the linker. In that case there's nothing extra - // that we need to do here. + // that we need to do here. We do need to make sure that the + // `__LLVM,__cmdline` section exists even though it is empty as otherwise + // ld64 rejects the object file. // // * Wasm - the native LLD linker is hard-coded to skip `.llvmbc` and // `.llvmcmd` sections, so there's nothing extra we need to do. @@ -1111,7 +1112,7 @@ fn embed_bitcode( llvm::set_linkage(llglobal, llvm::Linkage::PrivateLinkage); llvm::LLVMSetGlobalConstant(llglobal, llvm::True); - let llconst = common::bytes_in_context(llcx, cmdline.as_bytes()); + let llconst = common::bytes_in_context(llcx, &[]); let llglobal = llvm::add_global(llmod, common::val_ty(llconst), c"rustc.embedded.cmdline"); llvm::set_initializer(llglobal, llconst); let section = if cgcx.target_is_like_darwin { @@ -1128,7 +1129,7 @@ fn embed_bitcode( let section_flags = if cgcx.is_pe_coff { "n" } else { "e" }; let asm = create_section_with_flags_asm(".llvmbc", section_flags, bitcode); llvm::append_module_inline_asm(llmod, &asm); - let asm = create_section_with_flags_asm(".llvmcmd", section_flags, cmdline.as_bytes()); + let asm = create_section_with_flags_asm(".llvmcmd", section_flags, &[]); llvm::append_module_inline_asm(llmod, &asm); } } -- cgit 1.4.1-3-g733a5 From 258915a55539593423c3d4c30f7b741f65c56e51 Mon Sep 17 00:00:00 2001 From: Matthew Maurer Date: Fri, 8 Aug 2025 17:38:03 +0000 Subject: llvm: Accept new LLVM lifetime format LLVM removed the size parameter from the lifetime format. Tolerate not having that size parameter. --- compiler/rustc_codegen_llvm/src/builder.rs | 6 +++++- .../codegen-llvm/align-byval-alignment-mismatch.rs | 4 ++-- tests/codegen-llvm/call-tmps-lifetime.rs | 16 +++++++-------- tests/codegen-llvm/intrinsics/transmute-simd.rs | 24 +++++++++++----------- tests/codegen-llvm/intrinsics/transmute.rs | 8 ++++---- .../issues/issue-105386-ub-in-debuginfo.rs | 2 +- tests/codegen-llvm/lifetime_start_end.rs | 16 +++++++-------- .../uninhabited-transparent-return-abi.rs | 4 ++-- 8 files changed, 42 insertions(+), 38 deletions(-) (limited to 'compiler/rustc_codegen_llvm/src') diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index bd175e560c7..35a619da710 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -1686,7 +1686,11 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> { return; } - self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[self.cx.const_u64(size), ptr]); + if crate::llvm_util::get_version() >= (22, 0, 0) { + self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[ptr]); + } else { + self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[self.cx.const_u64(size), ptr]); + } } } impl<'a, 'll, CX: Borrow>> GenericBuilder<'a, 'll, CX> { diff --git a/tests/codegen-llvm/align-byval-alignment-mismatch.rs b/tests/codegen-llvm/align-byval-alignment-mismatch.rs index c69fc2de9d2..8eaf2751ed7 100644 --- a/tests/codegen-llvm/align-byval-alignment-mismatch.rs +++ b/tests/codegen-llvm/align-byval-alignment-mismatch.rs @@ -55,10 +55,10 @@ extern "C" { pub unsafe fn rust_to_c_increases_alignment(x: Align1) { // i686-linux: start: // i686-linux-NEXT: [[ALLOCA:%[0-9a-z]+]] = alloca [48 x i8], align 4 - // i686-linux-NEXT: call void @llvm.lifetime.start.p0(i64 48, ptr {{.*}}[[ALLOCA]]) + // i686-linux-NEXT: call void @llvm.lifetime.start.p0({{(i64 48, )?}}ptr {{.*}}[[ALLOCA]]) // i686-linux-NEXT: call void @llvm.memcpy.{{.+}}(ptr {{.*}}align 4 {{.*}}[[ALLOCA]], ptr {{.*}}align 1 {{.*}}%x // i686-linux-NEXT: call void @extern_c_align1({{.+}} [[ALLOCA]]) - // i686-linux-NEXT: call void @llvm.lifetime.end.p0(i64 48, ptr {{.*}}[[ALLOCA]]) + // i686-linux-NEXT: call void @llvm.lifetime.end.p0({{(i64 48, )?}}ptr {{.*}}[[ALLOCA]]) // x86_64-linux: start: // x86_64-linux-NEXT: call void @extern_c_align1 diff --git a/tests/codegen-llvm/call-tmps-lifetime.rs b/tests/codegen-llvm/call-tmps-lifetime.rs index 7b7b6e17bdd..0d7657ed758 100644 --- a/tests/codegen-llvm/call-tmps-lifetime.rs +++ b/tests/codegen-llvm/call-tmps-lifetime.rs @@ -16,14 +16,14 @@ use minicore::*; // CHECK-NEXT: start: // CHECK-NEXT: [[B:%.*]] = alloca // CHECK-NEXT: [[A:%.*]] = alloca -// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4096, ptr [[A]]) +// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 4096, )?}}ptr [[A]]) // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 4096, i1 false) // CHECK-NEXT: call void %h(ptr {{.*}} [[A]]) -// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4096, ptr [[A]]) -// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4096, ptr [[B]]) +// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 4096, )?}}ptr [[A]]) +// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 4096, )?}}ptr [[B]]) // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 4096, i1 false) // CHECK-NEXT: call void %h(ptr {{.*}} [[B]]) -// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4096, ptr [[B]]) +// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 4096, )?}}ptr [[B]]) #[no_mangle] pub fn const_indirect(h: extern "C" fn([u32; 1024])) { const C: [u32; 1024] = [0; 1024]; @@ -42,12 +42,12 @@ pub struct Str { // CHECK-LABEL: define void @immediate_indirect(ptr {{.*}}%s.0, i32 {{.*}}%s.1, ptr {{.*}}%g) // CHECK-NEXT: start: // CHECK-NEXT: [[A:%.*]] = alloca -// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr [[A]]) +// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 8, )?}}ptr [[A]]) // CHECK-NEXT: store ptr %s.0, ptr [[A]] // CHECK-NEXT: [[B:%.]] = getelementptr inbounds i8, ptr [[A]], i32 4 // CHECK-NEXT: store i32 %s.1, ptr [[B]] // CHECK-NEXT: call void %g(ptr {{.*}} [[A]]) -// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr [[A]]) +// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 8, )?}}ptr [[A]]) #[no_mangle] pub fn immediate_indirect(s: Str, g: extern "C" fn(Str)) { g(s); @@ -58,10 +58,10 @@ pub fn immediate_indirect(s: Str, g: extern "C" fn(Str)) { // CHECK-LABEL: define void @align_indirect(ptr{{.*}} align 1{{.*}} %a, ptr{{.*}} %fun) // CHECK-NEXT: start: // CHECK-NEXT: [[A:%.*]] = alloca [1024 x i8], align 4 -// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 1024, ptr [[A]]) +// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 1024, )?}}ptr [[A]]) // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 1 %a, i32 1024, i1 false) // CHECK-NEXT: call void %fun(ptr {{.*}} [[A]]) -// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 1024, ptr [[A]]) +// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 1024, )?}}ptr [[A]]) #[no_mangle] pub fn align_indirect(a: [u8; 1024], fun: extern "C" fn([u8; 1024])) { fun(a); diff --git a/tests/codegen-llvm/intrinsics/transmute-simd.rs b/tests/codegen-llvm/intrinsics/transmute-simd.rs index e34b27e1333..92af16945ab 100644 --- a/tests/codegen-llvm/intrinsics/transmute-simd.rs +++ b/tests/codegen-llvm/intrinsics/transmute-simd.rs @@ -97,10 +97,10 @@ pub extern "C" fn float_ptr_same_lanes(v: f64x2) -> PtrX2 { // CHECK-NOT: alloca // CHECK: %[[TEMP:.+]] = alloca [16 x i8] // CHECK-NOT: alloca - // CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: store <2 x double> %v, ptr %[[TEMP]] // CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]] - // CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: ret <2 x ptr> %[[RET]] unsafe { transmute(v) } } @@ -111,10 +111,10 @@ pub extern "C" fn ptr_float_same_lanes(v: PtrX2) -> f64x2 { // CHECK-NOT: alloca // CHECK: %[[TEMP:.+]] = alloca [16 x i8] // CHECK-NOT: alloca - // CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: store <2 x ptr> %v, ptr %[[TEMP]] // CHECK: %[[RET:.+]] = load <2 x double>, ptr %[[TEMP]] - // CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: ret <2 x double> %[[RET]] unsafe { transmute(v) } } @@ -125,10 +125,10 @@ pub extern "C" fn int_ptr_same_lanes(v: i64x2) -> PtrX2 { // CHECK-NOT: alloca // CHECK: %[[TEMP:.+]] = alloca [16 x i8] // CHECK-NOT: alloca - // CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: store <2 x i64> %v, ptr %[[TEMP]] // CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]] - // CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: ret <2 x ptr> %[[RET]] unsafe { transmute(v) } } @@ -139,10 +139,10 @@ pub extern "C" fn ptr_int_same_lanes(v: PtrX2) -> i64x2 { // CHECK-NOT: alloca // CHECK: %[[TEMP:.+]] = alloca [16 x i8] // CHECK-NOT: alloca - // CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: store <2 x ptr> %v, ptr %[[TEMP]] // CHECK: %[[RET:.+]] = load <2 x i64>, ptr %[[TEMP]] - // CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: ret <2 x i64> %[[RET]] unsafe { transmute(v) } } @@ -153,10 +153,10 @@ pub extern "C" fn float_ptr_widen(v: f32x4) -> PtrX2 { // CHECK-NOT: alloca // CHECK: %[[TEMP:.+]] = alloca [16 x i8] // CHECK-NOT: alloca - // CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: store <4 x float> %v, ptr %[[TEMP]] // CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]] - // CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: ret <2 x ptr> %[[RET]] unsafe { transmute(v) } } @@ -167,10 +167,10 @@ pub extern "C" fn int_ptr_widen(v: i32x4) -> PtrX2 { // CHECK-NOT: alloca // CHECK: %[[TEMP:.+]] = alloca [16 x i8] // CHECK-NOT: alloca - // CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: store <4 x i32> %v, ptr %[[TEMP]] // CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]] - // CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]]) // CHECK: ret <2 x ptr> %[[RET]] unsafe { transmute(v) } } diff --git a/tests/codegen-llvm/intrinsics/transmute.rs b/tests/codegen-llvm/intrinsics/transmute.rs index 91cff38773d..e327c100d7d 100644 --- a/tests/codegen-llvm/intrinsics/transmute.rs +++ b/tests/codegen-llvm/intrinsics/transmute.rs @@ -192,12 +192,12 @@ pub unsafe fn check_byte_from_bool(x: bool) -> u8 { #[no_mangle] pub unsafe fn check_to_pair(x: u64) -> Option { // CHECK: %[[TEMP:.+]] = alloca [8 x i8], align 8 - // CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 8, )?}}ptr %[[TEMP]]) // CHECK: store i64 %x, ptr %[[TEMP]], align 8 // CHECK: %[[PAIR0:.+]] = load i32, ptr %[[TEMP]], align 8 // CHECK: %[[PAIR1P:.+]] = getelementptr inbounds i8, ptr %[[TEMP]], i64 4 // CHECK: %[[PAIR1:.+]] = load i32, ptr %[[PAIR1P]], align 4 - // CHECK: call void @llvm.lifetime.end.p0(i64 8, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 8, )?}}ptr %[[TEMP]]) // CHECK: insertvalue {{.+}}, i32 %[[PAIR0]], 0 // CHECK: insertvalue {{.+}}, i32 %[[PAIR1]], 1 transmute(x) @@ -207,12 +207,12 @@ pub unsafe fn check_to_pair(x: u64) -> Option { #[no_mangle] pub unsafe fn check_from_pair(x: Option) -> u64 { // CHECK: %[[TEMP:.+]] = alloca [8 x i8], align 8 - // CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.start.p0({{(i64 8, )?}}ptr %[[TEMP]]) // CHECK: store i32 %x.0, ptr %[[TEMP]], align 8 // CHECK: %[[PAIR1P:.+]] = getelementptr inbounds i8, ptr %[[TEMP]], i64 4 // CHECK: store i32 %x.1, ptr %[[PAIR1P]], align 4 // CHECK: %[[R:.+]] = load i64, ptr %[[TEMP]], align 8 - // CHECK: call void @llvm.lifetime.end.p0(i64 8, ptr %[[TEMP]]) + // CHECK: call void @llvm.lifetime.end.p0({{(i64 8, )?}}ptr %[[TEMP]]) // CHECK: ret i64 %[[R]] transmute(x) } diff --git a/tests/codegen-llvm/issues/issue-105386-ub-in-debuginfo.rs b/tests/codegen-llvm/issues/issue-105386-ub-in-debuginfo.rs index 848aa910b58..95e70bf5678 100644 --- a/tests/codegen-llvm/issues/issue-105386-ub-in-debuginfo.rs +++ b/tests/codegen-llvm/issues/issue-105386-ub-in-debuginfo.rs @@ -19,6 +19,6 @@ pub fn outer_function(x: S, y: S) -> usize { // CHECK: [[spill:%.*]] = alloca // CHECK-NOT: [[ptr_tmp:%.*]] = getelementptr inbounds i8, ptr [[spill]] // CHECK-NOT: [[load:%.*]] = load ptr, ptr -// CHECK: call void @llvm.lifetime.start{{.*}}({{.*}}, ptr [[spill]]) +// CHECK: call void @llvm.lifetime.start{{.*}}({{(.*, )?}}ptr [[spill]]) // CHECK: [[inner:%.*]] = getelementptr inbounds i8, ptr [[spill]] // CHECK: call void @llvm.memcpy{{.*}}(ptr {{align .*}} [[inner]], ptr {{align .*}} %x diff --git a/tests/codegen-llvm/lifetime_start_end.rs b/tests/codegen-llvm/lifetime_start_end.rs index 0639e7640aa..2df5acf54df 100644 --- a/tests/codegen-llvm/lifetime_start_end.rs +++ b/tests/codegen-llvm/lifetime_start_end.rs @@ -8,27 +8,27 @@ pub fn test() { let a = 0u8; &a; // keep variable in an alloca - // CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, ptr %a) + // CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}ptr %a) { let b = &Some(a); &b; // keep variable in an alloca - // CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, {{.*}}) + // CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}{{.*}}) - // CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, {{.*}}) + // CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}{{.*}}) - // CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, {{.*}}) + // CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}{{.*}}) - // CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, {{.*}}) + // CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}{{.*}}) } let c = 1u8; &c; // keep variable in an alloca - // CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, ptr %c) + // CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}ptr %c) - // CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, ptr %c) + // CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}ptr %c) - // CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, ptr %a) + // CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}ptr %a) } diff --git a/tests/codegen-llvm/uninhabited-transparent-return-abi.rs b/tests/codegen-llvm/uninhabited-transparent-return-abi.rs index face1577c3f..83d9a7a32ab 100644 --- a/tests/codegen-llvm/uninhabited-transparent-return-abi.rs +++ b/tests/codegen-llvm/uninhabited-transparent-return-abi.rs @@ -23,7 +23,7 @@ extern "Rust" { #[no_mangle] pub fn test_uninhabited_ret_by_ref() { // CHECK: %_1 = alloca [24 x i8], align {{8|4}} - // CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %_1) + // CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 24, )?}}ptr nonnull %_1) // CHECK-NEXT: call void @opaque({{.*}} sret([24 x i8]) {{.*}} %_1) #2 // CHECK-NEXT: unreachable unsafe { @@ -35,7 +35,7 @@ pub fn test_uninhabited_ret_by_ref() { #[no_mangle] pub fn test_uninhabited_ret_by_ref_with_arg(rsi: u32) { // CHECK: %_2 = alloca [24 x i8], align {{8|4}} - // CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %_2) + // CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 24, )?}}ptr nonnull %_2) // CHECK-NEXT: call void @opaque_with_arg({{.*}} sret([24 x i8]) {{.*}} %_2, i32 noundef %rsi) #2 // CHECK-NEXT: unreachable unsafe { -- cgit 1.4.1-3-g733a5 From f978932903cac8cf508ef78b8133fafd9fe5a5b8 Mon Sep 17 00:00:00 2001 From: StackOverflowExcept1on <109800286+StackOverflowExcept1on@users.noreply.github.com> Date: Tue, 12 Aug 2025 00:39:24 +0300 Subject: fix(compiler/rustc_codegen_llvm): apply `target-cpu` attribute --- Cargo.lock | 2 +- compiler/rustc_codegen_llvm/src/allocator.rs | 17 +++++++++- src/tools/run-make-support/Cargo.toml | 2 +- tests/run-make/wasm-unexpected-features/rmake.rs | 38 ++++++++++++++++++++++ .../wasm32_test/Cargo.toml | 13 ++++++++ .../wasm32_test/src/lib.rs | 34 +++++++++++++++++++ 6 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 tests/run-make/wasm-unexpected-features/rmake.rs create mode 100644 tests/run-make/wasm-unexpected-features/wasm32_test/Cargo.toml create mode 100644 tests/run-make/wasm-unexpected-features/wasm32_test/src/lib.rs (limited to 'compiler/rustc_codegen_llvm/src') diff --git a/Cargo.lock b/Cargo.lock index 4eb246995b1..1a196172cdc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3283,7 +3283,7 @@ dependencies = [ "regex", "serde_json", "similar", - "wasmparser 0.219.2", + "wasmparser 0.236.0", ] [[package]] diff --git a/compiler/rustc_codegen_llvm/src/allocator.rs b/compiler/rustc_codegen_llvm/src/allocator.rs index 2b5090ed6db..23610aa856c 100644 --- a/compiler/rustc_codegen_llvm/src/allocator.rs +++ b/compiler/rustc_codegen_llvm/src/allocator.rs @@ -8,11 +8,12 @@ use rustc_middle::bug; use rustc_middle::ty::TyCtxt; use rustc_session::config::{DebugInfo, OomStrategy}; use rustc_symbol_mangling::mangle_internal_symbol; +use smallvec::SmallVec; use crate::builder::SBuilder; use crate::declare::declare_simple_fn; use crate::llvm::{self, False, True, Type, Value}; -use crate::{SimpleCx, attributes, debuginfo}; +use crate::{SimpleCx, attributes, debuginfo, llvm_util}; pub(crate) unsafe fn codegen( tcx: TyCtxt<'_>, @@ -147,6 +148,20 @@ fn create_wrapper_function( llvm::Visibility::from_generic(tcx.sess.default_visibility()), ty, ); + + let mut attrs = SmallVec::<[_; 2]>::new(); + + let target_cpu = llvm_util::target_cpu(tcx.sess); + let target_cpu_attr = llvm::CreateAttrStringValue(cx.llcx, "target-cpu", target_cpu); + + let tune_cpu_attr = llvm_util::tune_cpu(tcx.sess) + .map(|tune_cpu| llvm::CreateAttrStringValue(cx.llcx, "tune-cpu", tune_cpu)); + + attrs.push(target_cpu_attr); + attrs.extend(tune_cpu_attr); + + attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &attrs); + let no_return = if no_return { // -> ! DIFlagNoReturn let no_return = llvm::AttributeKind::NoReturn.create_attr(cx.llcx); diff --git a/src/tools/run-make-support/Cargo.toml b/src/tools/run-make-support/Cargo.toml index a4e7534137d..250e0f65a9f 100644 --- a/src/tools/run-make-support/Cargo.toml +++ b/src/tools/run-make-support/Cargo.toml @@ -17,7 +17,7 @@ object = "0.37" regex = "1.11" serde_json = "1.0" similar = "2.7" -wasmparser = { version = "0.219", default-features = false, features = ["std"] } +wasmparser = { version = "0.236", default-features = false, features = ["std", "features", "validate"] } # tidy-alphabetical-end # Shared with bootstrap and compiletest diff --git a/tests/run-make/wasm-unexpected-features/rmake.rs b/tests/run-make/wasm-unexpected-features/rmake.rs new file mode 100644 index 00000000000..142860c47ec --- /dev/null +++ b/tests/run-make/wasm-unexpected-features/rmake.rs @@ -0,0 +1,38 @@ +//@ only-wasm32-wasip1 + +use std::path::Path; + +use run_make_support::{cargo, path, rfs, target, wasmparser}; + +fn main() { + let target_dir = path("target"); + + cargo() + .args([ + "rustc", + "--manifest-path", + "wasm32_test/Cargo.toml", + "--profile", + "release", + "--target", + "wasm32-wasip1", + "-Zbuild-std=core,alloc,panic_abort", + "--", + "-Clink-arg=--import-memory", + "-Clinker-plugin-lto=on", + ]) + .env("RUSTFLAGS", "-Ctarget-cpu=mvp") + .env("CARGO_TARGET_DIR", &target_dir) + .run(); + + let wasm32_program_path = target_dir.join(target()).join("release").join("wasm32_program.wasm"); + verify_features(&wasm32_program_path); +} + +fn verify_features(path: &Path) { + eprintln!("verify {path:?}"); + let file = rfs::read(&path); + + let mut validator = wasmparser::Validator::new_with_features(wasmparser::WasmFeatures::WASM1); + validator.validate_all(&file).unwrap(); +} diff --git a/tests/run-make/wasm-unexpected-features/wasm32_test/Cargo.toml b/tests/run-make/wasm-unexpected-features/wasm32_test/Cargo.toml new file mode 100644 index 00000000000..9f1070e1c54 --- /dev/null +++ b/tests/run-make/wasm-unexpected-features/wasm32_test/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "wasm32_test" +version = "0.1.0" +edition = "2024" + +[lib] +crate-type = ["cdylib"] +name = "wasm32_program" + +[profile.release] +codegen-units = 1 +lto = "fat" +opt-level = "z" diff --git a/tests/run-make/wasm-unexpected-features/wasm32_test/src/lib.rs b/tests/run-make/wasm-unexpected-features/wasm32_test/src/lib.rs new file mode 100644 index 00000000000..123c2c1b6ca --- /dev/null +++ b/tests/run-make/wasm-unexpected-features/wasm32_test/src/lib.rs @@ -0,0 +1,34 @@ +#![no_std] + +extern crate alloc; + +use core::alloc::{GlobalAlloc, Layout}; +use core::mem::MaybeUninit; + +#[global_allocator] +static ALLOC: GlobalDlmalloc = GlobalDlmalloc; + +struct GlobalDlmalloc; + +unsafe impl GlobalAlloc for GlobalDlmalloc { + #[inline] + unsafe fn alloc(&self, _layout: Layout) -> *mut u8 { + core::ptr::null_mut() + } + + #[inline] + unsafe fn dealloc(&self, _ptr: *mut u8, _layout: Layout) {} +} + +#[used] +static mut BUF: MaybeUninit<[u8; 1024]> = MaybeUninit::uninit(); + +#[unsafe(no_mangle)] +extern "C" fn init() { + alloc::alloc::handle_alloc_error(Layout::new::<[u8; 64 * 1024]>()); +} + +#[panic_handler] +fn my_panic(_: &core::panic::PanicInfo) -> ! { + loop {} +} -- cgit 1.4.1-3-g733a5 From 5c631041aa0b0ad9e161b966b78e6dfdb8011023 Mon Sep 17 00:00:00 2001 From: Marcelo Domínguez Date: Thu, 14 Aug 2025 15:22:45 +0000 Subject: Basic implementation of `autodiff` intrinsic --- compiler/rustc_builtin_macros/src/autodiff.rs | 139 +++++++++- .../rustc_codegen_llvm/src/builder/autodiff.rs | 302 +++++---------------- compiler/rustc_codegen_llvm/src/context.rs | 2 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 76 +++++- compiler/rustc_hir_analysis/src/check/intrinsic.rs | 5 +- compiler/rustc_span/src/symbol.rs | 1 + library/core/src/intrinsics/mod.rs | 4 + 7 files changed, 284 insertions(+), 245 deletions(-) (limited to 'compiler/rustc_codegen_llvm/src') diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index a662840eda5..3f8585d35bc 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -329,17 +329,22 @@ mod llvm_enzyme { .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = gen_enzyme_body( + + // TODO(Sa4dUs): Remove this and all the related logic + let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); + let d_body = + call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics, + generics: generics.clone(), contract: None, body: Some(d_body), define_opaque: None, @@ -428,12 +433,15 @@ mod llvm_enzyme { tokens: ts, }); + let vis_clone = vis.clone(); + + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = Box::new(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -443,13 +451,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(Box::new(ast::Stmt { @@ -463,7 +471,9 @@ mod llvm_enzyme { } }; - return vec![orig_annotatable, d_annotatable]; + let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); + + return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -484,6 +494,123 @@ mod llvm_enzyme { ty } + // Generate `autodiff` intrinsic call + // ``` + // std::intrinsics::autodiff(source, diff, (args)) + // ``` + fn call_autodiff( + ecx: &ExtCtxt<'_>, + primal: Ident, + diff: Ident, + span: Span, + d_sig: &FnSig, + ) -> P { + let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + + let tuple_expr = ecx.expr_tuple( + span, + d_sig + .decl + .inputs + .iter() + .map(|arg| match arg.pat.kind { + PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)), + _ => todo!(), + }) + .collect::>() + .into(), + ); + + let enzyme_path = ecx.path( + span, + vec![ + Ident::from_str("std"), + Ident::from_str("intrinsics"), + Ident::from_str("autodiff"), + ], + ); + let call_expr = ecx.expr_call( + span, + ecx.expr_path(enzyme_path), + vec![primal_path_expr, diff_path_expr, tuple_expr].into(), + ); + + let block = ecx.block_expr(call_expr); + + block + } + + // Generate dummy const to prevent primal function + // from being optimized away before applying enzyme + // ``` + // const _: () = + // { + // #[used] + // pub static DUMMY_PTR: fn_type = primal_fn; + // }; + // ``` + fn gen_dummy_const( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + sig: FnSig, + generics: Generics, + vis: Visibility, + ) -> Annotatable { + // #[used] + let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let used_attr = outer_normal_attr(&used_attr, new_id, span); + + // static DUMMY_PTR: = + let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); + let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { + safety: sig.header.safety, + ext: sig.header.ext, + generic_params: generics.params, + decl: sig.decl, + decl_span: sig.span, + })); + let static_ty = ecx.ty(span, fn_ptr_ty); + + let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); + let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { + ident: static_ident, + ty: static_ty, + safety: ast::Safety::Default, + mutability: ast::Mutability::Not, + expr: Some(static_expr), + define_opaque: None, + })); + + let static_item = ast::Item { + attrs: thin_vec![used_attr], + id: ast::DUMMY_NODE_ID, + span, + vis, + kind: static_item_kind, + tokens: None, + }; + + let block_expr = ecx.expr_block(Box::new(ast::Block { + stmts: thin_vec![ecx.stmt_item(span, P(static_item))], + id: ast::DUMMY_NODE_ID, + rules: ast::BlockCheckMode::Default, + span, + tokens: None, + })); + + let const_item = ecx.item_const( + span, + Ident::from_str_and_span("_", span), + ecx.ty(span, ast::TyKind::Tup(thin_vec![])), + block_expr, + ); + + Annotatable::Item(const_item) + } + // Will generate a body of the type: // ``` // { diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 829b3c513c2..66c34fbcfb1 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,22 +3,22 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::common::TypeKind; -use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_errors::FatalError; use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::{SBuilder, UNNAMED}; +use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; -use crate::llvm::{Metadata, True}; +use crate::llvm::{Metadata, True, Type}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; -fn get_params(fnc: &Value) -> Vec<&Value> { +fn _get_params(fnc: &Value) -> Vec<&Value> { let param_num = llvm::LLVMCountParams(fnc) as usize; let mut fnc_args: Vec<&Value> = vec![]; fnc_args.reserve(param_num); @@ -29,7 +29,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> { fnc_args } -fn has_sret(fnc: &Value) -> bool { +fn _has_sret(fnc: &Value) -> bool { let num_args = llvm::LLVMCountParams(fnc) as usize; if num_args == 0 { false @@ -48,14 +48,13 @@ fn has_sret(fnc: &Value) -> bool { // need to match those. // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it // using iterators and peek()? -fn match_args_from_caller_to_enzyme<'ll>( +fn match_args_from_caller_to_enzyme<'ll, 'tcx>( cx: &SimpleCx<'ll>, - builder: &SBuilder<'ll, 'll>, + builder: &mut Builder<'_, 'll, 'tcx>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], outer_args: &[&'ll llvm::Value], - has_sret: bool, ) { debug!("matching autodiff arguments"); // We now handle the issue that Rust level arguments not always match the llvm-ir level @@ -67,20 +66,12 @@ fn match_args_from_caller_to_enzyme<'ll>( let mut outer_pos: usize = 0; let mut activity_pos = 0; - if has_sret { - // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the - // inner function will still return something. We increase our outer_pos by one, - // and once we're done with all other args we will take the return of the inner call and - // update the sret pointer with it - outer_pos = 1; - } - - let enzyme_const = cx.create_metadata(b"enzyme_const"); - let enzyme_out = cx.create_metadata(b"enzyme_out"); - let enzyme_dup = cx.create_metadata(b"enzyme_dup"); - let enzyme_dupv = cx.create_metadata(b"enzyme_dupv"); - let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed"); - let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv"); + let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); + let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); + let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); + let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap(); + let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); + let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap(); while activity_pos < inputs.len() { let diff_activity = inputs[activity_pos as usize]; @@ -193,92 +184,6 @@ fn match_args_from_caller_to_enzyme<'ll>( } } -// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input -// arguments. We do however need to declare them with their correct return type. -// We already figured the correct return type out in our frontend, when generating the outer_fn, -// so we can now just go ahead and use that. This is not always trivial, e.g. because sret. -// Beyond sret, this article describes our challenges nicely: -// -// I.e. (i32, f32) will get merged into i64, but we don't handle that yet. -fn compute_enzyme_fn_ty<'ll>( - cx: &SimpleCx<'ll>, - attrs: &AutoDiffAttrs, - fn_to_diff: &'ll Value, - outer_fn: &'ll Value, -) -> &'ll llvm::Type { - let fn_ty = cx.get_type_of_global(outer_fn); - let mut ret_ty = cx.get_return_type(fn_ty); - - let has_sret = has_sret(outer_fn); - - if has_sret { - // Now we don't just forward the return type, so we have to figure it out based on the - // primal return type, in combination with the autodiff settings. - let fn_ty = cx.get_type_of_global(fn_to_diff); - let inner_ret_ty = cx.get_return_type(fn_ty); - - let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) }; - if inner_ret_ty == void_ty { - // This indicates that even the inner function has an sret. - // Right now I only look for an sret in the outer function. - // This *probably* needs some extra handling, but I never ran - // into such a case. So I'll wait for user reports to have a test case. - bug!("sret in inner function"); - } - - if attrs.width == 1 { - // Enzyme returns a struct of style: - // `{ original_ret(if requested), float, float, ... }` - let mut struct_elements = vec![]; - if attrs.has_primal_ret() { - struct_elements.push(inner_ret_ty); - } - // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, - // and therefore part of the return struct. - let param_tys = cx.func_params_types(fn_ty); - for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { - if matches!(act, DiffActivity::Active) { - // Now find the float type at position i based on the fn_ty, - // to know what (f16/f32/f64/...) to add to the struct. - struct_elements.push(param_ty); - } - } - ret_ty = cx.type_struct(&struct_elements, false); - } else { - // First we check if we also have to deal with the primal return. - match attrs.mode { - DiffMode::Forward => match attrs.ret_activity { - DiffActivity::Dual => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) }; - ret_ty = arr_ty; - } - DiffActivity::DualOnly => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) }; - ret_ty = arr_ty; - } - DiffActivity::Const => { - todo!("Not sure, do we need to do something here?"); - } - _ => { - bug!("unreachable"); - } - }, - DiffMode::Reverse => { - todo!("Handle sret for reverse mode"); - } - _ => { - bug!("unreachable"); - } - } - } - } - - // LLVM can figure out the input types on it's own, so we take a shortcut here. - unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) } -} - /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another /// function with expected naming and calling conventions[^1] which will be /// discovered by the enzyme LLVM pass and its body populated with the differentiated @@ -288,11 +193,15 @@ fn compute_enzyme_fn_ty<'ll>( /// [^1]: // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll>( +pub(crate) fn generate_enzyme_call<'ll, 'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, - outer_fn: &'ll Value, + outer_name: &str, + ret_ty: &'ll Type, + fn_args: &[&'ll Value], attrs: AutoDiffAttrs, + dest: PlaceRef<'tcx, &'ll Value>, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -302,11 +211,9 @@ fn generate_enzyme_call<'ll>( } .to_string(); - // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple + // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple // functions. Unwrap will only panic, if LLVM gave us an invalid string. - let name = llvm::get_value_name(outer_fn); - let outer_fn_name = std::str::from_utf8(&name).unwrap(); - ad_name.push_str(outer_fn_name); + ad_name.push_str(outer_name); // Let us assume the user wrote the following function square: // @@ -317,13 +224,7 @@ fn generate_enzyme_call<'ll>( // ret double %0 // } // ``` - // - // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`. - // Our macro generates the following placeholder code (slightly simplified): - // - // ```llvm // define double @dsquare(double %x) { - // ; placeholder code // return 0.0; // } // ``` @@ -340,120 +241,52 @@ fn generate_enzyme_call<'ll>( // ret double %0 // } // ``` - unsafe { - let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn); - - // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and - // think a bit more about what should go here. - let cc = llvm::LLVMGetFunctionCallConv(outer_fn); - let ad_fn = declare_simple_fn( - cx, - &ad_name, - llvm::CallConv::try_from(cc).expect("invalid callconv"), - llvm::UnnamedAddr::No, - llvm::Visibility::Default, - enzyme_ty, - ); - - // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to - // do it's work. - let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); - attributes::apply_to_llfn(ad_fn, Function, &[attr]); - - // We add a made-up attribute just such that we can recognize it after AD to update - // (no)-inline attributes. We'll then also remove this attribute. - let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); - attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - - // first, remove all calls from fnc - let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); - let br = llvm::LLVMRustGetTerminator(entry); - llvm::LLVMRustEraseInstFromParent(br); - - let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = SBuilder::build(cx, entry); - - let num_args = llvm::LLVMCountParams(&fn_to_diff); - let mut args = Vec::with_capacity(num_args as usize + 1); - args.push(fn_to_diff); - - let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return"); - if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { - args.push(cx.get_metadata_value(enzyme_primal_ret)); - } - if attrs.width > 1 { - let enzyme_width = cx.create_metadata(b"enzyme_width"); - args.push(cx.get_metadata_value(enzyme_width)); - args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); - } - - let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = get_params(outer_fn); - match_args_from_caller_to_enzyme( - &cx, - &builder, - attrs.width, - &mut args, - &attrs.input_activity, - &outer_args, - has_sret, - ); - - let call = builder.call(enzyme_ty, ad_fn, &args, None); - - // This part is a bit iffy. LLVM requires that a call to an inlineable function has some - // metadata attached to it, but we just created this code oota. Given that the - // differentiated function already has partly confusing metadata, and given that this - // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the - // dummy code which we inserted at a higher level. - // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, - // and how to best improve it for enzyme core and rust-enzyme. - let md_ty = cx.get_md_kind_id("dbg"); - if llvm::LLVMRustHasMetadata(last_inst, md_ty) { - let md = llvm::LLVMRustDIGetInstMetadata(last_inst) - .expect("failed to get instruction metadata"); - let md_todiff = cx.get_metadata_value(md); - llvm::LLVMSetMetadata(call, md_ty, md_todiff); - } else { - // We don't panic, since depending on whether we are in debug or release mode, we might - // have no debug info to copy, which would then be ok. - trace!("no dbg info"); - } - - // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); + let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }; + + // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and + // think a bit more about what should go here. + // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now + let cc = 8; + let ad_fn = declare_simple_fn( + cx, + &ad_name, + llvm::CallConv::try_from(cc).expect("invalid callconv"), + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + enzyme_ty, + ); + + // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to + // do it's work. + let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); + attributes::apply_to_llfn(ad_fn, Function, &[attr]); + + let num_args = llvm::LLVMCountParams(&fn_to_diff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(fn_to_diff); + + let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); + if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { + args.push(cx.get_metadata_value(enzyme_primal_ret)); + } + if attrs.width > 1 { + let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); + args.push(cx.get_metadata_value(enzyme_width)); + args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); + } - if cx.val_ty(call) == cx.type_void() || has_sret { - if has_sret { - // This is what we already have in our outer_fn (shortened): - // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) { - // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>) - // - // store [4 x double] %7, ptr %0, align 8 - // ret void - // } + match_args_from_caller_to_enzyme( + &cx, + builder, + attrs.width, + &mut args, + &attrs.input_activity, + fn_args, + ); - // now store the result of the enzyme call into the sret pointer. - let sret_ptr = outer_args[0]; - let call_ty = cx.val_ty(call); - if attrs.width == 1 { - assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); - } else { - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); - } - llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); - } - builder.ret_void(); - } else { - builder.ret(call); - } + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - // Let's crash in case that we messed something up above and generated invalid IR. - llvm::LLVMRustVerifyFunction( - outer_fn, - llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction, - ); - } + builder.store_to_place(call, dest.val); } pub(crate) fn differentiate<'ll>( @@ -461,6 +294,7 @@ pub(crate) fn differentiate<'ll>( cgcx: &CodegenContext, diff_items: Vec, ) -> Result<(), FatalError> { + // TODO(Sa4dUs): delete all this logic for item in &diff_items { trace!("{}", item); } @@ -480,7 +314,7 @@ pub(crate) fn differentiate<'ll>( for item in diff_items.iter() { let name = item.source.clone(); let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(fn_def) = fn_def else { + let Some(_fn_def) = fn_def else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -492,7 +326,7 @@ pub(crate) fn differentiate<'ll>( }; debug!(?item.target); let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(fn_target) = fn_target else { + let Some(_fn_target) = fn_target else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -503,7 +337,7 @@ pub(crate) fn differentiate<'ll>( )); }; - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); } // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 27ae729a531..8eb15571e82 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -660,7 +660,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } impl<'ll> SimpleCx<'ll> { - pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type { + pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type { assert_eq!(self.type_kind(ty), TypeKind::Function); unsafe { llvm::LLVMGetReturnType(ty) } } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 7b27e496986..1102fc1d0c8 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -3,23 +3,26 @@ use std::cmp::Ordering; use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; +use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; use rustc_hir as hir; +use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; -use rustc_symbol_mangling::mangle_internal_symbol; +use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::spec::PanicStrategy; use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; +use crate::builder::autodiff::generate_enzyme_call; use crate::context::CodegenCx; use crate::llvm::{self, Metadata}; use crate::type_::Type; @@ -174,10 +177,17 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { span: Span, ) -> Result<(), ty::Instance<'tcx>> { let tcx = self.tcx; + let callee_ty = instance.ty(tcx, self.typing_env()); - let name = tcx.item_name(instance.def_id()); let fn_args = instance.args; + let sig = callee_ty.fn_sig(tcx); + let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig); + let ret_ty = sig.output(); + let name = tcx.item_name(instance.def_id()); + + let llret_ty = self.layout_of(ret_ty).llvm_type(self); + let simple = call_simple_intrinsic(self, name, args); let llval = match name { _ if simple.is_some() => simple.unwrap(), @@ -189,6 +199,66 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } + sym::autodiff => { + let val_arr: Vec<&'ll Value> = match args[2].val { + crate::intrinsic::OperandValue::Ref(ref place_value) => { + let mut ret_arr = vec![]; + let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout }; + + for i in 0..tuple_place.layout.layout.0.fields.count() { + let field_place = tuple_place.project_field(self, i); + let field_layout = tuple_place.layout.field(self, i); + let llvm_ty = field_layout.llvm_type(self.cx); + + let field_val = + self.load(llvm_ty, field_place.val.llval, field_place.val.align); + + ret_arr.push(field_val) + } + + ret_arr + } + crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Immediate(v) => vec![v], + OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + }; + + // Get source, diff, and attrs + let source_id = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_source = Instance::mono(tcx, *source_id); + let source_symbol = + symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + let diff_id = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_diff = Instance::mono(tcx, *diff_id); + let diff_symbol = + symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let diff_attrs = autodiff_attrs(tcx, *diff_id); + let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + + // Build body + generate_enzyme_call( + self, + self.cx, + fn_to_diff, + &diff_symbol, + llret_ty, + &val_arr, + diff_attrs.clone(), + result, + ); + + return Ok(()); + } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { self.call_intrinsic( diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 4441dd6ebd6..46371cfe591 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::round_ties_even_f32 | sym::round_ties_even_f64 | sym::round_ties_even_f128 + | sym::autodiff | sym::const_eval_select => hir::Safety::Safe, _ => hir::Safety::Unsafe, }; @@ -171,6 +172,8 @@ pub(crate) fn check_intrinsic_type( } }; + let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); + let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -195,9 +198,9 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; - let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { + sym::autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), sym::breakpoint => (0, 0, vec![], tcx.types.unit), diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 416ce27367e..f7a8258a9d8 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -542,6 +542,7 @@ symbols! { audit_that, augmented_assignments, auto_traits, + autodiff, autodiff_forward, autodiff_reverse, automatically_derived, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 7228ad0ed6d..6c389c55a5f 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3157,6 +3157,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64; #[rustc_intrinsic] pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn autodiff(f: F, df: G, args: T) -> R; + /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] #[rustc_allow_const_fn_unstable(const_eval_select)] -- cgit 1.4.1-3-g733a5 From 250d77e5d72fde69a6406050a3b037635f685378 Mon Sep 17 00:00:00 2001 From: Marcelo Domínguez Date: Thu, 14 Aug 2025 15:27:57 +0000 Subject: Complete functionality and general cleanup --- Cargo.lock | 2 - compiler/rustc_builtin_macros/src/autodiff.rs | 492 ++++----------------- compiler/rustc_codegen_gcc/src/lib.rs | 6 - .../rustc_codegen_llvm/src/builder/autodiff.rs | 178 ++++---- compiler/rustc_codegen_llvm/src/context.rs | 5 - compiler/rustc_codegen_llvm/src/intrinsic.rs | 212 ++++++--- compiler/rustc_codegen_llvm/src/lib.rs | 6 - compiler/rustc_codegen_ssa/src/back/write.rs | 18 +- compiler/rustc_codegen_ssa/src/base.rs | 7 +- compiler/rustc_codegen_ssa/src/codegen_attrs.rs | 10 +- compiler/rustc_codegen_ssa/src/traits/write.rs | 2 - compiler/rustc_hir_analysis/src/check/intrinsic.rs | 3 +- .../rustc_middle/src/middle/codegen_fn_attrs.rs | 4 - compiler/rustc_middle/src/mir/mono.rs | 2 - compiler/rustc_monomorphize/Cargo.toml | 2 - compiler/rustc_monomorphize/src/collector.rs | 5 + .../rustc_monomorphize/src/collector/autodiff.rs | 48 ++ compiler/rustc_monomorphize/src/partitioning.rs | 34 +- .../src/partitioning/autodiff.rs | 143 ------ library/core/src/intrinsics/mod.rs | 34 ++ library/core/src/macros/mod.rs | 2 + src/doc/rustc-dev-guide/src/SUMMARY.md | 1 - .../rustc-dev-guide/src/autodiff/limitations.md | 27 -- triagebot.toml | 3 - 24 files changed, 419 insertions(+), 827 deletions(-) create mode 100644 compiler/rustc_monomorphize/src/collector/autodiff.rs delete mode 100644 compiler/rustc_monomorphize/src/partitioning/autodiff.rs delete mode 100644 src/doc/rustc-dev-guide/src/autodiff/limitations.md (limited to 'compiler/rustc_codegen_llvm/src') diff --git a/Cargo.lock b/Cargo.lock index 8a878faecbc..e8590fa484b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4308,7 +4308,6 @@ name = "rustc_monomorphize" version = "0.0.0" dependencies = [ "rustc_abi", - "rustc_ast", "rustc_data_structures", "rustc_errors", "rustc_fluent_macro", @@ -4317,7 +4316,6 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", - "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 3f8585d35bc..c260dca87c0 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -15,11 +15,12 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, - MetaItemInner, PatKind, QSelf, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode, + FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind, + MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; - use rustc_span::{Ident, Span, Symbol, kw, sym}; + use rustc_span::{Ident, Span, Symbol, sym}; use thin_vec::{ThinVec, thin_vec}; use tracing::{debug, trace}; @@ -179,11 +180,8 @@ mod llvm_enzyme { } /// We expand the autodiff macro to generate a new placeholder function which passes - /// type-checking and can be called by users. The function body of the placeholder function will - /// later be replaced on LLVM-IR level, so the design of the body is less important and for now - /// should just prevent early inlining and optimizations which alter the function signature. - /// The exact signature of the generated function depends on the configuration provided by the - /// user, but here is an example: + /// type-checking and can be called by users. The exact signature of the generated function + /// depends on the configuration provided by the user, but here is an example: /// /// ``` /// #[autodiff(cos_box, Reverse, Duplicated, Active)] @@ -199,14 +197,8 @@ mod llvm_enzyme { /// f32::sin(**x) /// } /// #[rustc_autodiff(Reverse, Duplicated, Active)] - /// #[inline(never)] /// fn cos_box(x: &Box, dx: &mut Box, dret: f32) -> f32 { - /// unsafe { - /// asm!("NOP"); - /// }; - /// ::core::hint::black_box(sin(x)); - /// ::core::hint::black_box((dx, dret)); - /// ::core::hint::black_box(sin(x)) + /// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret)) /// } /// ``` /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked @@ -227,16 +219,24 @@ mod llvm_enzyme { // first get information about the annotable item: visibility, signature, name and generic // parameters. // these will be used to generate the differentiated version of the function - let Some((vis, sig, primal, generics)) = (match &item { - Annotatable::Item(iitem) => extract_item_info(iitem), + let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item { + Annotatable::Item(iitem) => { + extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false)) + } Annotatable::Stmt(stmt) => match &stmt.kind { - ast::StmtKind::Item(iitem) => extract_item_info(iitem), + ast::StmtKind::Item(iitem) => { + extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false)) + } _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) - } + Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( + assoc_item.vis.clone(), + sig.clone(), + ident.clone(), + generics.clone(), + *of_trait, + )), _ => None, }, _ => None, @@ -254,7 +254,6 @@ mod llvm_enzyme { }; let has_ret = has_ret(&sig.decl.output); - let sig_span = ecx.with_call_site_ctxt(sig.span); // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field @@ -323,28 +322,27 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); - let n_active: u32 = x - .input_activity - .iter() - .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) - .count() as u32; - let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - - // TODO(Sa4dUs): Remove this and all the related logic - let _d_body = gen_enzyme_body( - ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, - &generics, - ); + let d_sig = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = - call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + let d_body = ecx.block( + span, + thin_vec![call_autodiff( + ecx, + primal, + first_ident(&meta_item_vec[0]), + span, + &d_sig, + &generics, + impl_of_trait, + )], + ); // The first element of it is the name of the function to be generated - let asdf = Box::new(ast::Fn { + let d_fn = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: generics.clone(), + generics, contract: None, body: Some(d_body), define_opaque: None, @@ -433,13 +431,11 @@ mod llvm_enzyme { tokens: ts, }); - let vis_clone = vis.clone(); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { - let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); + let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn); let d_fn = Box::new(ast::AssocItem { attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, @@ -451,13 +447,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Stmt(Box::new(ast::Stmt { @@ -471,9 +467,7 @@ mod llvm_enzyme { } }; - let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); - - return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; + return vec![orig_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -504,9 +498,11 @@ mod llvm_enzyme { diff: Ident, span: Span, d_sig: &FnSig, - ) -> P { - let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); - let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + generics: &Generics, + is_impl: bool, + ) -> rustc_ast::Stmt { + let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl); let tuple_expr = ecx.expr_tuple( span, @@ -522,371 +518,65 @@ mod llvm_enzyme { .into(), ); - let enzyme_path = ecx.path( - span, - vec![ - Ident::from_str("std"), - Ident::from_str("intrinsics"), - Ident::from_str("autodiff"), - ], - ); + let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]); + let enzyme_path = ecx.path(span, enzyme_path_idents); let call_expr = ecx.expr_call( span, ecx.expr_path(enzyme_path), vec![primal_path_expr, diff_path_expr, tuple_expr].into(), ); - let block = ecx.block_expr(call_expr); - - block - } - - // Generate dummy const to prevent primal function - // from being optimized away before applying enzyme - // ``` - // const _: () = - // { - // #[used] - // pub static DUMMY_PTR: fn_type = primal_fn; - // }; - // ``` - fn gen_dummy_const( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - sig: FnSig, - generics: Generics, - vis: Visibility, - ) -> Annotatable { - // #[used] - let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); - let used_attr = outer_normal_attr(&used_attr, new_id, span); - - // static DUMMY_PTR: = - let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); - let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { - safety: sig.header.safety, - ext: sig.header.ext, - generic_params: generics.params, - decl: sig.decl, - decl_span: sig.span, - })); - let static_ty = ecx.ty(span, fn_ptr_ty); - - let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); - let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { - ident: static_ident, - ty: static_ty, - safety: ast::Safety::Default, - mutability: ast::Mutability::Not, - expr: Some(static_expr), - define_opaque: None, - })); - - let static_item = ast::Item { - attrs: thin_vec![used_attr], - id: ast::DUMMY_NODE_ID, - span, - vis, - kind: static_item_kind, - tokens: None, - }; - - let block_expr = ecx.expr_block(Box::new(ast::Block { - stmts: thin_vec![ecx.stmt_item(span, P(static_item))], - id: ast::DUMMY_NODE_ID, - rules: ast::BlockCheckMode::Default, - span, - tokens: None, - })); - - let const_item = ecx.item_const( - span, - Ident::from_str_and_span("_", span), - ecx.ty(span, ast::TyKind::Tup(thin_vec![])), - block_expr, - ); - - Annotatable::Item(const_item) + ecx.stmt_expr(call_expr) } - // Will generate a body of the type: - // ``` - // { - // unsafe { - // asm!("NOP"); - // } - // ::core::hint::black_box(primal(args)); - // ::core::hint::black_box((args, ret)); - // - // } - // ``` - fn init_body_helper( + // Generate turbofish expression from fn name and generics + // Given `foo` and `` params, gen `foo::` + // We use this expression when passing primal and diff function to the autodiff intrinsic + fn gen_turbofish_expr( ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - new_names: &[String], - sig_span: Span, - new_decl_span: Span, - idents: &[Ident], - errored: bool, + ident: Ident, generics: &Generics, - ) -> (Box, Box, Box, Box) { - let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); - let noop = ast::InlineAsm { - asm_macro: ast::AsmMacro::Asm, - template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())], - template_strs: Box::new([]), - operands: vec![], - clobber_abis: vec![], - options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM, - line_spans: vec![], - }; - let noop_expr = ecx.expr_asm(span, Box::new(noop)); - let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); - let unsf_block = ast::Block { - stmts: thin_vec![ecx.stmt_semi(noop_expr)], - id: ast::DUMMY_NODE_ID, - tokens: None, - rules: unsf, - span, - }; - let unsf_expr = ecx.expr_block(Box::new(unsf_block)); - let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let black_box_primal_call = ecx.expr_call( - new_decl_span, - blackbox_call_expr.clone(), - thin_vec![primal_call.clone()], - ); - let tup_args = new_names - .iter() - .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) - .collect(); - - let black_box_remaining_args = ecx.expr_call( - sig_span, - blackbox_call_expr.clone(), - thin_vec![ecx.expr_tuple(sig_span, tup_args)], - ); - - let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_semi(unsf_expr)); - - // This uses primal args which won't be available if we errored before - if !errored { - body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); - } - body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); - - (body, primal_call, black_box_primal_call, blackbox_call_expr) - } - - /// We only want this function to type-check, since we will replace the body - /// later on llvm level. Using `loop {}` does not cover all return types anymore, - /// so instead we manually build something that should pass the type checker. - /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining - /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another - /// bug would ever try to accidentally differentiate this placeholder function body. - /// Finally, we also add back_box usages of all input arguments, to prevent rustc - /// from optimizing any arguments away. - fn gen_enzyme_body( - ecx: &ExtCtxt<'_>, - x: &AutoDiffAttrs, - n_active: u32, - sig: &ast::FnSig, - d_sig: &ast::FnSig, - primal: Ident, - new_names: &[String], span: Span, - sig_span: Span, - idents: Vec, - errored: bool, - generics: &Generics, - ) -> Box { - let new_decl_span = d_sig.span; - - // Just adding some default inline-asm and black_box usages to prevent early inlining - // and optimizations which alter the function signature. - // - // The bb_primal_call is the black_box call of the primal function. We keep it around, - // since it has the convenient property of returning the type of the primal function, - // Remember, we only care to match types here. - // No matter which return we pick, we always wrap it into a std::hint::black_box call, - // to prevent rustc from propagating it into the caller. - let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper( - ecx, - span, - primal, - new_names, - sig_span, - new_decl_span, - &idents, - errored, - generics, - ); - - if !has_ret(&d_sig.decl.output) { - // there is no return type that we have to match, () works fine. - return body; - } - - // Everything from here onwards just tries to fulfil the return type. Fun! - - // having an active-only return means we'll drop the original return type. - // So that can be treated identical to not having one in the first place. - let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); - - if primal_ret && n_active == 0 && x.mode.is_rev() { - // We only have the primal ret. - body.stmts.push(ecx.stmt_expr(bb_primal_call)); - return body; - } - - if !primal_ret && n_active == 1 { - // Again no tuple return, so return default float val. - let ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); + is_impl: bool, + ) -> Box { + let generic_args = generics + .params + .iter() + .filter_map(|p| match &p.kind { + GenericParamKind::Type { .. } => { + let path = ast::Path::from_ident(p.ident); + let ty = ecx.ty_path(path); + Some(AngleBracketedArg::Arg(GenericArg::Type(ty))) } - }; - let arg = ty.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - body.stmts.push(ecx.stmt_expr(default_call_expr)); - return body; - } - - let mut exprs: Box = primal_call; - let d_ret_ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - if x.mode.is_fwd() { - // Fwd mode is easy. If the return activity is Const, we support arbitrary types. - // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. - // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function. - // In all three cases, we can return `std::hint::black_box(::default())`. - if x.ret_activity == DiffActivity::Const { - // Here we call the primal function, since our dummy function has the same return - // type due to the Const return activity. - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 }; - let y = ExprKind::Path( - Some(Box::new(q)), - ecx.path_ident(span, Ident::with_dummy_span(kw::Default)), - ); - let default_call_expr = ecx.expr(span, y); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]); - } - } else if x.mode.is_rev() { - if x.width == 1 { - // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`. - match d_ret_ty.kind { - TyKind::Tup(ref args) => { - // We have a tuple return type. We need to create a tuple of the same size - // and fill it with default values. - let mut exprs2 = thin_vec![exprs]; - for arg in args.iter().skip(1) { - let arg = arg.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs2.push(default_call_expr); - } - exprs = ecx.expr_tuple(new_decl_span, exprs2); - } - _ => { - // Interestingly, even the `-> ArbitraryType` case - // ends up getting matched and handled correctly above, - // so we don't have to handle any other case for now. - panic!("Unsupported return type: {:?}", d_ret_ty); - } + GenericParamKind::Const { .. } => { + let expr = ecx.expr_path(ast::Path::from_ident(p.ident)); + let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr }; + Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const))) } - } - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - unreachable!("Unsupported mode: {:?}", x.mode); - } - - body.stmts.push(ecx.stmt_expr(exprs)); + GenericParamKind::Lifetime { .. } => None, + }) + .collect::>(); - body - } + let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args }; - fn gen_primal_call( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - idents: &[Ident], - generics: &Generics, - ) -> Box { - let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + let segment = PathSegment { + ident, + id: ast::DUMMY_NODE_ID, + args: Some(Box::new(GenericArgs::AngleBracketed(args))), + }; - if has_self { - let args: ThinVec<_> = - idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let self_expr = ecx.expr_self(span); - ecx.expr_method_call(span, self_expr, primal, args) + let segments = if is_impl { + thin_vec![ + PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None }, + segment, + ] } else { - let args: ThinVec<_> = - idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let mut primal_path = ecx.path_ident(span, primal); - - let is_generic = !generics.params.is_empty(); - - match (is_generic, primal_path.segments.last_mut()) { - (true, Some(function_path)) => { - let primal_generic_types = generics - .params - .iter() - .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); - - let generated_generic_types = primal_generic_types - .map(|type_param| { - let generic_param = TyKind::Path( - None, - ast::Path { - span, - segments: thin_vec![ast::PathSegment { - ident: type_param.ident, - args: None, - id: ast::DUMMY_NODE_ID, - }], - tokens: None, - }, - ); - - ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty { - id: type_param.id, - span, - kind: generic_param, - tokens: None, - }))) - }) - .collect(); - - function_path.args = - Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { - span, - args: generated_generic_types, - }))); - } - _ => {} - } + thin_vec![segment] + }; - let primal_call_expr = ecx.expr_path(primal_path); - ecx.expr_call(span, primal_call_expr, args) - } + let path = Path { span, segments, tokens: None }; + + ecx.expr_path(path) } // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must @@ -905,7 +595,7 @@ mod llvm_enzyme { sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, - ) -> (ast::FnSig, Vec, Vec, bool) { + ) -> ast::FnSig { let dcx = ecx.sess.dcx(); let has_ret = has_ret(&sig.decl.output); let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; @@ -917,7 +607,7 @@ mod llvm_enzyme { found: num_activities, }); // This is not the right signature, but we can continue parsing. - return (sig.clone(), vec![], vec![], true); + return sig.clone(); } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(has_ret == x.has_ret_activity()); @@ -960,7 +650,7 @@ mod llvm_enzyme { if errors { // This is not the right signature, but we can continue parsing. - return (sig.clone(), new_inputs, idents, true); + return sig.clone(); } let unsafe_activities = x @@ -1174,7 +864,7 @@ mod llvm_enzyme { } let d_sig = FnSig { header: d_header, decl: d_decl, span }; trace!("Generated signature: {:?}", d_sig); - (d_sig, new_inputs, idents, false) + d_sig } } diff --git a/compiler/rustc_codegen_gcc/src/lib.rs b/compiler/rustc_codegen_gcc/src/lib.rs index b11f11d38e3..4025aba82da 100644 --- a/compiler/rustc_codegen_gcc/src/lib.rs +++ b/compiler/rustc_codegen_gcc/src/lib.rs @@ -93,7 +93,6 @@ use gccjit::{CType, Context, OptimizationLevel}; #[cfg(feature = "master")] use gccjit::{TargetInfo, Version}; use rustc_ast::expand::allocator::AllocatorKind; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn, @@ -363,12 +362,7 @@ impl WriteBackendMethods for GccCodegenBackend { _exported_symbols_for_lto: &[String], each_linked_rlib_for_lto: &[PathBuf], modules: Vec>, - diff_functions: Vec, ) -> Result, FatalError> { - if !diff_functions.is_empty() { - unimplemented!(); - } - back::lto::run_fat(cgcx, each_linked_rlib_for_lto, modules) } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 66c34fbcfb1..56116959a62 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,40 +1,93 @@ use std::ptr; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; -use rustc_codegen_ssa::ModuleCodegen; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; -use rustc_errors::FatalError; -use rustc_middle::bug; -use tracing::{debug, trace}; +use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; +use rustc_middle::{bug, ty}; +use tracing::debug; -use crate::back::write::llvm_err; use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; -use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True, Type}; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; +use crate::{attributes, llvm}; -fn _get_params(fnc: &Value) -> Vec<&Value> { - let param_num = llvm::LLVMCountParams(fnc) as usize; - let mut fnc_args: Vec<&Value> = vec![]; - fnc_args.reserve(param_num); - unsafe { - llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); - fnc_args.set_len(param_num); +pub(crate) fn adjust_activity_to_abi<'tcx>( + tcx: TyCtxt<'tcx>, + fn_ty: Ty<'tcx>, + da: &mut Vec, +) { + if !matches!(fn_ty.kind(), ty::FnDef(..)) { + bug!("expected fn def for autodiff, got {:?}", fn_ty); } - fnc_args -} -fn _has_sret(fnc: &Value) -> bool { - let num_args = llvm::LLVMCountParams(fnc) as usize; - if num_args == 0 { - false - } else { - unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) } + // We don't actually pass the types back into the type system. + // All we do is decide how to handle the arguments. + let sig = fn_ty.fn_sig(tcx).skip_binder(); + + let mut new_activities = vec![]; + let mut new_positions = vec![]; + for (i, ty) in sig.inputs().iter().enumerate() { + if let Some(inner_ty) = ty.builtin_deref(true) { + if inner_ty.is_slice() { + // Now we need to figure out the size of each slice element in memory to allow + // safety checks and usability improvements in the backend. + let sty = match inner_ty.builtin_index() { + Some(sty) => sty, + None => { + panic!("slice element type unknown"); + } + }; + let pci = PseudoCanonicalInput { + typing_env: TypingEnv::fully_monomorphized(), + value: sty, + }; + + let layout = tcx.layout_of(pci); + let elem_size = match layout { + Ok(layout) => layout.size, + Err(_) => { + bug!("autodiff failed to compute slice element size"); + } + }; + let elem_size: u32 = elem_size.bytes() as u32; + + // We know that the length will be passed as extra arg. + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + // There is one FakeActivitySize per slice, so for convenience we store the + // slice element size in bytes in it. We will use the size in the backend. + let activity = match da[i] { + DiffActivity::DualOnly + | DiffActivity::Dual + | DiffActivity::Dualv + | DiffActivity::DuplicatedOnly + | DiffActivity::Duplicated => { + DiffActivity::FakeActivitySize(Some(elem_size)) + } + DiffActivity::Const => DiffActivity::Const, + _ => bug!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + + continue; + } + } + } + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); } } @@ -66,12 +119,12 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( let mut outer_pos: usize = 0; let mut activity_pos = 0; - let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); - let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); - let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); - let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap(); - let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); - let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap(); + let enzyme_const = cx.create_metadata(b"enzyme_const"); + let enzyme_out = cx.create_metadata(b"enzyme_out"); + let enzyme_dup = cx.create_metadata(b"enzyme_dup"); + let enzyme_dupv = cx.create_metadata(b"enzyme_dupv"); + let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed"); + let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv"); while activity_pos < inputs.len() { let diff_activity = inputs[activity_pos as usize]; @@ -223,7 +276,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // %0 = fmul double %x, %x // ret double %0 // } - // ``` + // // define double @dsquare(double %x) { // return 0.0; // } @@ -245,8 +298,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and // think a bit more about what should go here. - // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now - let cc = 8; + let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) }; let ad_fn = declare_simple_fn( cx, &ad_name, @@ -265,12 +317,12 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); - let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); + let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return"); if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { args.push(cx.get_metadata_value(enzyme_primal_ret)); } if attrs.width > 1 { - let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); + let enzyme_width = cx.create_metadata(b"enzyme_width"); args.push(cx.get_metadata_value(enzyme_width)); args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); } @@ -288,61 +340,3 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder.store_to_place(call, dest.val); } - -pub(crate) fn differentiate<'ll>( - module: &'ll ModuleCodegen, - cgcx: &CodegenContext, - diff_items: Vec, -) -> Result<(), FatalError> { - // TODO(Sa4dUs): delete all this logic - for item in &diff_items { - trace!("{}", item); - } - - let diag_handler = cgcx.create_dcx(); - - let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size); - - // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag? - if !diff_items.is_empty() - && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) - { - return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable)); - } - - // Here we replace the placeholder code with the actual autodiff code, which calls Enzyme. - for item in diff_items.iter() { - let name = item.source.clone(); - let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(_fn_def) = fn_def else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find source function".to_owned(), - }, - )); - }; - debug!(?item.target); - let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(_fn_target) = fn_target else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find target function".to_owned(), - }, - )); - }; - - // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); - } - - // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts - - trace!("done with differentiate()"); - - Ok(()) -} diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 8eb15571e82..da8c1e5f47b 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -8,7 +8,6 @@ use std::str; use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx}; use rustc_codegen_ssa::back::versioned_llvm_target; use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh}; -use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::errors as ssa_errors; use rustc_codegen_ssa::traits::*; use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN}; @@ -660,10 +659,6 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } impl<'ll> SimpleCx<'ll> { - pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type { - assert_eq!(self.type_kind(ty), TypeKind::Function); - unsafe { llvm::LLVMGetReturnType(ty) } - } pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type { unsafe { llvm::LLVMGlobalGetValueType(val) } } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 1102fc1d0c8..4935f8d7dff 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -9,21 +9,23 @@ use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphizati use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; -use rustc_hir as hir; use rustc_hir::def_id::LOCAL_CRATE; +use rustc_hir::{self as hir}; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; +use rustc_target::callconv::PassMode; use rustc_target::spec::PanicStrategy; use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; -use crate::builder::autodiff::generate_enzyme_call; +use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; use crate::context::CodegenCx; +use crate::errors::AutoDiffWithoutEnable; use crate::llvm::{self, Metadata}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; @@ -177,16 +179,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { span: Span, ) -> Result<(), ty::Instance<'tcx>> { let tcx = self.tcx; - let callee_ty = instance.ty(tcx, self.typing_env()); - let fn_args = instance.args; - - let sig = callee_ty.fn_sig(tcx); - let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig); - let ret_ty = sig.output(); let name = tcx.item_name(instance.def_id()); - - let llret_ty = self.layout_of(ret_ty).llvm_type(self); + let fn_args = instance.args; let simple = call_simple_intrinsic(self, name, args); let llval = match name { @@ -200,63 +195,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { ) } sym::autodiff => { - let val_arr: Vec<&'ll Value> = match args[2].val { - crate::intrinsic::OperandValue::Ref(ref place_value) => { - let mut ret_arr = vec![]; - let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout }; - - for i in 0..tuple_place.layout.layout.0.fields.count() { - let field_place = tuple_place.project_field(self, i); - let field_layout = tuple_place.layout.field(self, i); - let llvm_ty = field_layout.llvm_type(self.cx); - - let field_val = - self.load(llvm_ty, field_place.val.llval, field_place.val.align); - - ret_arr.push(field_val) - } - - ret_arr - } - crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2], - OperandValue::Immediate(v) => vec![v], - OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), - }; - - // Get source, diff, and attrs - let source_id = match fn_args.into_type_list(tcx)[0].kind() { - ty::FnDef(def_id, _) => def_id, - _ => bug!("invalid args"), - }; - let fn_source = Instance::mono(tcx, *source_id); - let source_symbol = - symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); - let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); - let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; - - let diff_id = match fn_args.into_type_list(tcx)[1].kind() { - ty::FnDef(def_id, _) => def_id, - _ => bug!("invalid args"), - }; - let fn_diff = Instance::mono(tcx, *diff_id); - let diff_symbol = - symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - - let diff_attrs = autodiff_attrs(tcx, *diff_id); - let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; - - // Build body - generate_enzyme_call( - self, - self.cx, - fn_to_diff, - &diff_symbol, - llret_ty, - &val_arr, - diff_attrs.clone(), - result, - ); - + codegen_autodiff(self, tcx, instance, args, result); return Ok(()); } sym::is_val_statically_known => { @@ -1183,6 +1122,143 @@ fn get_rust_try_fn<'a, 'll, 'tcx>( rust_try } +fn codegen_autodiff<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + args: &[OperandRef<'tcx, &'ll Value>], + result: PlaceRef<'tcx, &'ll Value>, +) { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutEnable); + } + + let fn_args = instance.args; + let callee_ty = instance.ty(tcx, bx.typing_env()); + + let sig = callee_ty.fn_sig(tcx).skip_binder(); + + let ret_ty = sig.output(); + let llret_ty = bx.layout_of(ret_ty).llvm_type(bx); + + // Get source, diff, and attrs + let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, source_params) => (def_id, source_params), + _ => bug!("invalid autodiff intrinsic args"), + }; + + let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) { + Ok(Some(instance)) => instance, + Ok(None) => bug!( + "could not resolve ({:?}, {:?}) to a specific autodiff instance", + source_id, + source_args + ), + Err(_) => { + // An error has already been emitted + return; + } + }; + + let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else { + bug!("could not find source function") + }; + + let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, diff_args) => (def_id, diff_args), + _ => bug!("invalid args"), + }; + + let fn_diff = match Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args) { + Ok(Some(instance)) => instance, + Ok(None) => bug!( + "could not resolve ({:?}, {:?}) to a specific autodiff instance", + diff_id, + diff_args + ), + Err(_) => { + // An error has already been emitted + return; + } + }; + + let val_arr = get_args_from_tuple(bx, args[2], fn_diff); + let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else { + bug!("could not find autodiff attrs") + }; + + adjust_activity_to_abi( + tcx, + fn_source.ty(tcx, TypingEnv::fully_monomorphized()), + &mut diff_attrs.input_activity, + ); + + // Build body + generate_enzyme_call( + bx, + bx.cx, + fn_to_diff, + &diff_symbol, + llret_ty, + &val_arr, + diff_attrs.clone(), + result, + ); +} + +fn get_args_from_tuple<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + tuple_op: OperandRef<'tcx, &'ll Value>, + fn_instance: Instance<'tcx>, +) -> Vec<&'ll Value> { + let cx = bx.cx; + let fn_abi = cx.fn_abi_of_instance(fn_instance, ty::List::empty()); + + match tuple_op.val { + OperandValue::Immediate(val) => vec![val], + OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Ref(ptr) => { + let tuple_place = PlaceRef { val: ptr, layout: tuple_op.layout }; + + let mut result = Vec::with_capacity(fn_abi.args.len()); + let mut tuple_index = 0; + + for arg in &fn_abi.args { + match arg.mode { + PassMode::Ignore => {} + PassMode::Direct(_) | PassMode::Cast { .. } => { + let field = tuple_place.project_field(bx, tuple_index); + let llvm_ty = field.layout.llvm_type(bx.cx); + let val = bx.load(llvm_ty, field.val.llval, field.val.align); + result.push(val); + tuple_index += 1; + } + PassMode::Pair(_, _) => { + let field = tuple_place.project_field(bx, tuple_index); + let llvm_ty = field.layout.llvm_type(bx.cx); + let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align); + result.push(bx.extract_value(pair_val, 0)); + result.push(bx.extract_value(pair_val, 1)); + tuple_index += 1; + } + PassMode::Indirect { .. } => { + let field = tuple_place.project_field(bx, tuple_index); + result.push(field.val.llval); + tuple_index += 1; + } + } + } + + result + } + + OperandValue::ZeroSized => vec![], + } +} + fn generic_simd_intrinsic<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, name: Symbol, diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index ca84b6de8b1..79e80db6f55 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -30,7 +30,6 @@ use context::SimpleCx; use errors::ParseTargetMachineConfig; use llvm_util::target_config; use rustc_ast::expand::allocator::AllocatorKind; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -173,15 +172,10 @@ impl WriteBackendMethods for LlvmCodegenBackend { exported_symbols_for_lto: &[String], each_linked_rlib_for_lto: &[PathBuf], modules: Vec>, - diff_fncs: Vec, ) -> Result, FatalError> { let mut module = back::lto::run_fat(cgcx, exported_symbols_for_lto, each_linked_rlib_for_lto, modules)?; - if !diff_fncs.is_empty() { - builder::autodiff::differentiate(&module, cgcx, diff_fncs)?; - } - let dcx = cgcx.create_dcx(); let dcx = dcx.handle(); back::lto::run_pass_manager(cgcx, dcx, &mut module, false)?; diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index aa29afb7f5b..2e8122798d1 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -7,7 +7,6 @@ use std::{fs, io, mem, str, thread}; use rustc_abi::Size; use rustc_ast::attr; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::fx::FxIndexMap; use rustc_data_structures::jobserver::{self, Acquired}; use rustc_data_structures::memmap::Mmap; @@ -38,7 +37,7 @@ use tracing::debug; use super::link::{self, ensure_removed}; use super::lto::{self, SerializedModule}; use crate::back::lto::check_lto_allowed; -use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir}; +use crate::errors::ErrorCreatingRemarkDir; use crate::traits::*; use crate::{ CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind, @@ -454,7 +453,6 @@ pub(crate) fn start_async_codegen( backend: B, tcx: TyCtxt<'_>, target_cpu: String, - autodiff_items: &[AutoDiffItem], ) -> OngoingCodegen { let (coordinator_send, coordinator_receive) = channel(); @@ -473,7 +471,6 @@ pub(crate) fn start_async_codegen( backend.clone(), tcx, &crate_info, - autodiff_items, shared_emitter, codegen_worker_send, coordinator_receive, @@ -728,7 +725,6 @@ pub(crate) enum WorkItem { each_linked_rlib_for_lto: Vec, needs_fat_lto: Vec>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, - autodiff: Vec, }, /// Performs thin-LTO on the given module. ThinLto(lto::ThinModule), @@ -1001,7 +997,6 @@ fn execute_fat_lto_work_item( each_linked_rlib_for_lto: &[PathBuf], mut needs_fat_lto: Vec>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, - autodiff: Vec, module_config: &ModuleConfig, ) -> Result, FatalError> { for (module, wp) in import_only_modules { @@ -1013,7 +1008,6 @@ fn execute_fat_lto_work_item( exported_symbols_for_lto, each_linked_rlib_for_lto, needs_fat_lto, - autodiff, )?; let module = B::codegen(cgcx, module, module_config)?; Ok(WorkItemResult::Finished(module)) @@ -1105,7 +1099,6 @@ fn start_executing_work( backend: B, tcx: TyCtxt<'_>, crate_info: &CrateInfo, - autodiff_items: &[AutoDiffItem], shared_emitter: SharedEmitter, codegen_worker_send: Sender, coordinator_receive: Receiver>, @@ -1115,7 +1108,6 @@ fn start_executing_work( ) -> thread::JoinHandle> { let coordinator_send = tx_to_llvm_workers; let sess = tcx.sess; - let autodiff_items = autodiff_items.to_vec(); let mut each_linked_rlib_for_lto = Vec::new(); let mut each_linked_rlib_file_for_lto = Vec::new(); @@ -1448,7 +1440,6 @@ fn start_executing_work( each_linked_rlib_for_lto: each_linked_rlib_file_for_lto, needs_fat_lto, import_only_modules, - autodiff: autodiff_items.clone(), }, 0, )); @@ -1456,11 +1447,6 @@ fn start_executing_work( helper.request_token(); } } else { - if !autodiff_items.is_empty() { - let dcx = cgcx.create_dcx(); - dcx.handle().emit_fatal(AutodiffWithoutLto {}); - } - for (work, cost) in generate_thin_lto_work( &cgcx, &exported_symbols_for_lto, @@ -1795,7 +1781,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>( each_linked_rlib_for_lto, needs_fat_lto, import_only_modules, - autodiff, } => { let _timer = cgcx .prof @@ -1806,7 +1791,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>( &each_linked_rlib_for_lto, needs_fat_lto, import_only_modules, - autodiff, module_config, ) } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index b4556ced0b3..b483c01da59 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -647,7 +647,7 @@ pub fn codegen_crate( ) -> OngoingCodegen { // Skip crate items and just output metadata in -Z no-codegen mode. if tcx.sess.opts.unstable_opts.no_codegen || !tcx.sess.opts.output_types.should_codegen() { - let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu, &[]); + let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu); ongoing_codegen.codegen_finished(tcx); @@ -665,8 +665,7 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let MonoItemPartitions { codegen_units, autodiff_items, .. } = - tcx.collect_and_partition_mono_items(()); + let MonoItemPartitions { codegen_units, .. } = tcx.collect_and_partition_mono_items(()); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -679,7 +678,7 @@ pub fn codegen_crate( } } - let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu, autodiff_items); + let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu); // Codegen an allocator shim, if necessary. if let Some(kind) = allocator_kind_for_codegen(tcx) { diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index a36a772bc97..af70f0deb07 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -177,14 +177,6 @@ fn process_builtin_attrs( let mut interesting_spans = InterestingAttributeDiagnosticSpans::default(); let rust_target_features = tcx.rust_target_features(LOCAL_CRATE); - // If our rustc version supports autodiff/enzyme, then we call our handler - // to check for any `#[rustc_autodiff(...)]` attributes. - // FIXME(jdonszelmann): merge with loop below - if cfg!(llvm_enzyme) { - let ad = autodiff_attrs(tcx, did.into()); - codegen_fn_attrs.autodiff_item = ad; - } - for attr in attrs.iter() { if let hir::Attribute::Parsed(p) = attr { match p { @@ -612,7 +604,7 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option { /// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never /// panic, unless we introduced a bug when parsing the autodiff macro. //FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed. -fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { +pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { let attrs = tcx.get_attrs(id, sym::rustc_autodiff); let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::>(); diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index f391c198e1a..c29ad90735b 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,6 +1,5 @@ use std::path::PathBuf; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -23,7 +22,6 @@ pub trait WriteBackendMethods: Clone + 'static { exported_symbols_for_lto: &[String], each_linked_rlib_for_lto: &[PathBuf], modules: Vec>, - diff_fncs: Vec, ) -> Result, FatalError>; /// Performs thin LTO by performing necessary global analysis and returning two /// lists, one of the modules that need optimization and another for modules that diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 46371cfe591..f50aed0b3c2 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -172,8 +172,6 @@ pub(crate) fn check_intrinsic_type( } }; - let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); - let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -198,6 +196,7 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; + let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { sym::autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), diff --git a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs index 2852c4cbd34..7d2fc0995aa 100644 --- a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs +++ b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; use rustc_abi::Align; -use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_hir::attrs::{InlineAttr, InstructionSetAttr, Linkage, OptimizeAttr}; use rustc_hir::def_id::DefId; use rustc_macros::{HashStable, TyDecodable, TyEncodable}; @@ -75,8 +74,6 @@ pub struct CodegenFnAttrs { /// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around /// the function entry. pub patchable_function_entry: Option, - /// For the `#[autodiff]` macros. - pub autodiff_item: Option, } #[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)] @@ -182,7 +179,6 @@ impl CodegenFnAttrs { instruction_set: None, alignment: None, patchable_function_entry: None, - autodiff_item: None, } } diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index 440771b3d68..0e6f797b1e4 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -2,7 +2,6 @@ use std::borrow::Cow; use std::fmt; use std::hash::Hash; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN}; use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fx::FxIndexMap; @@ -336,7 +335,6 @@ impl ToStableHashKey> for MonoItem<'_> { pub struct MonoItemPartitions<'tcx> { pub codegen_units: &'tcx [CodegenUnit<'tcx>], pub all_mono_items: &'tcx DefIdSet, - pub autodiff_items: &'tcx [AutoDiffItem], } #[derive(Debug, HashStable)] diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index 0ed5b4fc0d0..09a55f0b5f8 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -6,7 +6,6 @@ edition = "2024" [dependencies] # tidy-alphabetical-start rustc_abi = { path = "../rustc_abi" } -rustc_ast = { path = "../rustc_ast" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_fluent_macro = { path = "../rustc_fluent_macro" } @@ -15,7 +14,6 @@ rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } -rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } rustc_target = { path = "../rustc_target" } serde = "1" serde_json = "1" diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 26ca8518434..af2c3177067 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -205,6 +205,8 @@ //! this is not implemented however: a mono item will be produced //! regardless of whether it is actually needed or not. +mod autodiff; + use std::cell::OnceCell; use rustc_data_structures::fx::FxIndexMap; @@ -235,6 +237,7 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan}; use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument, trace}; +use crate::collector::autodiff::collect_autodiff_fn; use crate::errors::{ self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm, NoOptimizedMir, RecursionLimit, @@ -911,6 +914,8 @@ fn visit_instance_use<'tcx>( return; } if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) { + collect_autodiff_fn(tcx, instance, intrinsic, output); + if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) { // The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will // be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs new file mode 100644 index 00000000000..13868cca944 --- /dev/null +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -0,0 +1,48 @@ +use rustc_middle::bug; +use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt}; + +use crate::collector::{MonoItems, create_fn_mono_item}; + +// Here, we force both primal and diff function to be collected in +// mono so this does not interfere in `autodiff` intrinsics +// codegen process. If they are unused, LLVM will remove them when +// compiling with O3. +pub(crate) fn collect_autodiff_fn<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + intrinsic: IntrinsicDef, + output: &mut MonoItems<'tcx>, +) { + if intrinsic.name != rustc_span::sym::autodiff { + return; + }; + + collect_autodiff_fn_from_arg(instance.args[0], tcx, output); +} + +fn collect_autodiff_fn_from_arg<'tcx>( + arg: GenericArg<'tcx>, + tcx: TyCtxt<'tcx>, + output: &mut MonoItems<'tcx>, +) { + let (instance, span) = match arg.kind() { + ty::GenericArgKind::Type(ty) => match ty.kind() { + ty::FnDef(def_id, substs) => { + let span = tcx.def_span(def_id); + let instance = ty::Instance::expect_resolve( + tcx, + ty::TypingEnv::non_body_analysis(tcx, def_id), + *def_id, + substs, + span, + ); + + (instance, span) + } + _ => bug!("expected autodiff function"), + }, + _ => bug!("expected type when matching autodiff arg"), + }; + + output.push(create_fn_mono_item(tcx, instance, span)); +} diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 628ea2b63de..d784d3540c4 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,8 +92,6 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. -mod autodiff; - use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; @@ -251,17 +249,7 @@ where always_export_generics, ); - // We can't differentiate a function that got inlined. - let autodiff_active = cfg!(llvm_enzyme) - && matches!(mono_item, MonoItem::Fn(_)) - && cx - .tcx - .codegen_fn_attrs(mono_item.def_id()) - .autodiff_item - .as_ref() - .is_some_and(|ad| ad.is_active()); - - if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { + if visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1157,27 +1145,15 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio } } - #[cfg(not(llvm_enzyme))] - let autodiff_mono_items: Vec<_> = vec![]; - #[cfg(llvm_enzyme)] - let mut autodiff_mono_items: Vec<_> = vec![]; let mono_items: DefIdSet = items .iter() .filter_map(|mono_item| match *mono_item { - MonoItem::Fn(ref instance) => { - #[cfg(llvm_enzyme)] - autodiff_mono_items.push((mono_item, instance)); - Some(instance.def_id()) - } + MonoItem::Fn(ref instance) => Some(instance.def_id()), MonoItem::Static(def_id) => Some(def_id), _ => None, }) .collect(); - let autodiff_items = - autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items); - let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); - // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats && let Err(err) = @@ -1235,11 +1211,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio } } - MonoItemPartitions { - all_mono_items: tcx.arena.alloc(mono_items), - codegen_units, - autodiff_items, - } + MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units } } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs deleted file mode 100644 index 22d593b80b8..00000000000 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ /dev/null @@ -1,143 +0,0 @@ -use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; -use rustc_hir::def_id::LOCAL_CRATE; -use rustc_middle::bug; -use rustc_middle::mir::mono::MonoItem; -use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; -use rustc_symbol_mangling::symbol_name_for_instance_in_crate; -use tracing::{debug, trace}; - -use crate::partitioning::UsageMap; - -fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec) { - if !matches!(fn_ty.kind(), ty::FnDef(..)) { - bug!("expected fn def for autodiff, got {:?}", fn_ty); - } - - // We don't actually pass the types back into the type system. - // All we do is decide how to handle the arguments. - let sig = fn_ty.fn_sig(tcx).skip_binder(); - - let mut new_activities = vec![]; - let mut new_positions = vec![]; - for (i, ty) in sig.inputs().iter().enumerate() { - if let Some(inner_ty) = ty.builtin_deref(true) { - if inner_ty.is_slice() { - // Now we need to figure out the size of each slice element in memory to allow - // safety checks and usability improvements in the backend. - let sty = match inner_ty.builtin_index() { - Some(sty) => sty, - None => { - panic!("slice element type unknown"); - } - }; - let pci = PseudoCanonicalInput { - typing_env: TypingEnv::fully_monomorphized(), - value: sty, - }; - - let layout = tcx.layout_of(pci); - let elem_size = match layout { - Ok(layout) => layout.size, - Err(_) => { - bug!("autodiff failed to compute slice element size"); - } - }; - let elem_size: u32 = elem_size.bytes() as u32; - - // We know that the length will be passed as extra arg. - if !da.is_empty() { - // We are looking at a slice. The length of that slice will become an - // extra integer on llvm level. Integers are always const. - // However, if the slice get's duplicated, we want to know to later check the - // size. So we mark the new size argument as FakeActivitySize. - // There is one FakeActivitySize per slice, so for convenience we store the - // slice element size in bytes in it. We will use the size in the backend. - let activity = match da[i] { - DiffActivity::DualOnly - | DiffActivity::Dual - | DiffActivity::Dualv - | DiffActivity::DuplicatedOnly - | DiffActivity::Duplicated => { - DiffActivity::FakeActivitySize(Some(elem_size)) - } - DiffActivity::Const => DiffActivity::Const, - _ => bug!("unexpected activity for ptr/ref"), - }; - new_activities.push(activity); - new_positions.push(i + 1); - } - - continue; - } - } - } - // now add the extra activities coming from slices - // Reverse order to not invalidate the indices - for _ in 0..new_activities.len() { - let pos = new_positions.pop().unwrap(); - let activity = new_activities.pop().unwrap(); - da.insert(pos, activity); - } -} - -pub(crate) fn find_autodiff_source_functions<'tcx>( - tcx: TyCtxt<'tcx>, - usage_map: &UsageMap<'tcx>, - autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>, -) -> Vec { - let mut autodiff_items: Vec = vec![]; - for (item, instance) in autodiff_mono_items { - let target_id = instance.def_id(); - let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item; - let Some(target_attrs) = cg_fn_attr else { - continue; - }; - let mut input_activities: Vec = target_attrs.input_activity.clone(); - if target_attrs.is_source() { - trace!("source found: {:?}", target_id); - } - if !target_attrs.apply_autodiff() { - continue; - } - - let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); - - let source = - usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item { - MonoItem::Fn(ref instance_s) => { - let source_id = instance_s.def_id(); - if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item - && ad.is_active() - { - return Some(instance_s); - } - None - } - _ => None, - }); - let inst = match source { - Some(source) => source, - None => continue, - }; - - debug!("source_id: {:?}", inst.def_id()); - let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized()); - assert!(fn_ty.is_fn()); - adjust_activity_to_abi(tcx, fn_ty, &mut input_activities); - let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); - - let mut new_target_attrs = target_attrs.clone(); - new_target_attrs.input_activity = input_activities; - let itm = new_target_attrs.into_item(symb, target_symbol); - autodiff_items.push(itm); - } - - if !autodiff_items.is_empty() { - trace!("AUTODIFF ITEMS EXIST"); - for item in &mut *autodiff_items { - trace!("{}", &item); - } - } - - autodiff_items -} diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 6c389c55a5f..dd838d494bc 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3157,6 +3157,40 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64; #[rustc_intrinsic] pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; +/// Generates the LLVM body for the automatic differentiation of `f` using Enzyme, +/// with `df` as the derivative function and `args` as its arguments. +/// +/// Used internally as the body of `df` when expanding the `#[autodiff_forward]` +/// and `#[autodiff_reverse]` attribute macros. +/// +/// Type Parameters: +/// - `F`: The original function to differentiate. Must be a function item. +/// - `G`: The derivative function. Must be a function item. +/// - `T`: A tuple of arguments passed to `df`. +/// - `R`: The return type of the derivative function. +/// +/// This shows where the `autodiff` intrinsic is used during macro expansion: +/// +/// ```rust,ignore (macro example) +/// #[autodiff_forward(df1, Dual, Const, Dual)] +/// pub fn f1(x: &[f64], y: f64) -> f64 { +/// unimplemented!() +/// } +/// ``` +/// +/// expands to: +/// +/// ```rust,ignore (macro example) +/// #[rustc_autodiff] +/// #[inline(never)] +/// pub fn f1(x: &[f64], y: f64) -> f64 { +/// ::core::panicking::panic("not implemented") +/// } +/// #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] +/// pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { +/// ::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y)) +/// } +/// ``` #[rustc_nounwind] #[rustc_intrinsic] pub const fn autodiff(f: F, df: G, args: T) -> R; diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index c59290a757b..888369d73f4 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1494,6 +1494,7 @@ pub(crate) mod builtin { /// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities. #[unstable(feature = "autodiff", issue = "124509")] #[allow_internal_unstable(rustc_attrs)] + #[allow_internal_unstable(core_intrinsics)] #[rustc_builtin_macro] pub macro autodiff_forward($item:item) { /* compiler built-in */ @@ -1512,6 +1513,7 @@ pub(crate) mod builtin { /// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities. #[unstable(feature = "autodiff", issue = "124509")] #[allow_internal_unstable(rustc_attrs)] + #[allow_internal_unstable(core_intrinsics)] #[rustc_builtin_macro] pub macro autodiff_reverse($item:item) { /* compiler built-in */ diff --git a/src/doc/rustc-dev-guide/src/SUMMARY.md b/src/doc/rustc-dev-guide/src/SUMMARY.md index 9ded467d5cd..025a078ae5b 100644 --- a/src/doc/rustc-dev-guide/src/SUMMARY.md +++ b/src/doc/rustc-dev-guide/src/SUMMARY.md @@ -107,7 +107,6 @@ - [Installation](./autodiff/installation.md) - [How to debug](./autodiff/debugging.md) - [Autodiff flags](./autodiff/flags.md) - - [Current limitations](./autodiff/limitations.md) # Source Code Representation diff --git a/src/doc/rustc-dev-guide/src/autodiff/limitations.md b/src/doc/rustc-dev-guide/src/autodiff/limitations.md deleted file mode 100644 index 90afbd51f3f..00000000000 --- a/src/doc/rustc-dev-guide/src/autodiff/limitations.md +++ /dev/null @@ -1,27 +0,0 @@ -# Current limitations - -## Safety and Soundness - -Enzyme currently assumes that the user passes shadow arguments (`dx`, `dy`, ...) of appropriate size. Under Reverse Mode, we additionally assume that shadow arguments are mutable. In Reverse Mode we adjust the outermost pointer or reference to be mutable. Therefore `&f32` will receive the shadow type `&mut f32`. However, we do not check length for other types than slices (e.g. enums, Vec). We also do not enforce mutability of inner references, but will warn if we recognize them. We do intend to add additional checks over time. - -## ABI adjustments - -In some cases, a function parameter might get lowered in a way that we currently don't handle correctly, leading to a compile time type mismatch in the `rustc_codegen_llvm` backend. Here are some [examples](https://github.com/EnzymeAD/rust/issues/105). - -## Compile Times - -Enzyme will often achieve excellent runtime performance, but might increase your compile time by a large factor. For Rust, we already have made significant improvements and have a list of further improvements planed - please reach out if you have time to help here. - -### Type Analysis - -Most of the times, Type Analysis (TA) is the reason of large (>5x) compile time increases when using Enzyme. This poster explains why we need to run Type Analysis in the bottom left part: [Poster Link](https://c.wsmoses.com/posters/Enzyme-llvmdev.pdf). - -We intend to increase the number of locations where we pass down Type information based on Rust types, which in turn will reduce the number of locations where Enzyme has to run Type Analysis, which will help compile times. - -### Duplicated Optimizations - -The key reason for Enzyme offering often excellent performance is that Enzyme differentiates already optimized LLVM-IR. However, we also (have to) run LLVM's optimization pipeline after differentiating, to make sure that the code which Enzyme generates is optimized properly. As a result you should have excellent runtime performance (please fill an issue if not), but at a compile time cost for running optimizations twice. - -### Fat-LTO - -The usage of `#[autodiff(...)]` currently requires compiling your project with Fat-LTO. We technically only need LTO if the function being differentiated calls functions in other compilation units. Therefore, other solutions are possible, but this is the most simple one to get started. diff --git a/triagebot.toml b/triagebot.toml index 6f6e95c5b50..df81bb71160 100644 --- a/triagebot.toml +++ b/triagebot.toml @@ -283,7 +283,6 @@ trigger_files = [ "src/tools/enzyme", "src/doc/unstable-book/src/compiler-flags/autodiff.md", "compiler/rustc_ast/src/expand/autodiff_attrs.rs", - "compiler/rustc_monomorphize/src/partitioning/autodiff.rs", "compiler/rustc_codegen_llvm/src/builder/autodiff.rs", "compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs", ] @@ -1285,8 +1284,6 @@ cc = ["@ZuseZ4"] cc = ["@ZuseZ4"] [mentions."compiler/rustc_builtin_macros/src/autodiff.rs"] cc = ["@ZuseZ4"] -[mentions."compiler/rustc_monomorphize/src/partitioning/autodiff.rs"] -cc = ["@ZuseZ4"] [mentions."compiler/rustc_codegen_llvm/src/builder/autodiff.rs"] cc = ["@ZuseZ4"] [mentions."compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs"] -- cgit 1.4.1-3-g733a5 From c9c1c171289aa575040678f3d0c005342b1e29e1 Mon Sep 17 00:00:00 2001 From: Marcelo Domínguez Date: Thu, 14 Aug 2025 15:29:37 +0000 Subject: Remove inlining for autodiff handling --- compiler/rustc_builtin_macros/src/autodiff.rs | 32 ++++++++++++++-------- .../rustc_codegen_llvm/src/builder/autodiff.rs | 8 +----- 2 files changed, 21 insertions(+), 19 deletions(-) (limited to 'compiler/rustc_codegen_llvm/src') diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index c260dca87c0..48d0795af5e 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -192,7 +192,6 @@ mod llvm_enzyme { /// which becomes expanded to: /// ``` /// #[rustc_autodiff] - /// #[inline(never)] /// fn sin(x: &Box) -> f32 { /// f32::sin(**x) /// } @@ -371,7 +370,7 @@ mod llvm_enzyme { let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let inline_never = outer_normal_attr(&inline_never_attr, new_id, span); - // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`. + // We're avoid duplicating the attribute `#[rustc_autodiff]`. fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool { match (attr, item) { (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => { @@ -384,14 +383,16 @@ mod llvm_enzyme { } } + let mut has_inline_never = false; + // Don't add it multiple times: let orig_annotatable: Annotatable = match item { Annotatable::Item(ref mut iitem) => { if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { iitem.attrs.push(attr); } - if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { - iitem.attrs.push(inline_never.clone()); + if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { + has_inline_never = true; } Annotatable::Item(iitem.clone()) } @@ -399,8 +400,8 @@ mod llvm_enzyme { if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { assoc_item.attrs.push(attr); } - if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { - assoc_item.attrs.push(inline_never.clone()); + if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { + has_inline_never = true; } Annotatable::AssocItem(assoc_item.clone(), i) } @@ -410,9 +411,8 @@ mod llvm_enzyme { if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { iitem.attrs.push(attr); } - if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) - { - iitem.attrs.push(inline_never.clone()); + if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { + has_inline_never = true; } } _ => unreachable!("stmt kind checked previously"), @@ -433,11 +433,19 @@ mod llvm_enzyme { let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); + + // If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function + let mut d_attrs = thin_vec![d_attr]; + + if has_inline_never { + d_attrs.push(inline_never); + } + let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn); let d_fn = Box::new(ast::AssocItem { - attrs: thin_vec![d_attr], + attrs: d_attrs, id: ast::DUMMY_NODE_ID, span, vis, @@ -447,13 +455,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); + let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); + let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Stmt(Box::new(ast::Stmt { diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 56116959a62..e2df3265f6f 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -10,10 +10,9 @@ use tracing::debug; use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; -use crate::llvm::AttributePlace::Function; +use crate::llvm; use crate::llvm::{Metadata, True, Type}; use crate::value::Value; -use crate::{attributes, llvm}; pub(crate) fn adjust_activity_to_abi<'tcx>( tcx: TyCtxt<'tcx>, @@ -308,11 +307,6 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( enzyme_ty, ); - // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to - // do it's work. - let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); - attributes::apply_to_llfn(ad_fn, Function, &[attr]); - let num_args = llvm::LLVMCountParams(&fn_to_diff); let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); -- cgit 1.4.1-3-g733a5 From e1d79b9aad9cfcfcf23c68ec8625411819d4e3f7 Mon Sep 17 00:00:00 2001 From: Marcelo Domínguez Date: Thu, 14 Aug 2025 08:21:15 +0000 Subject: Remove lto inline logic --- compiler/rustc_codegen_llvm/src/attributes.rs | 16 --------------- compiler/rustc_codegen_llvm/src/back/lto.rs | 28 +-------------------------- compiler/rustc_codegen_llvm/src/context.rs | 10 ---------- compiler/rustc_codegen_llvm/src/llvm/mod.rs | 26 ------------------------- 4 files changed, 1 insertion(+), 79 deletions(-) (limited to 'compiler/rustc_codegen_llvm/src') diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index c548f467583..a6daacd95ef 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -28,22 +28,6 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[ } } -pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -> bool { - llvm::HasAttributeAtIndex(llfn, idx, attr) -} - -pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool { - llvm::HasStringAttribute(llfn, name) -} - -pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) { - llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind); -} - -pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) { - llvm::RemoveStringAttrFromFn(llfn, name); -} - /// Get LLVM attribute for the provided inline heuristic. #[inline] fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> { diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index c269f11e931..853d0295238 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -24,9 +24,8 @@ use crate::back::write::{ self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode, }; use crate::errors::{LlvmError, LtoBitcodeFromRlib}; -use crate::llvm::AttributePlace::Function; use crate::llvm::{self, build_string}; -use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes}; +use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx}; /// We keep track of the computed LTO cache keys from the previous /// session to determine which CGUs we can reuse. @@ -593,31 +592,6 @@ pub(crate) fn run_pass_manager( } if cfg!(llvm_enzyme) && enable_ad && !thin { - let cx = - SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); - - for function in cx.get_functions() { - let enzyme_marker = "enzyme_marker"; - if attributes::has_string_attr(function, enzyme_marker) { - // Sanity check: Ensure 'noinline' is present before replacing it. - assert!( - attributes::has_attr(function, Function, llvm::AttributeKind::NoInline), - "Expected __enzyme function to have 'noinline' before adding 'alwaysinline'" - ); - - attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); - attributes::remove_string_attr_from_llfn(function, enzyme_marker); - - assert!( - !attributes::has_string_attr(function, enzyme_marker), - "Expected function to not have 'enzyme_marker'" - ); - - let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx); - attributes::apply_to_llfn(function, Function, &[always_inline]); - } - } - let opt_stage = llvm::OptStage::FatLTO; let stage = write::AutodiffStage::PostAD; if !config.autodiff.contains(&config::AutoDiff::NoPostopt) { diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index da8c1e5f47b..b0f3494ea68 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -722,16 +722,6 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len()) } } - - pub(crate) fn get_functions(&self) -> Vec<&'ll Value> { - let mut functions = vec![]; - let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) }; - while let Some(f) = func { - functions.push(f); - func = unsafe { llvm::LLVMGetNextFunction(f) } - } - functions - } } impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 154ba4fd690..0ea0af0c9af 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -42,32 +42,6 @@ pub(crate) fn AddFunctionAttributes<'ll>( } } -pub(crate) fn HasAttributeAtIndex<'ll>( - llfn: &'ll Value, - idx: AttributePlace, - kind: AttributeKind, -) -> bool { - unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) } -} - -pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool { - unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) } -} - -pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) { - unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) } -} - -pub(crate) fn RemoveRustEnumAttributeAtIndex( - llfn: &Value, - place: AttributePlace, - kind: AttributeKind, -) { - unsafe { - LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind); - } -} - pub(crate) fn AddCallSiteAttributes<'ll>( callsite: &'ll Value, idx: AttributePlace, -- cgit 1.4.1-3-g733a5