about summary refs log tree commit diff
path: root/compiler/rustc_mir_dataflow
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_dataflow')
-rw-r--r--compiler/rustc_mir_dataflow/src/value_analysis.rs149
1 files changed, 81 insertions, 68 deletions
diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs
index bfbfff7e259..0364c23bfcb 100644
--- a/compiler/rustc_mir_dataflow/src/value_analysis.rs
+++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs
@@ -39,7 +39,7 @@ use std::ops::Range;
 use rustc_data_structures::fx::FxHashMap;
 use rustc_data_structures::stack::ensure_sufficient_stack;
 use rustc_index::bit_set::BitSet;
-use rustc_index::{IndexSlice, IndexVec};
+use rustc_index::IndexVec;
 use rustc_middle::bug;
 use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor};
 use rustc_middle::mir::*;
@@ -336,14 +336,14 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper
     const NAME: &'static str = T::NAME;
 
     fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain {
-        State(StateData::Unreachable)
+        State::Unreachable
     }
 
     fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) {
         // The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥.
-        assert!(matches!(state.0, StateData::Unreachable));
-        let values = IndexVec::from_elem_n(T::Value::BOTTOM, self.0.map().value_count);
-        *state = State(StateData::Reachable(values));
+        assert!(matches!(state, State::Unreachable));
+        let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count);
+        *state = State::Reachable(values);
         for arg in body.args_iter() {
             state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map());
         }
