about summary refs log tree commit diff
path: root/compiler/rustc_data_structures
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_data_structures')
-rw-r--r--compiler/rustc_data_structures/Cargo.toml2
-rw-r--r--compiler/rustc_data_structures/src/marker.rs13
-rw-r--r--compiler/rustc_data_structures/src/sync/parallel.rs104
3 files changed, 89 insertions, 30 deletions
diff --git a/compiler/rustc_data_structures/Cargo.toml b/compiler/rustc_data_structures/Cargo.toml
index fcaf2750507..f48c73b13b9 100644
--- a/compiler/rustc_data_structures/Cargo.toml
+++ b/compiler/rustc_data_structures/Cargo.toml
@@ -14,7 +14,7 @@ indexmap = "2.4.0"
 jobserver_crate = { version = "0.1.28", package = "jobserver" }
 measureme = "12.0.1"
 rustc-hash = "2.0.0"
-rustc-rayon = { version = "0.5.1", features = ["indexmap"] }
+rustc-rayon-core = { version = "0.5.0" }
 rustc-stable-hash = { version = "0.1.0", features = ["nightly"] }
 rustc_arena = { path = "../rustc_arena" }
 rustc_graphviz = { path = "../rustc_graphviz" }
diff --git a/compiler/rustc_data_structures/src/marker.rs b/compiler/rustc_data_structures/src/marker.rs
index 64c64bfa3c2..744ae9b6fe2 100644
--- a/compiler/rustc_data_structures/src/marker.rs
+++ b/compiler/rustc_data_structures/src/marker.rs
@@ -180,6 +180,12 @@ impl<T> FromDyn<T> {
     }
 
     #[inline(always)]
+    pub fn derive<O>(&self, val: O) -> FromDyn<O> {
+        // We already did the check for `sync::is_dyn_thread_safe()` when creating `Self`
+        FromDyn(val)
+    }
+
+    #[inline(always)]
     pub fn into_inner(self) -> T {
         self.0
     }
@@ -200,6 +206,13 @@ impl<T> std::ops::Deref for FromDyn<T> {
     }
 }
 
+impl<T> std::ops::DerefMut for FromDyn<T> {
+    #[inline(always)]
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        &mut self.0
+    }
+}
+
 // A wrapper to convert a struct that is already a `Send` or `Sync` into
 // an instance of `DynSend` and `DynSync`, since the compiler cannot infer
 // it automatically in some cases. (e.g. Box<dyn Send / Sync>)
diff --git a/compiler/rustc_data_structures/src/sync/parallel.rs b/compiler/rustc_data_structures/src/sync/parallel.rs
index 8ef8a3f3585..ba3c85ef5b1 100644
--- a/compiler/rustc_data_structures/src/sync/parallel.rs
+++ b/compiler/rustc_data_structures/src/sync/parallel.rs
@@ -7,7 +7,6 @@ use std::any::Any;
 use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
 
 use parking_lot::Mutex;
-use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};
 
 use crate::FatalErrorMarker;
 use crate::sync::{DynSend, DynSync, FromDyn, IntoDynSyncSend, mode};
