about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDylan DPC <99973273+Dylan-DPC@users.noreply.github.com>2023-02-15 12:24:55 +0530
committerGitHub <noreply@github.com>2023-02-15 12:24:55 +0530
commitc78e3c735aff8aa9bad9cd2dc46b5fdbf4ca1af6 (patch)
tree70b814539bb7b7ff9c4c7b10f400611d93099804
parenta110cf5d1649aa920c046c764c28c666ebdd85cd (diff)
parent09797a463cd1bc70bc439aaf0c94b7b5a80f5bfb (diff)
downloadrust-c78e3c735aff8aa9bad9cd2dc46b5fdbf4ca1af6.tar.gz
rust-c78e3c735aff8aa9bad9cd2dc46b5fdbf4ca1af6.zip
Rollup merge of #107411 - cjgillot:dataflow-discriminant, r=oli-obk
Handle discriminant in DataflowConstProp

cc ``@jachris``
r? ``@JakobDegen``

This PR attempts to extend the DataflowConstProp pass to handle propagation of discriminants. We handle this by adding 2 new variants to `TrackElem`: `TrackElem::Variant` for enum variants and `TrackElem::Discriminant` for the enum discriminant pseudo-place.

The difficulty is that the enum discriminant and enum variants may alias each another. This is the issue of the `Option<NonZeroUsize>` test, which is the equivalent of https://github.com/rust-lang/unsafe-code-guidelines/issues/84 with a direct write.

To handle that, we generalize the flood process to flood all the potentially aliasing places. In particular:
- any write to `(PLACE as Variant)`, either direct or through a projection, floods `(PLACE as OtherVariant)` for all other variants and `discriminant(PLACE)`;
- `SetDiscriminant(PLACE)` floods `(PLACE as Variant)` for each variant.

This implies that flooding is not hierarchical any more, and that an assignment to a non-tracked place may need to flood a tracked place. This is handled by `for_each_aliasing_place` which generalizes `preorder_invoke`.

As we deaggregate enums by putting `SetDiscriminant` last, this allows to propagate the value of the discriminant.

This refactor will allow to make https://github.com/rust-lang/rust/pull/107009 able to handle discriminants too.
-rw-r--r--compiler/rustc_middle/src/mir/mod.rs8
-rw-r--r--compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs4
-rw-r--r--compiler/rustc_mir_dataflow/src/lib.rs1
-rw-r--r--compiler/rustc_mir_dataflow/src/value_analysis.rs275
-rw-r--r--compiler/rustc_mir_transform/src/dataflow_const_prop.rs127
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs18
-rw-r--r--tests/mir-opt/dataflow-const-prop/enum.mutate_discriminant.DataflowConstProp.diff26
-rw-r--r--tests/mir-opt/dataflow-const-prop/enum.rs45
-rw-r--r--tests/mir-opt/dataflow-const-prop/enum.simple.DataflowConstProp.diff (renamed from tests/mir-opt/dataflow-const-prop/enum.main.DataflowConstProp.diff)16
9 files changed, 408 insertions, 112 deletions
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index 752f5be7853..6996d91a80d 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -1642,6 +1642,14 @@ impl<'tcx> PlaceRef<'tcx> {
         }
     }
 
+    /// Returns `true` if this `Place` contains a `Deref` projection.
+    ///
+    /// If `Place::is_indirect` returns false, the caller knows that the `Place` refers to the
+    /// same region of memory as its base.
+    pub fn is_indirect(&self) -> bool {
+        self.projection.iter().any(|elem| elem.is_indirect())
+    }
+
     /// If MirPhase >= Derefered and if projection contains Deref,
     /// It's guaranteed to be in the first place
     pub fn has_deref(&self) -> bool {
diff --git a/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs b/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs
index 0f8e86d1d66..6f4e7fd4682 100644
--- a/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs
+++ b/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs
@@ -121,7 +121,9 @@ where
                 // for now. See discussion on [#61069].
                 //
                 // [#61069]: https://github.com/rust-lang/rust/pull/61069
-                self.trans.gen(dropped_place.local);
+                if !dropped_place.is_indirect() {
+                    self.trans.gen(dropped_place.local);
+                }
             }
 
             TerminatorKind::Abort
diff --git a/compiler/rustc_mir_dataflow/src/lib.rs b/compiler/rustc_mir_dataflow/src/lib.rs
index 7f40cfca32f..3e382f500af 100644
--- a/compiler/rustc_mir_dataflow/src/lib.rs
+++ b/compiler/rustc_mir_dataflow/src/lib.rs
@@ -1,6 +1,7 @@
 #![feature(associated_type_defaults)]
 #![feature(box_patterns)]
 #![feature(exact_size_is_empty)]
