about summary refs log tree commit diff
path: root/compiler/rustc_mir_dataflow/src
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2024-06-26 21:57:59 +0000
committerCamille GILLOT <gillot.camille@gmail.com>2024-07-01 15:41:21 +0000
commit76244d4dbc768e15e429c1f66ec021884f369f5f (patch)
tree247d7d63cb92520e7c1316a19e929e0a336200bd /compiler/rustc_mir_dataflow/src
parent1834f5a272d567a714f78c7f48c0d3ae4a6238bb (diff)
downloadrust-76244d4dbc768e15e429c1f66ec021884f369f5f.tar.gz
rust-76244d4dbc768e15e429c1f66ec021884f369f5f.zip
Make jump threading state sparse.
Diffstat (limited to 'compiler/rustc_mir_dataflow/src')
-rw-r--r--compiler/rustc_mir_dataflow/src/framework/lattice.rs14
-rw-r--r--compiler/rustc_mir_dataflow/src/value_analysis.rs91
2 files changed, 72 insertions, 33 deletions
diff --git a/compiler/rustc_mir_dataflow/src/framework/lattice.rs b/compiler/rustc_mir_dataflow/src/framework/lattice.rs
index 1c2b475a43c..23738f7a4a5 100644
--- a/compiler/rustc_mir_dataflow/src/framework/lattice.rs
+++ b/compiler/rustc_mir_dataflow/src/framework/lattice.rs
@@ -76,6 +76,8 @@ pub trait MeetSemiLattice: Eq {
 /// A set that has a "bottom" element, which is less than or equal to any other element.
 pub trait HasBottom {
     const BOTTOM: Self;
+
+    fn is_bottom(&self) -> bool;
 }
 
 /// A set that has a "top" element, which is greater than or equal to any other element.
@@ -114,6 +116,10 @@ impl MeetSemiLattice for bool {
 
 impl HasBottom for bool {
     const BOTTOM: Self = false;
+
+    fn is_bottom(&self) -> bool {
+        !self
+    }
 }
 
 impl HasTop for bool {
@@ -267,6 +273,10 @@ impl<T: Clone + Eq> MeetSemiLattice for FlatSet<T> {
 
 impl<T> HasBottom for FlatSet<T> {
     const BOTTOM: Self = Self::Bottom;
+
+    fn is_bottom(&self) -> bool {
+        matches!(self, Self::Bottom)
+    }
 }
 
 impl<T> HasTop for FlatSet<T> {
@@ -291,6 +301,10 @@ impl<T> MaybeReachable<T> {
 
 impl<T> HasBottom for MaybeReachable<T> {
     const BOTTOM: Self = MaybeReachable::Unreachable;
+
+    fn is_bottom(&self) -> bool {
+        matches!(self, Self::Unreachable)
+    }
 }
 
 impl<T: HasTop> HasTop for MaybeReachable<T> {
diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs
index 0364c23bfcb..7c1ff6fda53 100644
--- a/compiler/rustc_mir_dataflow/src/value_analysis.rs
+++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs
@@ -36,7 +36,7 @@ use std::collections::VecDeque;
 use std::fmt::{Debug, Formatter};
 use std::ops::Range;
 
-use rustc_data_structures::fx::FxHashMap;
+use rustc_data_structures::fx::{FxHashMap, StdEntry};
 use rustc_data_structures::stack::ensure_sufficient_stack;
 use rustc_index::bit_set::BitSet;
 use rustc_index::IndexVec;
@@ -342,8 +342,7 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper
     fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) {
         // The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥.
         assert!(matches!(state, State::Unreachable));
-        let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count);
-        *state = State::Reachable(values);
+        *state = State::new_reachable();
         for arg in body.args_iter() {
             state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map());
         }
@@ -415,30 +414,54 @@ rustc_index::newtype_index!(
 
 /// See [`State`].
 #[derive(PartialEq, Eq, Debug)]
-struct StateData<V> {
-    map: IndexVec<ValueIndex, V>,
+pub struct StateData<V> {
+    bottom: V,
+    /// This map only contains values that are not `⊥`.
+    map: FxHashMap<ValueIndex, V>,
 }
 
-impl<V: Clone> StateData<V> {
-    fn from_elem_n(elem: V, n: usize) -> StateData<V> {
-        StateData { map: IndexVec::from_elem_n(elem, n) }
+impl<V: HasBottom> StateData<V> {
+    fn new() -> StateData<V> {
+        StateData { bottom: V::BOTTOM, map: FxHashMap::default() }
+    }
+
+    fn get(&self, idx: ValueIndex) -> &V {
+        self.map.get(&idx).unwrap_or(&self.bottom)
+    }
+
+    fn insert(&mut self, idx: ValueIndex, elem: V) {
+        if elem.is_bottom() {
+            self.map.remove(&idx);
+        } else {
+            self.map.insert(idx, elem);
+        }
     }
 }
 
 impl<V: Clone> Clone for StateData<V> {
     fn clone(&self) -> Self {
-        StateData { map: self.map.clone() }
+        StateData { bottom: self.bottom.clone(), map: self.map.clone() }
     }
 
     fn clone_from(&mut self, source: &Self) {
-        // We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
-        self.map.raw.clone_from(&source.map.raw)
+        self.map.clone_from(&source.map)
     }
 }
 
-impl<V: JoinSemiLattice + Clone> JoinSemiLattice for StateData<V> {
+impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for StateData<V> {
     fn join(&mut self, other: &Self) -> bool {
-        self.map.join(&other.map)
+        let mut changed = false;
+        #[allow(rustc::potential_query_instability)]
+        for (i, v) in other.map.iter() {
+            match self.map.entry(*i) {
+                StdEntry::Vacant(e) => {
+                    e.insert(v.clone());
+                    changed = true
+                }
+                StdEntry::Occupied(e) => changed |= e.into_mut().join(v),
+            }
+        }
+        changed
     }
 }
 
@@ -476,15 +499,19 @@ impl<V: Clone> Clone for State<V> {
     }
 }
 
-impl<V: Clone> State<V> {
-    pub fn new(init: V, map: &Map) -> State<V> {
-        State::Reachable(StateData::from_elem_n(init, map.value_count))
+impl<V: Clone + HasBottom> State<V> {
+    pub fn new_reachable() -> State<V> {
+        State::Reachable(StateData::new())
     }
 
-    pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
+    pub fn all_bottom(&self) -> bool {
         match self {
-            State::Unreachable => true,
-            State::Reachable(ref values) => values.map.iter().all(f),
+            State::Unreachable => false,
+            State::Reachable(ref values) =>
+            {
+                #[allow(rustc::potential_query_instability)]
+                values.map.values().all(V::is_bottom)
+            }
         }
     }
 
@@ -533,9 +560,7 @@ impl<V: Clone> State<V> {
         value: V,
     ) {
         let State::Reachable(values) = self else { return };
-        map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
-            values.map[vi] = value.clone();
-        });
+        map.for_each_aliasing_place(place, tail_elem, &mut |vi| values.insert(vi, value.clone()));
     }
 
     /// Low-level method that assigns to a place.
@@ -556,7 +581,7 @@ impl<V: Clone> State<V> {
     pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
         let State::Reachable(values) = self else { return };
         if let Some(value_index) = map.places[target].value_index {
-            values.map[value_index] = value;
+            values.insert(value_index, value)
         }
     }
 
@@ -575,7 +600,7 @@ impl<V: Clone> State<V> {
         // already been performed.
         if let Some(target_value) = map.places[target].value_index {
             if let Some(source_value) = map.places[source].value_index {
-                values.map[target_value] = values.map[source_value].clone();
+                values.insert(target_value, values.get(source_value).clone());
             }
         }
         for target_child in map.children(target) {
@@ -631,7 +656,7 @@ impl<V: Clone> State<V> {
     pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
         match self {
             State::Reachable(values) => {
-                map.places[place].value_index.map(|v| values.map[v].clone())
+                map.places[place].value_index.map(|v| values.get(v).clone())
             }
             State::Unreachable => None,
         }
@@ -688,7 +713,7 @@ impl<V: Clone> State<V> {
     {
         match self {
             State::Reachable(values) => {
-                map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP)
+                map.places[place].value_index.map(|v| values.get(v).clone()).unwrap_or(V::TOP)
             }
             State::Unreachable => {
                 // Because this is unreachable, we can return any value we want.
@@ -698,7 +723,7 @@ impl<V: Clone> State<V> {
     }
 }
 
-impl<V: JoinSemiLattice + Clone> JoinSemiLattice for State<V> {
+impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for State<V> {
     fn join(&mut self, other: &Self) -> bool {
         match (&mut *self, other) {
             (_, State::Unreachable) => false,
@@ -1228,7 +1253,7 @@ where
     }
 }
 
-fn debug_with_context_rec<V: Debug + Eq>(
+fn debug_with_context_rec<V: Debug + Eq + HasBottom>(
     place: PlaceIndex,
     place_str: &str,
     new: &StateData<V>,
@@ -1238,11 +1263,11 @@ fn debug_with_context_rec<V: Debug + Eq>(
 ) -> std::fmt::Result {
     if let Some(value) = map.places[place].value_index {
         match old {
-            None => writeln!(f, "{}: {:?}", place_str, new.map[value])?,
+            None => writeln!(f, "{}: {:?}", place_str, new.get(value))?,
             Some(old) => {
-                if new.map[value] != old.map[value] {
-                    writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?;
-                    writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?;
+                if new.get(value) != old.get(value) {
+                    writeln!(f, "\u{001f}-{}: {:?}", place_str, old.get(value))?;
+                    writeln!(f, "\u{001f}+{}: {:?}", place_str, new.get(value))?;
                 }
             }
         }
@@ -1274,7 +1299,7 @@ fn debug_with_context_rec<V: Debug + Eq>(
     Ok(())
 }
 
-fn debug_with_context<V: Debug + Eq>(
+fn debug_with_context<V: Debug + Eq + HasBottom>(
     new: &StateData<V>,
     old: Option<&StateData<V>>,
     map: &Map,