diff options
| author | Camille GILLOT <gillot.camille@gmail.com> | 2024-06-26 21:57:59 +0000 |
|---|---|---|
| committer | Camille GILLOT <gillot.camille@gmail.com> | 2024-07-01 15:41:21 +0000 |
| commit | 76244d4dbc768e15e429c1f66ec021884f369f5f (patch) | |
| tree | 247d7d63cb92520e7c1316a19e929e0a336200bd /compiler/rustc_mir_dataflow/src | |
| parent | 1834f5a272d567a714f78c7f48c0d3ae4a6238bb (diff) | |
| download | rust-76244d4dbc768e15e429c1f66ec021884f369f5f.tar.gz rust-76244d4dbc768e15e429c1f66ec021884f369f5f.zip | |
Make jump threading state sparse.
Diffstat (limited to 'compiler/rustc_mir_dataflow/src')
| -rw-r--r-- | compiler/rustc_mir_dataflow/src/framework/lattice.rs | 14 | ||||
| -rw-r--r-- | compiler/rustc_mir_dataflow/src/value_analysis.rs | 91 |
2 files changed, 72 insertions, 33 deletions
diff --git a/compiler/rustc_mir_dataflow/src/framework/lattice.rs b/compiler/rustc_mir_dataflow/src/framework/lattice.rs index 1c2b475a43c..23738f7a4a5 100644 --- a/compiler/rustc_mir_dataflow/src/framework/lattice.rs +++ b/compiler/rustc_mir_dataflow/src/framework/lattice.rs @@ -76,6 +76,8 @@ pub trait MeetSemiLattice: Eq { /// A set that has a "bottom" element, which is less than or equal to any other element. pub trait HasBottom { const BOTTOM: Self; + + fn is_bottom(&self) -> bool; } /// A set that has a "top" element, which is greater than or equal to any other element. @@ -114,6 +116,10 @@ impl MeetSemiLattice for bool { impl HasBottom for bool { const BOTTOM: Self = false; + + fn is_bottom(&self) -> bool { + !self + } } impl HasTop for bool { @@ -267,6 +273,10 @@ impl<T: Clone + Eq> MeetSemiLattice for FlatSet<T> { impl<T> HasBottom for FlatSet<T> { const BOTTOM: Self = Self::Bottom; + + fn is_bottom(&self) -> bool { + matches!(self, Self::Bottom) + } } impl<T> HasTop for FlatSet<T> { @@ -291,6 +301,10 @@ impl<T> MaybeReachable<T> { impl<T> HasBottom for MaybeReachable<T> { const BOTTOM: Self = MaybeReachable::Unreachable; + + fn is_bottom(&self) -> bool { + matches!(self, Self::Unreachable) + } } impl<T: HasTop> HasTop for MaybeReachable<T> { diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs index 0364c23bfcb..7c1ff6fda53 100644 --- a/compiler/rustc_mir_dataflow/src/value_analysis.rs +++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs @@ -36,7 +36,7 @@ use std::collections::VecDeque; use std::fmt::{Debug, Formatter}; use std::ops::Range; -use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::fx::{FxHashMap, StdEntry}; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; @@ -342,8 +342,7 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) { // The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥. assert!(matches!(state, State::Unreachable)); - let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count); - *state = State::Reachable(values); + *state = State::new_reachable(); for arg in body.args_iter() { state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map()); } @@ -415,30 +414,54 @@ rustc_index::newtype_index!( /// See [`State`]. #[derive(PartialEq, Eq, Debug)] -struct StateData<V> { - map: IndexVec<ValueIndex, V>, +pub struct StateData<V> { + bottom: V, + /// This map only contains values that are not `⊥`. + map: FxHashMap<ValueIndex, V>, } -impl<V: Clone> StateData<V> { - fn from_elem_n(elem: V, n: usize) -> StateData<V> { - StateData { map: IndexVec::from_elem_n(elem, n) } +impl<V: HasBottom> StateData<V> { + fn new() -> StateData<V> { + StateData { bottom: V::BOTTOM, map: FxHashMap::default() } + } + + fn get(&self, idx: ValueIndex) -> &V { + self.map.get(&idx).unwrap_or(&self.bottom) + } + + fn insert(&mut self, idx: ValueIndex, elem: V) { + if elem.is_bottom() { + self.map.remove(&idx); + } else { + self.map.insert(idx, elem); + } } } impl<V: Clone> Clone for StateData<V> { fn clone(&self) -> Self { - StateData { map: self.map.clone() } + StateData { bottom: self.bottom.clone(), map: self.map.clone() } } fn clone_from(&mut self, source: &Self) { - // We go through `raw` here, because `IndexVec` currently has a naive `clone_from`. - self.map.raw.clone_from(&source.map.raw) + self.map.clone_from(&source.map) } } -impl<V: JoinSemiLattice + Clone> JoinSemiLattice for StateData<V> { +impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for StateData<V> { fn join(&mut self, other: &Self) -> bool { - self.map.join(&other.map) + let mut changed = false; + #[allow(rustc::potential_query_instability)] + for (i, v) in other.map.iter() { + match self.map.entry(*i) { + StdEntry::Vacant(e) => { + e.insert(v.clone()); + changed = true + } + StdEntry::Occupied(e) => changed |= e.into_mut().join(v), + } + } + changed } } @@ -476,15 +499,19 @@ impl<V: Clone> Clone for State<V> { } } -impl<V: Clone> State<V> { - pub fn new(init: V, map: &Map) -> State<V> { - State::Reachable(StateData::from_elem_n(init, map.value_count)) +impl<V: Clone + HasBottom> State<V> { + pub fn new_reachable() -> State<V> { + State::Reachable(StateData::new()) } - pub fn all(&self, f: impl Fn(&V) -> bool) -> bool { + pub fn all_bottom(&self) -> bool { match self { - State::Unreachable => true, - State::Reachable(ref values) => values.map.iter().all(f), + State::Unreachable => false, + State::Reachable(ref values) => + { + #[allow(rustc::potential_query_instability)] + values.map.values().all(V::is_bottom) + } } } @@ -533,9 +560,7 @@ impl<V: Clone> State<V> { value: V, ) { let State::Reachable(values) = self else { return }; - map.for_each_aliasing_place(place, tail_elem, &mut |vi| { - values.map[vi] = value.clone(); - }); + map.for_each_aliasing_place(place, tail_elem, &mut |vi| values.insert(vi, value.clone())); } /// Low-level method that assigns to a place. @@ -556,7 +581,7 @@ impl<V: Clone> State<V> { pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) { let State::Reachable(values) = self else { return }; if let Some(value_index) = map.places[target].value_index { - values.map[value_index] = value; + values.insert(value_index, value) } } @@ -575,7 +600,7 @@ impl<V: Clone> State<V> { // already been performed. if let Some(target_value) = map.places[target].value_index { if let Some(source_value) = map.places[source].value_index { - values.map[target_value] = values.map[source_value].clone(); + values.insert(target_value, values.get(source_value).clone()); } } for target_child in map.children(target) { @@ -631,7 +656,7 @@ impl<V: Clone> State<V> { pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> { match self { State::Reachable(values) => { - map.places[place].value_index.map(|v| values.map[v].clone()) + map.places[place].value_index.map(|v| values.get(v).clone()) } State::Unreachable => None, } @@ -688,7 +713,7 @@ impl<V: Clone> State<V> { { match self { State::Reachable(values) => { - map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP) + map.places[place].value_index.map(|v| values.get(v).clone()).unwrap_or(V::TOP) } State::Unreachable => { // Because this is unreachable, we can return any value we want. @@ -698,7 +723,7 @@ impl<V: Clone> State<V> { } } -impl<V: JoinSemiLattice + Clone> JoinSemiLattice for State<V> { +impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for State<V> { fn join(&mut self, other: &Self) -> bool { match (&mut *self, other) { (_, State::Unreachable) => false, @@ -1228,7 +1253,7 @@ where } } -fn debug_with_context_rec<V: Debug + Eq>( +fn debug_with_context_rec<V: Debug + Eq + HasBottom>( place: PlaceIndex, place_str: &str, new: &StateData<V>, @@ -1238,11 +1263,11 @@ fn debug_with_context_rec<V: Debug + Eq>( ) -> std::fmt::Result { if let Some(value) = map.places[place].value_index { match old { - None => writeln!(f, "{}: {:?}", place_str, new.map[value])?, + None => writeln!(f, "{}: {:?}", place_str, new.get(value))?, Some(old) => { - if new.map[value] != old.map[value] { - writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?; - writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?; + if new.get(value) != old.get(value) { + writeln!(f, "\u{001f}-{}: {:?}", place_str, old.get(value))?; + writeln!(f, "\u{001f}+{}: {:?}", place_str, new.get(value))?; } } } @@ -1274,7 +1299,7 @@ fn debug_with_context_rec<V: Debug + Eq>( Ok(()) } -fn debug_with_context<V: Debug + Eq>( +fn debug_with_context<V: Debug + Eq + HasBottom>( new: &StateData<V>, old: Option<&StateData<V>>, map: &Map, |
