about summary refs log tree commit diff
path: root/library/alloc
diff options
context:
space:
mode:
authorThe 8472 <git@infinite-source.de>2023-04-15 12:59:16 +0200
committerThe 8472 <git@infinite-source.de>2023-09-03 19:59:47 +0200
commit3ca6bb0b44c3e65dab07e12aec9efb277dc206f9 (patch)
treeb5be981314a0c22721924dfe8eea614151f226e7 /library/alloc
parent2a1af898b2cb535a45cefe67acf9d023eff16b27 (diff)
downloadrust-3ca6bb0b44c3e65dab07e12aec9efb277dc206f9.tar.gz
rust-3ca6bb0b44c3e65dab07e12aec9efb277dc206f9.zip
Expand in-place iteration specialization to Flatten, FlatMap and ArrayChunks
Diffstat (limited to 'library/alloc')
-rw-r--r--library/alloc/src/collections/binary_heap/mod.rs11
-rw-r--r--library/alloc/src/lib.rs1
-rw-r--r--library/alloc/src/vec/in_place_collect.rs83
-rw-r--r--library/alloc/src/vec/into_iter.rs12
-rw-r--r--library/alloc/tests/lib.rs1
-rw-r--r--library/alloc/tests/vec.rs34
6 files changed, 119 insertions, 23 deletions
diff --git a/library/alloc/src/collections/binary_heap/mod.rs b/library/alloc/src/collections/binary_heap/mod.rs
index 66573b90db9..ad2c4e483a3 100644
--- a/library/alloc/src/collections/binary_heap/mod.rs
+++ b/library/alloc/src/collections/binary_heap/mod.rs
@@ -145,7 +145,7 @@
 
 use core::alloc::Allocator;
 use core::fmt;
-use core::iter::{FusedIterator, InPlaceIterable, SourceIter, TrustedLen};
+use core::iter::{FusedIterator, InPlaceIterable, SourceIter, TrustedFused, TrustedLen};
 use core::mem::{self, swap, ManuallyDrop};
 use core::num::NonZeroUsize;
 use core::ops::{Deref, DerefMut};