+#![feature(let_chains)]
 #![feature(min_specialization)]
 #![feature(once_cell)]
 #![feature(stmt_expr_attributes)]
diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs
index a6ef2a742c8..401db890a98 100644
--- a/compiler/rustc_mir_dataflow/src/value_analysis.rs
+++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs
@@ -24,7 +24,7 @@
 //! - The bottom state denotes uninitialized memory. Because we are only doing a sound approximation
 //! of the actual execution, we can also use this state for places where access would be UB.
 //!
-//! - The assignment logic in `State::assign_place_idx` assumes that the places are non-overlapping,
+//! - The assignment logic in `State::insert_place_idx` assumes that the places are non-overlapping,
 //! or identical. Note that this refers to place expressions, not memory locations.
 //!
 //! - Currently, places that have their reference taken cannot be tracked. Although this would be
@@ -35,6 +35,7 @@
 use std::fmt::{Debug, Formatter};
 
 use rustc_data_structures::fx::FxHashMap;
+use rustc_index::bit_set::BitSet;
 use rustc_index::vec::IndexVec;
 use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor};
 use rustc_middle::mir::*;
@@ -64,10 +65,8 @@ pub trait ValueAnalysis<'tcx> {
             StatementKind::Assign(box (place, rvalue)) => {
                 self.handle_assign(*place, rvalue, state);
             }
-            StatementKind::SetDiscriminant { .. } => {
-                // Could treat this as writing a constant to a pseudo-place.
-                // But discriminants are currently not tracked, so we do nothing.
-                // Related: https://github.com/rust-lang/unsafe-code-guidelines/issues/84
+            StatementKind::SetDiscriminant { box ref place, .. } => {
+                state.flood_discr(place.as_ref(), self.map());
             }
             StatementKind::Intrinsic(box intrinsic) => {
                 self.handle_intrinsic(intrinsic, state);
@@ -446,26 +445,51 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
     }
 
     pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
-        if let Some(root) = map.find(place) {
-            self.flood_idx_with(root, map, value);
-        }
+        let StateData::Reachable(values) = &mut self.0 else { return };
+        map.for_each_aliasing_place(place, None, &mut |place| {
+            if let Some(vi) = map.places[place].value_index {
+                values[vi] = value.clone();
+            }
+        });
     }
 
     pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
         self.flood_with(place, map, V::top())
     }
 
-    pub fn flood_idx_with(&mut self, place: PlaceIndex, map: &Map, value: V) {
+    pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
         let StateData::Reachable(values) = &mut self.0 else { return };
-        map.preorder_invoke(place, &mut |place| {
+        map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |place| {
             if let Some(vi) = map.places[place].value_index {
                 values[vi] = value.clone();
             }
         });
     }
 
-    pub fn flood_idx(&mut self, place: PlaceIndex, map: &Map) {
-        self.flood_idx_with(place, map, V::top())
+    pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) {
+        self.flood_discr_with(place, map, V::top())
+    }
+
+    /// Low-level method that assigns to a place.
+    /// This does nothing if the place is not tracked.
+    ///
+    /// The target place must have been flooded before calling this method.
+    pub fn insert_idx(&mut self, target: PlaceIndex, result: ValueOrPlace<V>, map: &Map) {
+        match result {
+            ValueOrPlace::Value(value) => self.insert_value_idx(target, value, map),
+            ValueOrPlace::Place(source) => self.insert_place_idx(target, source, map),
+        }
+    }
+
+    /// Low-level method that assigns a value to a place.
+    /// This does nothing if the place is not tracked.
+    ///
+    /// The target place must have been flooded before calling this method.
+    pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
+        let StateData::Reachable(values) = &mut self.0 else { return };
+        if let Some(value_index) = map.places[target].value_index {
+            values[value_index] = value;
+        }
     }
 
     /// Copies `source` to `target`, including all tracked places beneath.
@@ -473,50 +497,41 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
     /// If `target` contains a place that is not contained in `source`, it will be overwritten with
     /// Top. Also, because this will copy all entries one after another, it may only be used for
     /// places that are non-overlapping or identical.
-    pub fn assign_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
+    ///
+    /// The target place must have been flooded before calling this method.
+    fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
         let StateData::Reachable(values) = &mut self.0 else { return };
 
