about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2025-09-17 13:56:54 +0000
committerbors <bors@rust-lang.org>2025-09-17 13:56:54 +0000
commit5d1b897a07dc30d810dd541795125c1c216266c7 (patch)
tree039c32657725cddbb0528af54157fc9385063a95 /compiler
parentce6daf3d5a5bffb2a00264197f92dc31608df0da (diff)
parent7abbc9c8b2c638752e2a6b0913ed1e93e14d21ba (diff)
downloadrust-5d1b897a07dc30d810dd541795125c1c216266c7.tar.gz
rust-5d1b897a07dc30d810dd541795125c1c216266c7.zip
Auto merge of #146331 - RalfJung:copy-prov-repeat, r=oli-obk
interpret: copy_provenance: avoid large intermediate buffer for large repeat counts

Copying provenance worked in this odd way where the "preparation" phase (which is supposed to just extract the necessary information from the source range) already did all the work of repeating the result N times for the target range. This was needed to use the existing `insert_presorted` function on `SortedMap`.

This PR generalizes `insert_presorted` so that we can avoid this odd structure on copy-provenance, and maybe even improve performance.
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs4
-rw-r--r--compiler/rustc_data_structures/src/lib.rs1
-rw-r--r--compiler/rustc_data_structures/src/sorted_map.rs35
-rw-r--r--compiler/rustc_data_structures/src/sorted_map/tests.rs10
-rw-r--r--compiler/rustc_middle/src/mir/interpret/allocation.rs9
-rw-r--r--compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs111
6 files changed, 88 insertions, 82 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index ebcdb9461d0..323e1cefd58 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -1504,7 +1504,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         // This will also error if copying partial provenance is not supported.
         let provenance = src_alloc
             .provenance()
-            .prepare_copy(src_range, dest_offset, num_copies, self)
+            .prepare_copy(src_range, self)
             .map_err(|e| e.to_interp_error(src_alloc_id))?;
         // Prepare a copy of the initialization mask.
         let init = src_alloc.init_mask().prepare_copy(src_range);
@@ -1590,7 +1590,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
             num_copies,
         );
         // copy the provenance to the destination
-        dest_alloc.provenance_apply_copy(provenance);
+        dest_alloc.provenance_apply_copy(provenance, alloc_range(dest_offset, size), num_copies);
 
         interp_ok(())
     }
diff --git a/compiler/rustc_data_structures/src/lib.rs b/compiler/rustc_data_structures/src/lib.rs
index 17da3ea83c8..e4e86bcc41a 100644
--- a/compiler/rustc_data_structures/src/lib.rs
+++ b/compiler/rustc_data_structures/src/lib.rs
@@ -34,6 +34,7 @@
 #![feature(sized_hierarchy)]
 #![feature(test)]
 #![feature(thread_id_value)]
+#![feature(trusted_len)]
 #![feature(type_alias_impl_trait)]
 #![feature(unwrap_infallible)]
 // tidy-alphabetical-end
diff --git a/compiler/rustc_data_structures/src/sorted_map.rs b/compiler/rustc_data_structures/src/sorted_map.rs
index c002d47815b..15e3e6ea4c3 100644
--- a/compiler/rustc_data_structures/src/sorted_map.rs
+++ b/compiler/rustc_data_structures/src/sorted_map.rs
@@ -1,6 +1,7 @@
 use std::borrow::Borrow;
 use std::cmp::Ordering;
 use std::fmt::Debug;
+use std::iter::TrustedLen;
 use std::mem;
 use std::ops::{Bound, Index, IndexMut, RangeBounds};
 