@@ -1540,6 +1540,10 @@ impl<T, A: Allocator> ExactSizeIterator for IntoIter<T, A> {
 #[stable(feature = "fused", since = "1.26.0")]
 impl<T, A: Allocator> FusedIterator for IntoIter<T, A> {}
 
+#[doc(hidden)]
+#[unstable(issue = "none", feature = "trusted_fused")]
+unsafe impl<T, A: Allocator> TrustedFused for IntoIter<T, A> {}
+
 #[stable(feature = "default_iters", since = "1.70.0")]
 impl<T> Default for IntoIter<T> {
     /// Creates an empty `binary_heap::IntoIter`.
@@ -1569,7 +1573,10 @@ unsafe impl<T, A: Allocator> SourceIter for IntoIter<T, A> {
 
 #[unstable(issue = "none", feature = "inplace_iteration")]
 #[doc(hidden)]
-unsafe impl<I, A: Allocator> InPlaceIterable for IntoIter<I, A> {}
+unsafe impl<I, A: Allocator> InPlaceIterable for IntoIter<I, A> {
+    const EXPAND_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
+    const MERGE_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
+}
 
 unsafe impl<I> AsVecIntoIter for IntoIter<I> {
     type Item = I;
diff --git a/library/alloc/src/lib.rs b/library/alloc/src/lib.rs
index e43b6ac4039..c9fee435f1e 100644
--- a/library/alloc/src/lib.rs
+++ b/library/alloc/src/lib.rs
@@ -155,6 +155,7 @@
 #![feature(std_internals)]
 #![feature(str_internals)]
 #![feature(strict_provenance)]
+#![feature(trusted_fused)]
 #![feature(trusted_len)]
 #![feature(trusted_random_access)]
 #![feature(try_trait_v2)]
diff --git a/library/alloc/src/vec/in_place_collect.rs b/library/alloc/src/vec/in_place_collect.rs
index 5ecd0479971..d174074902a 100644
--- a/library/alloc/src/vec/in_place_collect.rs
+++ b/library/alloc/src/vec/in_place_collect.rs
@@ -137,44 +137,73 @@
 //! }
 //! vec.truncate(write_idx);
 //! ```
+use crate::alloc::{handle_alloc_error, Global};
+use core::alloc::Allocator;
+use core::alloc::Layout;
 use core::iter::{InPlaceIterable, SourceIter, TrustedRandomAccessNoCoerce};
 use core::mem::{self, ManuallyDrop, SizedTypeProperties};
-use core::ptr::{self};
+use core::num::NonZeroUsize;
+use core::ptr::{self, NonNull};
 
 use super::{InPlaceDrop, InPlaceDstBufDrop, SpecFromIter, SpecFromIterNested, Vec};
 
-/// Specialization marker for collecting an iterator pipeline into a Vec while reusing the
-/// source allocation, i.e. executing the pipeline in place.
-#[rustc_unsafe_specialization_marker]
-pub(super) trait InPlaceIterableMarker {}
+const fn in_place_collectible<DEST, SRC>(
+    step_merge: Option<NonZeroUsize>,
+    step_expand: Option<NonZeroUsize>,
+) -> bool {
+    if DEST::IS_ZST || mem::align_of::<SRC>() != mem::align_of::<DEST>() {
+        return false;
+    }
+
+    match (step_merge, step_expand) {
+        (Some(step_merge), Some(step_expand)) => {
+            // At least N merged source items -> at most M expanded destination items
+            // e.g.
+            // - 1 x [u8; 4] -> 4x u8, via flatten
+            // - 4 x u8 -> 1x [u8; 4], via array_chunks
+            mem::size_of::<SRC>() * step_merge.get() == mem::size_of::<DEST>() * step_expand.get()
+        }
+        // Fall back to other from_iter impls if an overflow occured in the step merge/expansion
+        // tracking.
+        _ => false,
+    }
+}
 
-impl<T> InPlaceIterableMarker for T where T: InPlaceIterable {}
+/// This provides a shorthand for the source type since local type aliases aren't a thing.
+#[rustc_specialization_trait]
+trait InPlaceCollect: SourceIter<Source: AsVecIntoIter> + InPlaceIterable {
+    type Src;
+}
+
+impl<T> InPlaceCollect for T
+where
+    T: SourceIter<Source: AsVecIntoIter> + InPlaceIterable,
+{
+    type Src = <<T as SourceIter>::Source as AsVecIntoIter>::Item;
+}
 
 impl<T, I> SpecFromIter<T, I> for Vec<T>
 where
-    I: Iterator<Item = T> + SourceIter<Source: AsVecIntoIter> + InPlaceIterableMarker,
+    I: Iterator<Item = T> + InPlaceCollect,
+    <I as SourceIter>::Source: AsVecIntoIter,
 {
     default fn from_iter(mut iterator: I) -> Self {
         // See "Layout constraints" section in the module documentation. We rely on const
         // optimization here since these conditions currently cannot be expressed as trait bounds
-        if T::IS_ZST
-            || mem::size_of::<T>()
-                != mem::size_of::<<<I as SourceIter>::Source as AsVecIntoIter>::Item>()
-            || mem::align_of::<T>()
-                != mem::align_of::<<<I as SourceIter>::Source as AsVecIntoIter>::Item>()
-        {
+        if const { !in_place_collectible::<T, I::Src>(I::MERGE_BY, I::EXPAND_BY) } {
             // fallback to more generic implementations
             return SpecFromIterNested::from_iter(iterator);
         }
 
-        let (src_buf, src_ptr, dst_buf, dst_end, cap) = unsafe {
+        let (src_buf, src_ptr, src_cap, mut dst_buf, dst_end, dst_cap) = unsafe {
             let inner = iterator.as_inner().as_into_iter();
             (
                 inner.buf.as_ptr(),
                 inner.ptr,
+                inner.cap,
                 inner.buf.as_ptr() as *mut T,
                 inner.end as *const T,
-                inner.cap,
+                inner.cap * mem::size_of::<I::Src>() / mem::size_of::<T>(),
             )
         };
 
@@ -203,11 +232,31 @@ where
         // Note: This access to the source wouldn't be allowed by the TrustedRandomIteratorNoCoerce
         // contract (used by SpecInPlaceCollect below). But see the "O(1) collect" section in the
         // module documentation why this is ok anyway.
-        let dst_guard = InPlaceDstBufDrop { ptr: dst_buf, len, cap };
+        let dst_guard = InPlaceDstBufDrop { ptr: dst_buf, len, cap: dst_cap };
         src.forget_allocation_drop_remaining();
         mem::forget(dst_guard);
 
-        let vec = unsafe { Vec::from_raw_parts(dst_buf, len, cap) };
+        // Adjust the allocation size if the source had a capacity in bytes that wasn't a multiple
+        // of the destination type size.
+        // Since the discrepancy should generally be small this should only result in some
+        // bookkeeping updates and no memmove.
+        if const { mem::size_of::<T>() > mem::size_of::<I::Src>() }
+            && src_cap * mem::size_of::<I::Src>() != dst_cap * mem::size_of::<T>()
+        {
+            let alloc = Global;
+            unsafe {
+                let new_layout = Layout::array::<T>(dst_cap).unwrap();
+                let result = alloc.shrink(
+                    NonNull::new_unchecked(dst_buf as *mut u8),
+                    Layout::array::<I::Src>(src_cap).unwrap(),
+                    new_layout,
+                );
+                let Ok(reallocated) = result else { handle_alloc_error(new_layout) };
+                dst_buf = reallocated.as_ptr() as *mut T;
+            }
+        }
+
+        let vec = unsafe { Vec::from_raw_parts(dst_buf, len, dst_cap) };
 
         vec
     }
diff --git a/library/alloc/src/vec/into_iter.rs b/library/alloc/src/vec/into_iter.rs
index b2db2fdfd18..40b93b34e18 100644
--- a/library/alloc/src/vec/into_iter.rs
+++ b/library/alloc/src/vec/into_iter.rs
@@ -7,7 +7,8 @@ use crate::raw_vec::RawVec;
 use core::array;
 use core::fmt;
 use core::iter::{
-    FusedIterator, InPlaceIterable, SourceIter, TrustedLen, TrustedRandomAccessNoCoerce,
+    FusedIterator, InPlaceIterable, SourceIter, TrustedFused, TrustedLen,
+    TrustedRandomAccessNoCoerce,
 };
 use core::marker::PhantomData;
 use core::mem::{self, ManuallyDrop, MaybeUninit, SizedTypeProperties};
@@ -339,6 +340,10 @@ impl<T, A: Allocator> ExactSizeIterator for IntoIter<T, A> {
 #[stable(feature = "fused", since = "1.26.0")]
 impl<T, A: Allocator> FusedIterator for IntoIter<T, A> {}
 
+#[doc(hidden)]
+#[unstable(issue = "none", feature = "trusted_fused")]
+unsafe impl<T, A: Allocator> TrustedFused for IntoIter<T, A> {}
+
 #[unstable(feature = "trusted_len", issue = "37572")]
 unsafe impl<T, A: Allocator> TrustedLen for IntoIter<T, A> {}
 
@@ -423,7 +428,10 @@ unsafe impl<#[may_dangle] T, A: Allocator> Drop for IntoIter<T, A> {
 // also refer to the vec::in_place_collect module documentation to get an overview
 #[unstable(issue = "none", feature = "inplace_iteration")]
 #[doc(hidden)]
-unsafe impl<T, A: Allocator> InPlaceIterable for IntoIter<T, A> {}
+unsafe impl<T, A: Allocator> InPlaceIterable for IntoIter<T, A> {
+    const EXPAND_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
+    const MERGE_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
+}
 
 #[unstable(issue = "none", feature = "inplace_iteration")]
 #[doc(hidden)]
diff --git a/library/alloc/tests/lib.rs b/library/alloc/tests/lib.rs
index aa7a331b368..5bb531f0e62 100644
--- a/library/alloc/tests/lib.rs
+++ b/library/alloc/tests/lib.rs
@@ -1,5 +1,6 @@
 #![feature(allocator_api)]
 #![feature(alloc_layout_extra)]
+#![feature(iter_array_chunks)]
 #![feature(assert_matches)]
 #![feature(btree_extract_if)]
 #![feature(cow_is_borrowed)]
diff --git a/library/alloc/tests/vec.rs b/library/alloc/tests/vec.rs
index 9cb27899f10..33dd2658e1e 100644
--- a/library/alloc/tests/vec.rs
+++ b/library/alloc/tests/vec.rs
@@ -1,6 +1,7 @@
+use alloc::vec::Vec;
 use core::alloc::{Allocator, Layout};
-use core::assert_eq;
-use core::iter::IntoIterator;
+use core::{assert_eq, assert_ne};
+use core::iter::{IntoIterator, Iterator};
 use core::num::NonZeroUsize;
 use core::ptr::NonNull;
 use std::alloc::System;
@@ -1185,6 +1186,35 @@ fn test_from_iter_specialization_with_iterator_adapters() {
 }
 
 #[test]
+fn test_in_place_specialization_step_up_down() {
+    fn assert_in_place_trait<T: InPlaceIterable>(_: &T) {}
+    let src = vec![[0u8; 4]; 256];
+    let srcptr = src.as_ptr();
+    let src_cap = src.capacity();
+    let iter = src.into_iter().flatten();
+    assert_in_place_trait(&iter);
+    let sink = iter.collect::<Vec<_>>();
+    let sinkptr = sink.as_ptr();
+    assert_eq!(srcptr as *const u8, sinkptr);
+    assert_eq!(src_cap * 4, sink.capacity());
+
+    let iter = sink.into_iter().array_chunks::<4>();
+    assert_in_place_trait(&iter);
+    let sink = iter.collect::<Vec<_>>();
+    let sinkptr = sink.as_ptr();
+    assert_eq!(srcptr, sinkptr);
+    assert_eq!(src_cap, sink.capacity());
+
+    let mut src: Vec<u8> = Vec::with_capacity(17);
+    let src_bytes = src.capacity();
+    src.resize(8, 0u8);
+    let sink: Vec<[u8; 4]> = src.into_iter().array_chunks::<4>().collect();
+    let sink_bytes = sink.capacity() * 4;
+    assert_ne!(src_bytes, sink_bytes);
+    assert_eq!(sink.len(), 2);
+}
+
+#[test]
 fn test_from_iter_specialization_head_tail_drop() {
     let drop_count: Vec<_> = (0..=2).map(|_| Rc::new(())).collect();
     let src: Vec<_> = drop_count.iter().cloned().collect();