-        // If both places are tracked, we copy the value to the target. If the target is tracked,
-        // but the source is not, we have to invalidate the value in target. If the target is not
-        // tracked, then we don't have to do anything.
+        // If both places are tracked, we copy the value to the target.
+        // If the target is tracked, but the source is not, we do nothing, as invalidation has
+        // already been performed.
         if let Some(target_value) = map.places[target].value_index {
             if let Some(source_value) = map.places[source].value_index {
                 values[target_value] = values[source_value].clone();
-            } else {
-                values[target_value] = V::top();
             }
         }
         for target_child in map.children(target) {
             // Try to find corresponding child and recurse. Reasoning is similar as above.
             let projection = map.places[target_child].proj_elem.unwrap();
             if let Some(source_child) = map.projections.get(&(source, projection)) {
-                self.assign_place_idx(target_child, *source_child, map);
-            } else {
-                self.flood_idx(target_child, map);
+                self.insert_place_idx(target_child, *source_child, map);
             }
         }
     }
 
+    /// Helper method to interpret `target = result`.
     pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
+        self.flood(target, map);
         if let Some(target) = map.find(target) {
-            self.assign_idx(target, result, map);
-        } else {
-            // We don't track this place nor any projections, assignment can be ignored.
+            self.insert_idx(target, result, map);
         }
     }
 
-    pub fn assign_idx(&mut self, target: PlaceIndex, result: ValueOrPlace<V>, map: &Map) {
-        match result {
-            ValueOrPlace::Value(value) => {
-                // First flood the target place in case we also track any projections (although
-                // this scenario is currently not well-supported by the API).
-                self.flood_idx(target, map);
-                let StateData::Reachable(values) = &mut self.0 else { return };
-                if let Some(value_index) = map.places[target].value_index {
-                    values[value_index] = value;
-                }
-            }
-            ValueOrPlace::Place(source) => self.assign_place_idx(target, source, map),
+    /// Helper method for assignments to a discriminant.
+    pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
+        self.flood_discr(target, map);
+        if let Some(target) = map.find_discr(target) {
+            self.insert_idx(target, result, map);
         }
     }
 
@@ -525,6 +540,14 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
         map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::top())
     }
 