@@ -415,27 +415,30 @@ rustc_index::newtype_index!(
 
 /// See [`State`].
 #[derive(PartialEq, Eq, Debug)]
-enum StateData<V> {
-    Reachable(IndexVec<ValueIndex, V>),
-    Unreachable,
+struct StateData<V> {
+    map: IndexVec<ValueIndex, V>,
+}
+
+impl<V: Clone> StateData<V> {
+    fn from_elem_n(elem: V, n: usize) -> StateData<V> {
+        StateData { map: IndexVec::from_elem_n(elem, n) }
+    }
 }
 
 impl<V: Clone> Clone for StateData<V> {
     fn clone(&self) -> Self {
-        match self {
-            Self::Reachable(x) => Self::Reachable(x.clone()),
-            Self::Unreachable => Self::Unreachable,
-        }
+        StateData { map: self.map.clone() }
     }
 
     fn clone_from(&mut self, source: &Self) {
-        match (&mut *self, source) {
-            (Self::Reachable(x), Self::Reachable(y)) => {
-                // We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
-                x.raw.clone_from(&y.raw);
-            }
-            _ => *self = source.clone(),
-        }
+        // We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
+        self.map.raw.clone_from(&source.map.raw)
+    }
+}
+
+impl<V: JoinSemiLattice + Clone> JoinSemiLattice for StateData<V> {
+    fn join(&mut self, other: &Self) -> bool {
+        self.map.join(&other.map)
     }
 }
 
@@ -450,33 +453,43 @@ impl<V: Clone> Clone for StateData<V> {
 ///
 /// Flooding means assigning a value (by default `⊤`) to all tracked projections of a given place.
 #[derive(PartialEq, Eq, Debug)]
-pub struct State<V>(StateData<V>);
+pub enum State<V> {
+    Unreachable,
+    Reachable(StateData<V>),
+}
 
 impl<V: Clone> Clone for State<V> {
     fn clone(&self) -> Self {
-        Self(self.0.clone())
+        match self {
+            Self::Reachable(x) => Self::Reachable(x.clone()),
+            Self::Unreachable => Self::Unreachable,
+        }
     }
 
     fn clone_from(&mut self, source: &Self) {
-        self.0.clone_from(&source.0);
+        match (&mut *self, source) {
+            (Self::Reachable(x), Self::Reachable(y)) => {
+                x.clone_from(&y);
+            }
+            _ => *self = source.clone(),
+        }
     }
 }
 
 impl<V: Clone> State<V> {
     pub fn new(init: V, map: &Map) -> State<V> {
-        let values = IndexVec::from_elem_n(init, map.value_count);
-        State(StateData::Reachable(values))
+        State::Reachable(StateData::from_elem_n(init, map.value_count))
     }
 
     pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
-        match self.0 {
-            StateData::Unreachable => true,
-            StateData::Reachable(ref values) => values.iter().all(f),
+        match self {
+            State::Unreachable => true,
+            State::Reachable(ref values) => values.map.iter().all(f),
         }
     }
 
     fn is_reachable(&self) -> bool {
-        matches!(&self.0, StateData::Reachable(_))
+        matches!(self, State::Reachable(_))
     }
 
     /// Assign `value` to all places that are contained in `place` or may alias one.
@@ -519,9 +532,9 @@ impl<V: Clone> State<V> {
         map: &Map,
         value: V,
     ) {
-        let StateData::Reachable(values) = &mut self.0 else { return };
+        let State::Reachable(values) = self else { return };
         map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
-            values[vi] = value.clone();
+            values.map[vi] = value.clone();
         });
     }
 
@@ -541,9 +554,9 @@ impl<V: Clone> State<V> {
     ///
     /// The target place must have been flooded before calling this method.
     pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
-        let StateData::Reachable(values) = &mut self.0 else { return };
+        let State::Reachable(values) = self else { return };
         if let Some(value_index) = map.places[target].value_index {
-            values[value_index] = value;
+            values.map[value_index] = value;
         }
     }
 
@@ -555,14 +568,14 @@ impl<V: Clone> State<V> {
     ///
     /// The target place must have been flooded before calling this method.
     pub fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
-        let StateData::Reachable(values) = &mut self.0 else { return };
+        let State::Reachable(values) = self else { return };
 
         // If both places are tracked, we copy the value to the target.
         // If the target is tracked, but the source is not, we do nothing, as invalidation has
         // already been performed.
         if let Some(target_value) = map.places[target].value_index {
             if let Some(source_value) = map.places[source].value_index {
-                values[target_value] = values[source_value].clone();
+                values.map[target_value] = values.map[source_value].clone();
             }
         }
         for target_child in map.children(target) {
@@ -616,11 +629,11 @@ impl<V: Clone> State<V> {
 
     /// Retrieve the value stored for a place index, or `None` if it is not tracked.
     pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
-        match &self.0 {
-            StateData::Reachable(values) => {
-                map.places[place].value_index.map(|v| values[v].clone())
+        match self {
+            State::Reachable(values) => {
+                map.places[place].value_index.map(|v| values.map[v].clone())
             }
-            StateData::Unreachable => None,
+            State::Unreachable => None,
         }
     }
 
@@ -631,10 +644,10 @@ impl<V: Clone> State<V> {
     where
         V: HasBottom + HasTop,
     {
-        match &self.0 {
-            StateData::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
+        match self {
+            State::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
             // Because this is unreachable, we can return any value we want.
-            StateData::Unreachable => V::BOTTOM,
+            State::Unreachable => V::BOTTOM,
         }
     }
 
@@ -645,10 +658,10 @@ impl<V: Clone> State<V> {
     where
         V: HasBottom + HasTop,
     {
-        match &self.0 {
-            StateData::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
+        match self {
+            State::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
             // Because this is unreachable, we can return any value we want.
-            StateData::Unreachable => V::BOTTOM,
+            State::Unreachable => V::BOTTOM,
         }
     }
 
@@ -659,10 +672,10 @@ impl<V: Clone> State<V> {
     where
         V: HasBottom + HasTop,
     {
-        match &self.0 {
-            StateData::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
+        match self {
+            State::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
             // Because this is unreachable, we can return any value we want.
-            StateData::Unreachable => V::BOTTOM,
+            State::Unreachable => V::BOTTOM,
         }
     }
 
@@ -673,11 +686,11 @@ impl<V: Clone> State<V> {
     where
         V: HasBottom + HasTop,
     {
-        match &self.0 {
-            StateData::Reachable(values) => {
-                map.places[place].value_index.map(|v| values[v].clone()).unwrap_or(V::TOP)
+        match self {
+            State::Reachable(values) => {
+                map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP)
             }
-            StateData::Unreachable => {
+            State::Unreachable => {
                 // Because this is unreachable, we can return any value we want.
                 V::BOTTOM
             }
@@ -687,13 +700,13 @@ impl<V: Clone> State<V> {
 
 impl<V: JoinSemiLattice + Clone> JoinSemiLattice for State<V> {
     fn join(&mut self, other: &Self) -> bool {
-        match (&mut self.0, &other.0) {
-            (_, StateData::Unreachable) => false,
-            (StateData::Unreachable, _) => {
+        match (&mut *self, other) {
+            (_, State::Unreachable) => false,
+            (State::Unreachable, _) => {
                 *self = other.clone();
                 true
             }
-            (StateData::Reachable(this), StateData::Reachable(other)) => this.join(other),
+            (State::Reachable(this), State::Reachable(ref other)) => this.join(other),
         }
     }
 }
@@ -1194,9 +1207,9 @@ where
     T::Value: Debug,
 {
     fn fmt_with(&self, ctxt: &ValueAnalysisWrapper<T>, f: &mut Formatter<'_>) -> std::fmt::Result {
-        match &self.0 {
-            StateData::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f),
-            StateData::Unreachable => write!(f, "unreachable"),
+        match self {
+            State::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f),
+            State::Unreachable => write!(f, "unreachable"),
         }
     }
 
@@ -1206,8 +1219,8 @@ where
         ctxt: &ValueAnalysisWrapper<T>,
         f: &mut Formatter<'_>,
     ) -> std::fmt::Result {
-        match (&self.0, &old.0) {
-            (StateData::Reachable(this), StateData::Reachable(old)) => {
+        match (self, old) {
+            (State::Reachable(this), State::Reachable(old)) => {
                 debug_with_context(this, Some(old), ctxt.0.map(), f)
             }
             _ => Ok(()), // Consider printing something here.
@@ -1218,18 +1231,18 @@ where
 fn debug_with_context_rec<V: Debug + Eq>(
     place: PlaceIndex,
     place_str: &str,
-    new: &IndexSlice<ValueIndex, V>,
-    old: Option<&IndexSlice<ValueIndex, V>>,
+    new: &StateData<V>,
+    old: Option<&StateData<V>>,
     map: &Map,
     f: &mut Formatter<'_>,
 ) -> std::fmt::Result {
     if let Some(value) = map.places[place].value_index {
         match old {
-            None => writeln!(f, "{}: {:?}", place_str, new[value])?,
+            None => writeln!(f, "{}: {:?}", place_str, new.map[value])?,
             Some(old) => {
-                if new[value] != old[value] {
-                    writeln!(f, "\u{001f}-{}: {:?}", place_str, old[value])?;
-                    writeln!(f, "\u{001f}+{}: {:?}", place_str, new[value])?;
+                if new.map[value] != old.map[value] {
+                    writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?;
+                    writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?;
                 }
             }
         }
@@ -1262,8 +1275,8 @@ fn debug_with_context_rec<V: Debug + Eq>(
 }
 
 fn debug_with_context<V: Debug + Eq>(
-    new: &IndexSlice<ValueIndex, V>,
-    old: Option<&IndexSlice<ValueIndex, V>>,
+    new: &StateData<V>,
+    old: Option<&StateData<V>>,
     map: &Map,
     f: &mut Formatter<'_>,
 ) -> std::fmt::Result {