@@ -97,11 +96,11 @@ macro_rules! parallel {
 // This function only works when `mode::is_dyn_thread_safe()`.
 pub fn scope<'scope, OP, R>(op: OP) -> R
 where
-    OP: FnOnce(&rayon::Scope<'scope>) -> R + DynSend,
+    OP: FnOnce(&rayon_core::Scope<'scope>) -> R + DynSend,
     R: DynSend,
 {
     let op = FromDyn::from(op);
-    rayon::scope(|s| FromDyn::from(op.into_inner()(s))).into_inner()
+    rayon_core::scope(|s| FromDyn::from(op.into_inner()(s))).into_inner()
 }
 
 #[inline]
@@ -114,7 +113,7 @@ where
         let oper_a = FromDyn::from(oper_a);
         let oper_b = FromDyn::from(oper_b);
         let (a, b) = parallel_guard(|guard| {
-            rayon::join(
+            rayon_core::join(
                 move || guard.run(move || FromDyn::from(oper_a.into_inner()())),
                 move || guard.run(move || FromDyn::from(oper_b.into_inner()())),
             )
@@ -125,56 +124,103 @@ where
     }
 }
 
-pub fn par_for_each_in<I, T: IntoIterator<Item = I> + IntoParallelIterator<Item = I>>(
+fn par_slice<I: DynSend>(
+    items: &mut [I],
+    guard: &ParallelGuard,
+    for_each: impl Fn(&mut I) + DynSync + DynSend,
+) {
+    struct State<'a, F> {
+        for_each: FromDyn<F>,
+        guard: &'a ParallelGuard,
+        group: usize,
+    }
+
+    fn par_rec<I: DynSend, F: Fn(&mut I) + DynSync + DynSend>(
+        items: &mut [I],
+        state: &State<'_, F>,
+    ) {
+        if items.len() <= state.group {
+            for item in items {
+                state.guard.run(|| (state.for_each)(item));
+            }
+        } else {
+            let (left, right) = items.split_at_mut(items.len() / 2);
+            let mut left = state.for_each.derive(left);
+            let mut right = state.for_each.derive(right);
+            rayon_core::join(move || par_rec(*left, state), move || par_rec(*right, state));
+        }
+    }
+
+    let state = State {
+        for_each: FromDyn::from(for_each),
+        guard,
+        group: std::cmp::max(items.len() / 128, 1),
+    };
+    par_rec(items, &state)
+}
+
+pub fn par_for_each_in<I: DynSend, T: IntoIterator<Item = I>>(
     t: T,
-    for_each: impl Fn(I) + DynSync + DynSend,
+    for_each: impl Fn(&I) + DynSync + DynSend,
 ) {
     parallel_guard(|guard| {
         if mode::is_dyn_thread_safe() {
-            let for_each = FromDyn::from(for_each);
-            t.into_par_iter().for_each(|i| {
-                guard.run(|| for_each(i));
-            });
+            let mut items: Vec<_> = t.into_iter().collect();
+            par_slice(&mut items, guard, |i| for_each(&*i))
         } else {
             t.into_iter().for_each(|i| {
-                guard.run(|| for_each(i));
+                guard.run(|| for_each(&i));
             });
         }
     });
 }
 
-pub fn try_par_for_each_in<
-    T: IntoIterator + IntoParallelIterator<Item = <T as IntoIterator>::Item>,
-    E: Send,
->(
+/// This runs `for_each` in parallel for each iterator item. If one or more of the
+/// `for_each` calls returns `Err`, the function will also return `Err`. The error returned
+/// will be non-deterministic, but this is expected to be used with `ErrorGuaranteed` which
+/// are all equivalent.
+pub fn try_par_for_each_in<T: IntoIterator, E: DynSend>(
     t: T,
-    for_each: impl Fn(<T as IntoIterator>::Item) -> Result<(), E> + DynSync + DynSend,
-) -> Result<(), E> {
+    for_each: impl Fn(&<T as IntoIterator>::Item) -> Result<(), E> + DynSync + DynSend,
+) -> Result<(), E>
+where
+    <T as IntoIterator>::Item: DynSend,
+{
     parallel_guard(|guard| {
         if mode::is_dyn_thread_safe() {
-            let for_each = FromDyn::from(for_each);
-            t.into_par_iter()
-                .filter_map(|i| guard.run(|| for_each(i)))
-                .reduce(|| Ok(()), Result::and)
+            let mut items: Vec<_> = t.into_iter().collect();
+
+            let error = Mutex::new(None);
+
+            par_slice(&mut items, guard, |i| {
+                if let Err(err) = for_each(&*i) {
+                    *error.lock() = Some(err);
+                }
+            });
+
+            if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) }
         } else {
-            t.into_iter().filter_map(|i| guard.run(|| for_each(i))).fold(Ok(()), Result::and)
+            t.into_iter().filter_map(|i| guard.run(|| for_each(&i))).fold(Ok(()), Result::and)
         }
     })
 }
 
-pub fn par_map<
-    I,
-    T: IntoIterator<Item = I> + IntoParallelIterator<Item = I>,
-    R: std::marker::Send,
-    C: FromIterator<R> + FromParallelIterator<R>,
->(
+pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterator<R>>(
     t: T,
     map: impl Fn(I) -> R + DynSync + DynSend,
 ) -> C {
     parallel_guard(|guard| {
         if mode::is_dyn_thread_safe() {
             let map = FromDyn::from(map);
-            t.into_par_iter().filter_map(|i| guard.run(|| map(i))).collect()
+
+            let mut items: Vec<(Option<I>, Option<R>)> =
+                t.into_iter().map(|i| (Some(i), None)).collect();
+
+            par_slice(&mut items, guard, |i| {
+                i.1 = Some(map(i.0.take().unwrap()));
+            });
+
+            items.into_iter().filter_map(|i| i.1).collect()
         } else {
             t.into_iter().filter_map(|i| guard.run(|| map(i))).collect()
         }