+    /// Retrieve the value stored for a place, or ⊤ if it is not tracked.
+    pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V {
+        match map.find_discr(place) {
+            Some(place) => self.get_idx(place, map),
+            None => V::top(),
+        }
+    }
+
     /// Retrieve the value stored for a place index, or ⊤ if it is not tracked.
     pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V {
         match &self.0 {
@@ -581,15 +604,15 @@ impl Map {
     /// This is currently the only way to create a [`Map`]. The way in which the tracked places are
     /// chosen is an implementation detail and may not be relied upon (other than that their type
     /// passes the filter).
-    #[instrument(skip_all, level = "debug")]
     pub fn from_filter<'tcx>(
         tcx: TyCtxt<'tcx>,
         body: &Body<'tcx>,
         filter: impl FnMut(Ty<'tcx>) -> bool,
+        place_limit: Option<usize>,
     ) -> Self {
         let mut map = Self::new();
         let exclude = excluded_locals(body);
-        map.register_with_filter(tcx, body, filter, &exclude);
+        map.register_with_filter(tcx, body, filter, exclude, place_limit);
         debug!("registered {} places ({} nodes in total)", map.value_count, map.places.len());
         map
     }
@@ -600,20 +623,28 @@ impl Map {
         tcx: TyCtxt<'tcx>,
         body: &Body<'tcx>,
         mut filter: impl FnMut(Ty<'tcx>) -> bool,
-        exclude: &IndexVec<Local, bool>,
+        exclude: BitSet<Local>,
+        place_limit: Option<usize>,
     ) {
         // We use this vector as stack, pushing and popping projections.
         let mut projection = Vec::new();
         for (local, decl) in body.local_decls.iter_enumerated() {
-            if !exclude[local] {
-                self.register_with_filter_rec(tcx, local, &mut projection, decl.ty, &mut filter);
+            if !exclude.contains(local) {
+                self.register_with_filter_rec(
+                    tcx,
+                    local,
+                    &mut projection,
+                    decl.ty,
+                    &mut filter,
+                    place_limit,
+                );
             }
         }
     }
 
     /// Potentially register the (local, projection) place and its fields, recursively.
     ///
-    /// Invariant: The projection must only contain fields.
+    /// Invariant: The projection must only contain trackable elements.
     fn register_with_filter_rec<'tcx>(
         &mut self,
         tcx: TyCtxt<'tcx>,
@@ -621,27 +652,56 @@ impl Map {
         projection: &mut Vec<PlaceElem<'tcx>>,
         ty: Ty<'tcx>,
         filter: &mut impl FnMut(Ty<'tcx>) -> bool,
+        place_limit: Option<usize>,
     ) {
-        // Note: The framework supports only scalars for now.
-        if filter(ty) && ty.is_scalar() {
-            // We know that the projection only contains trackable elements.
-            let place = self.make_place(local, projection).unwrap();
+        if let Some(place_limit) = place_limit && self.value_count >= place_limit {
+            return
+        }
+
+        // We know that the projection only contains trackable elements.
+        let place = self.make_place(local, projection).unwrap();
 
-            // Allocate a value slot if it doesn't have one.
-            if self.places[place].value_index.is_none() {
-                self.places[place].value_index = Some(self.value_count.into());
-                self.value_count += 1;
+        // Allocate a value slot if it doesn't have one, and the user requested one.
+        if self.places[place].value_index.is_none() && filter(ty) {
+            self.places[place].value_index = Some(self.value_count.into());
+            self.value_count += 1;
+        }
+
+        if ty.is_enum() {
+            let discr_ty = ty.discriminant_ty(tcx);
+            if filter(discr_ty) {
+                let discr = *self
+                    .projections
+                    .entry((place, TrackElem::Discriminant))
+                    .or_insert_with(|| {
+                        // Prepend new child to the linked list.
+                        let next = self.places.push(PlaceInfo::new(Some(TrackElem::Discriminant)));
+                        self.places[next].next_sibling = self.places[place].first_child;
+                        self.places[place].first_child = Some(next);
+                        next
+                    });
+
+                // Allocate a value slot if it doesn't have one.
+                if self.places[discr].value_index.is_none() {
+                    self.places[discr].value_index = Some(self.value_count.into());
+                    self.value_count += 1;
+                }
             }
         }
 
         // Recurse with all fields of this place.
         iter_fields(ty, tcx, |variant, field, ty| {
-            if variant.is_some() {
-                // Downcasts are currently not supported.
+            if let Some(variant) = variant {
+                projection.push(PlaceElem::Downcast(None, variant));
+                let _ = self.make_place(local, projection);
+                projection.push(PlaceElem::Field(field, ty));
+                self.register_with_filter_rec(tcx, local, projection, ty, filter, place_limit);
+                projection.pop();
+                projection.pop();
                 return;
             }
             projection.push(PlaceElem::Field(field, ty));
-            self.register_with_filter_rec(tcx, local, projection, ty, filter);
+            self.register_with_filter_rec(tcx, local, projection, ty, filter, place_limit);
             projection.pop();
         });
     }
@@ -684,23 +744,105 @@ impl Map {
     }
 
     /// Locates the given place, if it exists in the tree.
-    pub fn find(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
+    pub fn find_extra(
+        &self,
+        place: PlaceRef<'_>,
+        extra: impl IntoIterator<Item = TrackElem>,
+    ) -> Option<PlaceIndex> {
         let mut index = *self.locals.get(place.local)?.as_ref()?;
 
         for &elem in place.projection {
             index = self.apply(index, elem.try_into().ok()?)?;
         }
+        for elem in extra {
+            index = self.apply(index, elem)?;
+        }
 
         Some(index)
     }
 
+    /// Locates the given place, if it exists in the tree.
+    pub fn find(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
+        self.find_extra(place, [])
+    }
+
+    /// Locates the given place and applies `Discriminant`, if it exists in the tree.
+    pub fn find_discr(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
+        self.find_extra(place, [TrackElem::Discriminant])
+    }
+
     /// Iterate over all direct children.
     pub fn children(&self, parent: PlaceIndex) -> impl Iterator<Item = PlaceIndex> + '_ {
         Children::new(self, parent)
     }
 
+    /// Invoke a function on the given place and all places that may alias it.
+    ///
+    /// In particular, when the given place has a variant downcast, we invoke the function on all
+    /// the other variants.
+    ///
+    /// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track
+    /// as such.
+    pub fn for_each_aliasing_place(
+        &self,
+        place: PlaceRef<'_>,
+        tail_elem: Option<TrackElem>,
+        f: &mut impl FnMut(PlaceIndex),
+    ) {
+        if place.is_indirect() {
+            // We do not track indirect places.
+            return;
+        }
+        let Some(&Some(mut index)) = self.locals.get(place.local) else {
+            // The local is not tracked at all, so it does not alias anything.
+            return;
+        };
+        let elems = place
+            .projection
+            .iter()
+            .map(|&elem| elem.try_into())
+            .chain(tail_elem.map(Ok).into_iter());
+        for elem in elems {
+            // A field aliases the parent place.
+            f(index);
+
+            let Ok(elem) = elem else { return };
+            let sub = self.apply(index, elem);
+            if let TrackElem::Variant(..) | TrackElem::Discriminant = elem {
+                // Enum variant fields and enum discriminants alias each another.
+                self.for_each_variant_sibling(index, sub, f);
+            }
+            if let Some(sub) = sub {
+                index = sub
+            } else {
+                return;
+            }
+        }
+        self.preorder_invoke(index, f);
+    }
+
+    /// Invoke the given function on all the descendants of the given place, except one branch.
+    fn for_each_variant_sibling(
+        &self,
+        parent: PlaceIndex,
+        preserved_child: Option<PlaceIndex>,
+        f: &mut impl FnMut(PlaceIndex),
+    ) {
+        for sibling in self.children(parent) {
+            let elem = self.places[sibling].proj_elem;
+            // Only invalidate variants and discriminant. Fields (for generators) are not
+            // invalidated by assignment to a variant.
+            if let Some(TrackElem::Variant(..) | TrackElem::Discriminant) = elem
+                // Only invalidate the other variants, the current one is fine.
+                && Some(sibling) != preserved_child
+            {
+                self.preorder_invoke(sibling, f);
+            }
+        }
+    }
+
     /// Invoke a function on the given place and all descendants.
-    pub fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
+    fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
         f(root);
         for child in self.children(root) {
             self.preorder_invoke(child, f);
@@ -759,6 +901,7 @@ impl<'a> Iterator for Children<'a> {
 }
 
 /// Used as the result of an operand or r-value.
+#[derive(Debug)]
 pub enum ValueOrPlace<V> {
     Value(V),
     Place(PlaceIndex),
@@ -776,6 +919,8 @@ impl<V: HasTop> ValueOrPlace<V> {
 #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
 pub enum TrackElem {
     Field(Field),
+    Variant(VariantIdx),
+    Discriminant,
 }
 
 impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
@@ -784,6 +929,7 @@ impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
     fn try_from(value: ProjectionElem<V, T>) -> Result<Self, Self::Error> {
         match value {
             ProjectionElem::Field(field, _) => Ok(TrackElem::Field(field)),
+            ProjectionElem::Downcast(_, idx) => Ok(TrackElem::Variant(idx)),
             _ => Err(()),
         }
     }
@@ -824,26 +970,27 @@ pub fn iter_fields<'tcx>(
 }
 
 /// Returns all locals with projections that have their reference or address taken.
-pub fn excluded_locals(body: &Body<'_>) -> IndexVec<Local, bool> {
+pub fn excluded_locals(body: &Body<'_>) -> BitSet<Local> {
     struct Collector {
-        result: IndexVec<Local, bool>,
+        result: BitSet<Local>,
     }
 
     impl<'tcx> Visitor<'tcx> for Collector {
         fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
-            if context.is_borrow()
+            if (context.is_borrow()
                 || context.is_address_of()
                 || context.is_drop()
-                || context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput)
+                || context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput))
+                && !place.is_indirect()
             {
                 // A pointer to a place could be used to access other places with the same local,
                 // hence we have to exclude the local completely.
-                self.result[place.local] = true;
+                self.result.insert(place.local);
             }
         }
     }
 
-    let mut collector = Collector { result: IndexVec::from_elem(false, &body.local_decls) };
+    let mut collector = Collector { result: BitSet::new_empty(body.local_decls.len()) };
     collector.visit_body(body);
     collector.result
 }
@@ -899,6 +1046,12 @@ fn debug_with_context_rec<V: Debug + Eq>(
     for child in map.children(place) {
         let info_elem = map.places[child].proj_elem.unwrap();
         let child_place_str = match info_elem {
+            TrackElem::Discriminant => {
+                format!("discriminant({})", place_str)
+            }
+            TrackElem::Variant(idx) => {
+                format!("({} as {:?})", place_str, idx)
+            }
             TrackElem::Field(field) => {
                 if place_str.starts_with('*') {
                     format!("({}).{}", place_str, field.index())
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
index 949a59a97bf..f3ca2337e59 100644
--- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
+++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
@@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
 use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
 use rustc_span::DUMMY_SP;
 use rustc_target::abi::Align;
+use rustc_target::abi::VariantIdx;
 
 use crate::MirPass;
 
@@ -30,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
 
     #[instrument(skip_all level = "debug")]
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+        debug!(def_id = ?body.source.def_id());
         if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
             debug!("aborted dataflow const prop due too many basic blocks");
             return;
         }
 
-        // Decide which places to track during the analysis.
-        let map = Map::from_filter(tcx, body, Ty::is_scalar);
-
         // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
         // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
         // applications, where `h` is the height of the lattice. Because the height of our lattice
@@ -46,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
         // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
         // map nodes is strongly correlated to the number of tracked places, this becomes more or
         // less `O(n)` if we place a constant limit on the number of tracked places.
-        if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT {
-            debug!("aborted dataflow const prop due to too many tracked places");
-            return;
-        }
+        let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None };
+
+        // Decide which places to track during the analysis.
+        let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit);
 
         // Perform the actual dataflow analysis.
         let analysis = ConstAnalysis::new(tcx, body, map);
@@ -63,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
     }
 }
 
-struct ConstAnalysis<'tcx> {
+struct ConstAnalysis<'a, 'tcx> {
     map: Map,
     tcx: TyCtxt<'tcx>,
+    local_decls: &'a LocalDecls<'tcx>,
     ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
     param_env: ty::ParamEnv<'tcx>,
 }
 
-impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
+impl<'tcx> ConstAnalysis<'_, 'tcx> {
+    fn eval_discriminant(
+        &self,
+        enum_ty: Ty<'tcx>,
+        variant_index: VariantIdx,
+    ) -> Option<ScalarTy<'tcx>> {
+        if !enum_ty.is_enum() {
+            return None;
+        }
+        let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
+        let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
+        let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
+        Some(ScalarTy(discr_value, discr.ty))
+    }
+}
+
+impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
     type Value = FlatSet<ScalarTy<'tcx>>;
 
     const NAME: &'static str = "ConstAnalysis";
@@ -79,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
         &self.map
     }
 
+    fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
+        match statement.kind {
+            StatementKind::SetDiscriminant { box ref place, variant_index } => {
+                state.flood_discr(place.as_ref(), &self.map);
+                if self.map.find_discr(place.as_ref()).is_some() {
+                    let enum_ty = place.ty(self.local_decls, self.tcx).ty;
+                    if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) {
+                        state.assign_discr(
+                            place.as_ref(),
+                            ValueOrPlace::Value(FlatSet::Elem(discr)),
+                            &self.map,
+                        );
+                    }
+                }
+            }
+            _ => self.super_statement(statement, state),
+        }
+    }
+
     fn handle_assign(
         &self,
         target: Place<'tcx>,
@@ -87,36 +122,47 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
     ) {
         match rvalue {
             Rvalue::Aggregate(kind, operands) => {
-                let target = self.map().find(target.as_ref());
-                if let Some(target) = target {
-                    state.flood_idx_with(target, self.map(), FlatSet::Bottom);
-                    let field_based = match **kind {
-                        AggregateKind::Tuple | AggregateKind::Closure(..) => true,
-                        AggregateKind::Adt(def_id, ..) => {
-                            matches!(self.tcx.def_kind(def_id), DefKind::Struct)
+                state.flood_with(target.as_ref(), self.map(), FlatSet::Bottom);
+                if let Some(target_idx) = self.map().find(target.as_ref()) {
+                    let (variant_target, variant_index) = match **kind {
+                        AggregateKind::Tuple | AggregateKind::Closure(..) => {
+                            (Some(target_idx), None)
                         }
-                        _ => false,
+                        AggregateKind::Adt(def_id, variant_index, ..) => {
+                            match self.tcx.def_kind(def_id) {
+                                DefKind::Struct => (Some(target_idx), None),
+                                DefKind::Enum => (Some(target_idx), Some(variant_index)),
+                                _ => (None, None),
+                            }
+                        }
+                        _ => (None, None),
                     };
-                    if field_based {
+                    if let Some(target) = variant_target {
                         for (field_index, operand) in operands.iter().enumerate() {
                             if let Some(field) = self
                                 .map()
                                 .apply(target, TrackElem::Field(Field::from_usize(field_index)))
                             {
                                 let result = self.handle_operand(operand, state);
-                                state.assign_idx(field, result, self.map());
+                                state.insert_idx(field, result, self.map());
                             }
                         }
                     }
+                    if let Some(variant_index) = variant_index
+                        && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
+                    {
+                        let enum_ty = target.ty(self.local_decls, self.tcx).ty;
+                        if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) {
+                            state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map);
+                        }
+                    }
                 }
             }
             Rvalue::CheckedBinaryOp(op, box (left, right)) => {
+                // Flood everything now, so we can use `insert_value_idx` directly later.
+                state.flood(target.as_ref(), self.map());
+
                 let target = self.map().find(target.as_ref());
-                if let Some(target) = target {
-                    // We should not track any projections other than
-                    // what is overwritten below, but just in case...
-                    state.flood_idx(target, self.map());
-                }
 
                 let value_target = target
                     .and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
@@ -127,7 +173,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
                     let (val, overflow) = self.binary_op(state, *op, left, right);
 
                     if let Some(value_target) = value_target {
-                        state.assign_idx(value_target, ValueOrPlace::Value(val), self.map());
+                        // We have flooded `target` earlier.
+                        state.insert_value_idx(value_target, val, self.map());
                     }
                     if let Some(overflow_target) = overflow_target {
                         let overflow = match overflow {
@@ -142,11 +189,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
                             }
                             FlatSet::Bottom => FlatSet::Bottom,
                         };
-                        state.assign_idx(
-                            overflow_target,
-                            ValueOrPlace::Value(overflow),
-                            self.map(),
-                        );
+                        // We have flooded `target` earlier.
+                        state.insert_value_idx(overflow_target, overflow, self.map());
                     }
                 }
             }
@@ -195,6 +239,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
                 FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
                 FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
             },
+            Rvalue::Discriminant(place) => {
+                ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
+            }
             _ => self.super_rvalue(rvalue, state),
         }
     }
@@ -268,12 +315,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
     }
 }
 
-impl<'tcx> ConstAnalysis<'tcx> {
-    pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
+impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
+    pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
         let param_env = tcx.param_env(body.source.def_id());
         Self {
             map,
             tcx,
+            local_decls: &body.local_decls,
             ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
             param_env: param_env,
         }
@@ -466,6 +514,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
             _ => (),
         }
     }