@@ -215,36 +216,40 @@ impl<K: Ord, V> SortedMap<K, V> {
     /// It is up to the caller to make sure that the elements are sorted by key
     /// and that there are no duplicates.
     #[inline]
-    pub fn insert_presorted(&mut self, elements: Vec<(K, V)>) {
-        if elements.is_empty() {
+    pub fn insert_presorted(
+        &mut self,
+        // We require `TrustedLen` to ensure that the `splice` below is actually efficient.
+        mut elements: impl Iterator<Item = (K, V)> + DoubleEndedIterator + TrustedLen,
+    ) {
+        let Some(first) = elements.next() else {
             return;
-        }
-
-        debug_assert!(elements.array_windows().all(|[fst, snd]| fst.0 < snd.0));
+        };
 
-        let start_index = self.lookup_index_for(&elements[0].0);
+        let start_index = self.lookup_index_for(&first.0);
 
         let elements = match start_index {
             Ok(index) => {
-                let mut elements = elements.into_iter();
-                self.data[index] = elements.next().unwrap();
-                elements
+                self.data[index] = first; // overwrite first element
+                elements.chain(None) // insert the rest below
             }
             Err(index) => {
-                if index == self.data.len() || elements.last().unwrap().0 < self.data[index].0 {
+                let last = elements.next_back();
+                if index == self.data.len()
+                    || last.as_ref().is_none_or(|l| l.0 < self.data[index].0)
+                {
                     // We can copy the whole range without having to mix with
                     // existing elements.
-                    self.data.splice(index..index, elements);
+                    self.data
+                        .splice(index..index, std::iter::once(first).chain(elements).chain(last));
                     return;
                 }
 
-                let mut elements = elements.into_iter();
-                self.data.insert(index, elements.next().unwrap());
-                elements
+                self.data.insert(index, first);
+                elements.chain(last) // insert the rest below
             }
         };
 
-        // Insert the rest
+        // Insert the rest. This is super inefficicent since each insertion copies the entire tail.
         for (k, v) in elements {
             self.insert(k, v);
         }
diff --git a/compiler/rustc_data_structures/src/sorted_map/tests.rs b/compiler/rustc_data_structures/src/sorted_map/tests.rs
index ea4d2f1feac..17d0d3cb170 100644
--- a/compiler/rustc_data_structures/src/sorted_map/tests.rs
+++ b/compiler/rustc_data_structures/src/sorted_map/tests.rs
@@ -171,7 +171,7 @@ fn test_insert_presorted_non_overlapping() {
     map.insert(2, 0);
     map.insert(8, 0);
 
-    map.insert_presorted(vec![(3, 0), (7, 0)]);
+    map.insert_presorted(vec![(3, 0), (7, 0)].into_iter());
 
     let expected = vec![2, 3, 7, 8];
     assert_eq!(keys(map), expected);
@@ -183,7 +183,7 @@ fn test_insert_presorted_first_elem_equal() {
     map.insert(2, 2);
     map.insert(8, 8);
 
-    map.insert_presorted(vec![(2, 0), (7, 7)]);
+    map.insert_presorted(vec![(2, 0), (7, 7)].into_iter());
 
     let expected = vec![(2, 0), (7, 7), (8, 8)];
     assert_eq!(elements(map), expected);
@@ -195,7 +195,7 @@ fn test_insert_presorted_last_elem_equal() {
     map.insert(2, 2);
     map.insert(8, 8);
 
-    map.insert_presorted(vec![(3, 3), (8, 0)]);
+    map.insert_presorted(vec![(3, 3), (8, 0)].into_iter());
 
     let expected = vec![(2, 2), (3, 3), (8, 0)];
     assert_eq!(elements(map), expected);
@@ -207,7 +207,7 @@ fn test_insert_presorted_shuffle() {
     map.insert(2, 2);
     map.insert(7, 7);
 
-    map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)]);
+    map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)].into_iter());
 
     let expected = vec![(1, 1), (2, 2), (3, 3), (7, 7), (8, 8)];
     assert_eq!(elements(map), expected);
@@ -219,7 +219,7 @@ fn test_insert_presorted_at_end() {
     map.insert(1, 1);
     map.insert(2, 2);
 
-    map.insert_presorted(vec![(3, 3), (8, 8)]);
+    map.insert_presorted(vec![(3, 3), (8, 8)].into_iter());
 
     let expected = vec![(1, 1), (2, 2), (3, 3), (8, 8)];
     assert_eq!(elements(map), expected);
diff --git a/compiler/rustc_middle/src/mir/interpret/allocation.rs b/compiler/rustc_middle/src/mir/interpret/allocation.rs
index 67962813ae4..8e603ce1b91 100644
--- a/compiler/rustc_middle/src/mir/interpret/allocation.rs
+++ b/compiler/rustc_middle/src/mir/interpret/allocation.rs
@@ -849,8 +849,13 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
     ///
     /// This is dangerous to use as it can violate internal `Allocation` invariants!
     /// It only exists to support an efficient implementation of `mem_copy_repeatedly`.
-    pub fn provenance_apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
-        self.provenance.apply_copy(copy)
+    pub fn provenance_apply_copy(
+        &mut self,
+        copy: ProvenanceCopy<Prov>,
+        range: AllocRange,
+        repeat: u64,
+    ) {
+        self.provenance.apply_copy(copy, range, repeat)
     }
 
     /// Applies a previously prepared copy of the init mask.
diff --git a/compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs b/compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs
index 720e58d7aa0..67baf63bbfa 100644
--- a/compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs
+++ b/compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs
@@ -278,90 +278,78 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
 
 /// A partial, owned list of provenance to transfer into another allocation.
 ///
-/// Offsets are already adjusted to the destination allocation.
+/// Offsets are relative to the beginning of the copied range.
 pub struct ProvenanceCopy<Prov> {
-    dest_ptrs: Option<Box<[(Size, Prov)]>>,
-    dest_bytes: Option<Box<[(Size, (Prov, u8))]>>,
+    ptrs: Box<[(Size, Prov)]>,
+    bytes: Box<[(Size, (Prov, u8))]>,
 }
 
 impl<Prov: Provenance> ProvenanceMap<Prov> {
     pub fn prepare_copy(
         &self,
-        src: AllocRange,
-        dest: Size,
-        count: u64,
+        range: AllocRange,
         cx: &impl HasDataLayout,
     ) -> AllocResult<ProvenanceCopy<Prov>> {
-        let shift_offset = move |idx, offset| {
-            // compute offset for current repetition
-            let dest_offset = dest + src.size * idx; // `Size` operations
-            // shift offsets from source allocation to destination allocation
-            (offset - src.start) + dest_offset // `Size` operations
-        };
+        let shift_offset = move |offset| offset - range.start;
         let ptr_size = cx.data_layout().pointer_size();
 
         // # Pointer-sized provenances
         // Get the provenances that are entirely within this range.
         // (Different from `range_get_ptrs` which asks if they overlap the range.)
         // Only makes sense if we are copying at least one pointer worth of bytes.
-        let mut dest_ptrs_box = None;
-        if src.size >= ptr_size {
-            let adjusted_end = Size::from_bytes(src.end().bytes() - (ptr_size.bytes() - 1));
-            let ptrs = self.ptrs.range(src.start..adjusted_end);
-            // If `count` is large, this is rather wasteful -- we are allocating a big array here, which
-            // is mostly filled with redundant information since it's just N copies of the same `Prov`s
-            // at slightly adjusted offsets. The reason we do this is so that in `mark_provenance_range`
-            // we can use `insert_presorted`. That wouldn't work with an `Iterator` that just produces
-            // the right sequence of provenance for all N copies.
-            // Basically, this large array would have to be created anyway in the target allocation.
-            let mut dest_ptrs = Vec::with_capacity(ptrs.len() * (count as usize));
-            for i in 0..count {
-                dest_ptrs
-                    .extend(ptrs.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
-            }
-            debug_assert_eq!(dest_ptrs.len(), dest_ptrs.capacity());
-            dest_ptrs_box = Some(dest_ptrs.into_boxed_slice());
+        let mut ptrs_box: Box<[_]> = Box::new([]);
+        if range.size >= ptr_size {
+            let adjusted_end = Size::from_bytes(range.end().bytes() - (ptr_size.bytes() - 1));
+            let ptrs = self.ptrs.range(range.start..adjusted_end);
+            ptrs_box = ptrs.iter().map(|&(offset, reloc)| (shift_offset(offset), reloc)).collect();
         };
 
         // # Byte-sized provenances
         // This includes the existing bytewise provenance in the range, and ptr provenance
         // that overlaps with the begin/end of the range.
-        let mut dest_bytes_box = None;
-        let begin_overlap = self.range_ptrs_get(alloc_range(src.start, Size::ZERO), cx).first();
-        let end_overlap = self.range_ptrs_get(alloc_range(src.end(), Size::ZERO), cx).first();
+        let mut bytes_box: Box<[_]> = Box::new([]);
+        let begin_overlap = self.range_ptrs_get(alloc_range(range.start, Size::ZERO), cx).first();
+        let end_overlap = self.range_ptrs_get(alloc_range(range.end(), Size::ZERO), cx).first();
         // We only need to go here if there is some overlap or some bytewise provenance.
         if begin_overlap.is_some() || end_overlap.is_some() || self.bytes.is_some() {
             let mut bytes: Vec<(Size, (Prov, u8))> = Vec::new();
             // First, if there is a part of a pointer at the start, add that.
             if let Some(entry) = begin_overlap {
                 trace!("start overlapping entry: {entry:?}");
-                // For really small copies, make sure we don't run off the end of the `src` range.
-                let entry_end = cmp::min(entry.0 + ptr_size, src.end());
-                for offset in src.start..entry_end {
-                    bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
+                // For really small copies, make sure we don't run off the end of the range.
+                let entry_end = cmp::min(entry.0 + ptr_size, range.end());
+                for offset in range.start..entry_end {
+                    bytes.push((shift_offset(offset), (entry.1, (offset - entry.0).bytes() as u8)));
                 }
             } else {
                 trace!("no start overlapping entry");
             }
 
             // Then the main part, bytewise provenance from `self.bytes`.
-            bytes.extend(self.range_bytes_get(src));
+            bytes.extend(
+                self.range_bytes_get(range)
+                    .iter()
+                    .map(|&(offset, reloc)| (shift_offset(offset), reloc)),
+            );
 
             // And finally possibly parts of a pointer at the end.
             if let Some(entry) = end_overlap {
                 trace!("end overlapping entry: {entry:?}");
-                // For really small copies, make sure we don't start before `src` does.
-                let entry_start = cmp::max(entry.0, src.start);
-                for offset in entry_start..src.end() {
+                // For really small copies, make sure we don't start before `range` does.
+                let entry_start = cmp::max(entry.0, range.start);
+                for offset in entry_start..range.end() {
                     if bytes.last().is_none_or(|bytes_entry| bytes_entry.0 < offset) {
                         // The last entry, if it exists, has a lower offset than us, so we
                         // can add it at the end and remain sorted.
-                        bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
+                        bytes.push((
+                            shift_offset(offset),
+                            (entry.1, (offset - entry.0).bytes() as u8),
+                        ));
                     } else {
                         // There already is an entry for this offset in there! This can happen when the
                         // start and end range checks actually end up hitting the same pointer, so we
                         // already added this in the "pointer at the start" part above.
-                        assert!(entry.0 <= src.start);
+                        assert!(entry.0 <= range.start);
                     }
                 }
             } else {
@@ -372,33 +360,40 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
             if !bytes.is_empty() && !Prov::OFFSET_IS_ADDR {
                 // FIXME(#146291): We need to ensure that we don't mix different pointers with
                 // the same provenance.
-                return Err(AllocError::ReadPartialPointer(src.start));
+                return Err(AllocError::ReadPartialPointer(range.start));
             }
 
             // And again a buffer for the new list on the target side.
-            let mut dest_bytes = Vec::with_capacity(bytes.len() * (count as usize));
-            for i in 0..count {
-                dest_bytes
-                    .extend(bytes.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
-            }
-            debug_assert_eq!(dest_bytes.len(), dest_bytes.capacity());
-            dest_bytes_box = Some(dest_bytes.into_boxed_slice());
+            bytes_box = bytes.into_boxed_slice();
         }
 
-        Ok(ProvenanceCopy { dest_ptrs: dest_ptrs_box, dest_bytes: dest_bytes_box })
+        Ok(ProvenanceCopy { ptrs: ptrs_box, bytes: bytes_box })
     }
 
     /// Applies a provenance copy.
     /// The affected range, as defined in the parameters to `prepare_copy` is expected
     /// to be clear of provenance.
-    pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
-        if let Some(dest_ptrs) = copy.dest_ptrs {
-            self.ptrs.insert_presorted(dest_ptrs.into());
+    pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>, range: AllocRange, repeat: u64) {
+        let shift_offset = |idx: u64, offset: Size| offset + range.start + idx * range.size;
+        if !copy.ptrs.is_empty() {
+            // We want to call `insert_presorted` only once so that, if possible, the entries
+            // after the range we insert are moved back only once.
+            let chunk_len = copy.ptrs.len() as u64;
+            self.ptrs.insert_presorted((0..chunk_len * repeat).map(|i| {
+                let chunk = i / chunk_len;
+                let (offset, reloc) = copy.ptrs[(i % chunk_len) as usize];
+                (shift_offset(chunk, offset), reloc)
+            }));
         }
-        if let Some(dest_bytes) = copy.dest_bytes
-            && !dest_bytes.is_empty()
-        {
-            self.bytes.get_or_insert_with(Box::default).insert_presorted(dest_bytes.into());
+        if !copy.bytes.is_empty() {
+            let chunk_len = copy.bytes.len() as u64;
+            self.bytes.get_or_insert_with(Box::default).insert_presorted(
+                (0..chunk_len * repeat).map(|i| {
+                    let chunk = i / chunk_len;
+                    let (offset, reloc) = copy.bytes[(i % chunk_len) as usize];
+                    (shift_offset(chunk, offset), reloc)
+                }),
+            );
         }
     }
 }