about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-01-19 06:44:19 +0000
committerbors <bors@rust-lang.org>2024-01-19 06:44:19 +0000
commit16fadb3f252bcfc5ee3f0be09472c9600a052202 (patch)
tree8a24b914c19778b1ae858e2d8cf21f83572dad72
parent1bd42be8cb707aadaf2068d5ac186154970c4d80 (diff)
parente68f3039d4f2b12dcfc348ebb50fd7855e6d7fd1 (diff)
downloadrust-16fadb3f252bcfc5ee3f0be09472c9600a052202.tar.gz
rust-16fadb3f252bcfc5ee3f0be09472c9600a052202.zip
Auto merge of #120069 - Mark-Simulacrum:fast-memcpy, r=oli-obk
Optimize large array creation in const-eval

This changes repeated memcpy's to a memset for the case that we're propagating a single byte into a region of memory. It also optimizes the element-by-element copies to have a tighter loop; I'm pretty sure the old code was actually doing a multiply within each loop iteration.

For an 8GB array (`static SLICE: [u8; SIZE] = [0u8; 1 << 33];`) this takes us from ~23 seconds to ~6 seconds locally, which is spent roughly 50/50 in (a) memset to zero and (b) memcpy of the original place into a new place, when popping stack frame. The latter seems hard to avoid but is a big memcpy (since we're copying the type rather than initializing a region, so it's pretty fast), and the first is as good as it's going to get without special casing constant-valued arrays.

Closes https://github.com/rust-lang/rust/issues/55795. (That issue's references to lint checking don't appear true anymore, but I think this closes that case as something that is slow due to *time* pretty fully. An 8GB array taking only 6 seconds feels reasonable enough to not merit further tracking).
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs31
1 files changed, 19 insertions, 12 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index 7ff970661d6..3afd14eb574 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -1209,21 +1209,28 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                         throw_ub_custom!(fluent::const_eval_copy_nonoverlapping_overlapping);
                     }
                 }
+            }
 
-                for i in 0..num_copies {
-                    ptr::copy(
-                        src_bytes,
-                        dest_bytes.add((size * i).bytes_usize()), // `Size` multiplication
-                        size.bytes_usize(),
-                    );
+            let size_in_bytes = size.bytes_usize();
+            // For particularly large arrays (where this is perf-sensitive) it's common that
+            // we're writing a single byte repeatedly. So, optimize that case to a memset.
+            if size_in_bytes == 1 && num_copies >= 1 {
+                // SAFETY: `src_bytes` would be read from anyway by copies below (num_copies >= 1).
+                // Since size_in_bytes = 1, then the `init.no_bytes_init()` check above guarantees
+                // that this read at type `u8` is OK -- it must be an initialized byte.
+                let value = *src_bytes;
+                dest_bytes.write_bytes(value, (size * num_copies).bytes_usize());
+            } else if src_alloc_id == dest_alloc_id {
+                let mut dest_ptr = dest_bytes;
+                for _ in 0..num_copies {
+                    ptr::copy(src_bytes, dest_ptr, size_in_bytes);
+                    dest_ptr = dest_ptr.add(size_in_bytes);
                 }
             } else {
-                for i in 0..num_copies {
-                    ptr::copy_nonoverlapping(
-                        src_bytes,
-                        dest_bytes.add((size * i).bytes_usize()), // `Size` multiplication
-                        size.bytes_usize(),
-                    );
+                let mut dest_ptr = dest_bytes;
+                for _ in 0..num_copies {
+                    ptr::copy_nonoverlapping(src_bytes, dest_ptr, size_in_bytes);
+                    dest_ptr = dest_ptr.add(size_in_bytes);
                 }
             }
         }