+
+    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
+        match rvalue {
+            Rvalue::Discriminant(place) => {
+                match self.state.get_discr(place.as_ref(), self.visitor.map) {
+                    FlatSet::Top => (),
+                    FlatSet::Elem(value) => {
+                        self.visitor.before_effect.insert((location, *place), value);
+                    }
+                    FlatSet::Bottom => (),
+                }
+            }
+            _ => self.super_rvalue(rvalue, location),
+        }
+    }
 }
 
 struct DummyMachine;
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 30d8511153a..8a37423b2a0 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -1,5 +1,5 @@
 use crate::MirPass;
-use rustc_index::bit_set::BitSet;
+use rustc_index::bit_set::{BitSet, GrowableBitSet};
 use rustc_index::vec::IndexVec;
 use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::visit::*;
@@ -26,10 +26,12 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
             debug!(?replacements);
             let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
             if !all_dead_locals.is_empty() {
-                for local in excluded.indices() {
-                    excluded[local] |= all_dead_locals.contains(local);
-                }
-                excluded.raw.resize(body.local_decls.len(), false);
+                excluded.union(&all_dead_locals);
+                excluded = {
+                    let mut growable = GrowableBitSet::from(excluded);
+                    growable.ensure(body.local_decls.len());
+                    growable.into()
+                };
             } else {
                 break;
             }
@@ -44,11 +46,11 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
 /// - the locals is a union or an enum;
 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
 ///   client code.
-fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
     let mut set = BitSet::new_empty(body.local_decls.len());
     set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
     for (local, decl) in body.local_decls().iter_enumerated() {
-        if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
+        if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
             set.insert(local);
         }
     }
@@ -172,7 +174,7 @@ fn replace_flattened_locals<'tcx>(
     body: &mut Body<'tcx>,
     replacements: ReplacementMap<'tcx>,
 ) -> BitSet<Local> {
-    let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
+    let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
     for (local, replacements) in replacements.fragments.iter_enumerated() {
         if replacements.is_some() {
             all_dead_locals.insert(local);
diff --git a/tests/mir-opt/dataflow-const-prop/enum.mutate_discriminant.DataflowConstProp.diff b/tests/mir-opt/dataflow-const-prop/enum.mutate_discriminant.DataflowConstProp.diff
new file mode 100644
index 00000000000..038e6c6bd90
--- /dev/null
+++ b/tests/mir-opt/dataflow-const-prop/enum.mutate_discriminant.DataflowConstProp.diff
@@ -0,0 +1,26 @@
+- // MIR for `mutate_discriminant` before DataflowConstProp
++ // MIR for `mutate_discriminant` after DataflowConstProp
+  
+  fn mutate_discriminant() -> u8 {
+      let mut _0: u8;                      // return place in scope 0 at $DIR/enum.rs:+0:29: +0:31
+      let mut _1: std::option::Option<NonZeroUsize>; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
+      let mut _2: isize;                   // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
+  
+      bb0: {
+          discriminant(_1) = 1;            // scope 0 at $DIR/enum.rs:+4:13: +4:34
+          (((_1 as variant#1).0: NonZeroUsize).0: usize) = const 0_usize; // scope 0 at $DIR/enum.rs:+6:13: +6:64
+          _2 = discriminant(_1);           // scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
+          switchInt(_2) -> [0: bb1, otherwise: bb2]; // scope 0 at $DIR/enum.rs:+9:13: +12:14
+      }
+  
+      bb1: {
+          _0 = const 1_u8;                 // scope 0 at $DIR/enum.rs:+15:13: +15:20
+          return;                          // scope 0 at $DIR/enum.rs:+16:13: +16:21
+      }
+  
+      bb2: {
+          _0 = const 2_u8;                 // scope 0 at $DIR/enum.rs:+19:13: +19:20
+          unreachable;                     // scope 0 at $DIR/enum.rs:+20:13: +20:26
+      }
+  }
+  
diff --git a/tests/mir-opt/dataflow-const-prop/enum.rs b/tests/mir-opt/dataflow-const-prop/enum.rs
index 13288577dea..7ea405bd9c4 100644
--- a/tests/mir-opt/dataflow-const-prop/enum.rs
+++ b/tests/mir-opt/dataflow-const-prop/enum.rs
@@ -1,13 +1,52 @@
 // unit-test: DataflowConstProp
 
-// Not trackable, because variants could be aliased.
+#![feature(custom_mir, core_intrinsics, rustc_attrs)]
+
+use std::intrinsics::mir::*;
+
 enum E {
     V1(i32),
     V2(i32)
 }
 
-// EMIT_MIR enum.main.DataflowConstProp.diff
-fn main() {
+// EMIT_MIR enum.simple.DataflowConstProp.diff
+fn simple() {
     let e = E::V1(0);
     let x = match e { E::V1(x) => x, E::V2(x) => x };
 }
+
+#[rustc_layout_scalar_valid_range_start(1)]
+#[rustc_nonnull_optimization_guaranteed]
+struct NonZeroUsize(usize);
+
+// EMIT_MIR enum.mutate_discriminant.DataflowConstProp.diff
+#[custom_mir(dialect = "runtime", phase = "post-cleanup")]
+fn mutate_discriminant() -> u8 {
+    mir!(
+        let x: Option<NonZeroUsize>;
+        {
+            SetDiscriminant(x, 1);
+            // This assignment overwrites the niche in which the discriminant is stored.
+            place!(Field(Field(Variant(x, 1), 0), 0)) = 0_usize;
+            // So we cannot know the value of this discriminant.
+            let a = Discriminant(x);
+            match a {
+                0 => bb1,
+                _ => bad,
+            }
+        }
+        bb1 = {
+            RET = 1;
+            Return()
+        }
+        bad = {
+            RET = 2;
+            Unreachable()
+        }
+    )
+}
+
+fn main() {
+    simple();
+    mutate_discriminant();
+}
diff --git a/tests/mir-opt/dataflow-const-prop/enum.main.DataflowConstProp.diff b/tests/mir-opt/dataflow-const-prop/enum.simple.DataflowConstProp.diff
index d049c79d78d..1fb65e65845 100644
--- a/tests/mir-opt/dataflow-const-prop/enum.main.DataflowConstProp.diff
+++ b/tests/mir-opt/dataflow-const-prop/enum.simple.DataflowConstProp.diff
@@ -1,8 +1,8 @@
-- // MIR for `main` before DataflowConstProp
-+ // MIR for `main` after DataflowConstProp
+- // MIR for `simple` before DataflowConstProp
++ // MIR for `simple` after DataflowConstProp
   
-  fn main() -> () {
-      let mut _0: ();                      // return place in scope 0 at $DIR/enum.rs:+0:11: +0:11
+  fn simple() -> () {
+      let mut _0: ();                      // return place in scope 0 at $DIR/enum.rs:+0:13: +0:13
       let _1: E;                           // in scope 0 at $DIR/enum.rs:+1:9: +1:10
       let mut _3: isize;                   // in scope 0 at $DIR/enum.rs:+2:23: +2:31
       scope 1 {
@@ -25,8 +25,10 @@
           StorageLive(_1);                 // scope 0 at $DIR/enum.rs:+1:9: +1:10
           _1 = E::V1(const 0_i32);         // scope 0 at $DIR/enum.rs:+1:13: +1:21
           StorageLive(_2);                 // scope 1 at $DIR/enum.rs:+2:9: +2:10
-          _3 = discriminant(_1);           // scope 1 at $DIR/enum.rs:+2:19: +2:20
-          switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20
+-         _3 = discriminant(_1);           // scope 1 at $DIR/enum.rs:+2:19: +2:20
+-         switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20
++         _3 = const 0_isize;              // scope 1 at $DIR/enum.rs:+2:19: +2:20
++         switchInt(const 0_isize) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20
       }
   
       bb1: {
@@ -50,7 +52,7 @@
       }
   
       bb4: {
-          _0 = const ();                   // scope 0 at $DIR/enum.rs:+0:11: +3:2
+          _0 = const ();                   // scope 0 at $DIR/enum.rs:+0:13: +3:2
           StorageDead(_2);                 // scope 1 at $DIR/enum.rs:+3:1: +3:2
           StorageDead(_1);                 // scope 0 at $DIR/enum.rs:+3:1: +3:2
           return;                          // scope 0 at $DIR/enum.rs:+3:2: +3:2