about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_const_eval/src/check_consts/check.rs62
1 files changed, 45 insertions, 17 deletions
diff --git a/compiler/rustc_const_eval/src/check_consts/check.rs b/compiler/rustc_const_eval/src/check_consts/check.rs
index 86a5afa65ba..d7594c9ae75 100644
--- a/compiler/rustc_const_eval/src/check_consts/check.rs
+++ b/compiler/rustc_const_eval/src/check_consts/check.rs
@@ -1,6 +1,7 @@
 //! The `Visitor` responsible for actually checking a `mir::Body` for invalid operations.
 
 use std::assert_matches::assert_matches;
+use std::borrow::Cow;
 use std::mem;
 use std::ops::Deref;
 
@@ -15,7 +16,9 @@ use rustc_middle::mir::*;
 use rustc_middle::span_bug;
 use rustc_middle::ty::adjustment::PointerCoercion;
 use rustc_middle::ty::{self, Instance, InstanceKind, Ty, TyCtxt, TypeVisitableExt};
-use rustc_mir_dataflow::Analysis;
+use rustc_mir_dataflow::impls::MaybeStorageLive;
+use rustc_mir_dataflow::storage::always_storage_live_locals;
+use rustc_mir_dataflow::{Analysis, ResultsCursor};
 use rustc_span::{sym, Span, Symbol, DUMMY_SP};
 use rustc_trait_selection::error_reporting::InferCtxtErrorExt;
 use rustc_trait_selection::traits::{self, ObligationCauseCode, ObligationCtxt};
@@ -188,8 +191,9 @@ pub struct Checker<'mir, 'tcx> {
     /// The span of the current statement.
     span: Span,
 
-    /// A set that stores for each local whether it has a `StorageDead` for it somewhere.
-    local_has_storage_dead: Option<BitSet<Local>>,
+    /// A set that stores for each local whether it is "transient", i.e. guaranteed to be dead
+    /// when this MIR body returns.
+    transient_locals: Option<BitSet<Local>>,
 
     error_emitted: Option<ErrorGuaranteed>,
     secondary_errors: Vec<Diag<'tcx>>,
@@ -209,7 +213,7 @@ impl<'mir, 'tcx> Checker<'mir, 'tcx> {
             span: ccx.body.span,
             ccx,
             qualifs: Default::default(),
-            local_has_storage_dead: None,
+            transient_locals: None,
             error_emitted: None,
             secondary_errors: Vec::new(),
         }
@@ -264,23 +268,47 @@ impl<'mir, 'tcx> Checker<'mir, 'tcx> {
         }
     }
 
-    fn local_has_storage_dead(&mut self, local: Local) -> bool {
+    fn local_is_transient(&mut self, local: Local) -> bool {
         let ccx = self.ccx;
-        self.local_has_storage_dead
+        self.transient_locals
             .get_or_insert_with(|| {
-                struct StorageDeads {
-                    locals: BitSet<Local>,
+                // A local is "transient" if it is guaranteed dead at all `Return`.
+                // So first compute the say of "maybe live" locals at each program point.
+                let always_live_locals = &always_storage_live_locals(&ccx.body);
+                let maybe_storage_live = MaybeStorageLive::new(Cow::Borrowed(always_live_locals))
+                    .into_engine(ccx.tcx, &ccx.body)
+                    .iterate_to_fixpoint()
+                    .into_results_cursor(&ccx.body);
+
+                // And then check all `Return` in the MIR, and if a local is "maybe live" at a
+                // `Return` then it is definitely not transient.
+                struct TransientLocalVisitor<'a, 'tcx> {
+                    maybe_storage_live: ResultsCursor<'a, 'tcx, MaybeStorageLive<'a>>,
+                    transient: BitSet<Local>,
                 }
-                impl<'tcx> Visitor<'tcx> for StorageDeads {
-                    fn visit_statement(&mut self, stmt: &Statement<'tcx>, _: Location) {
-                        if let StatementKind::StorageDead(l) = stmt.kind {
-                            self.locals.insert(l);
+                impl<'a, 'tcx> Visitor<'tcx> for TransientLocalVisitor<'a, 'tcx> {
+                    fn visit_terminator(
+                        &mut self,
+                        terminator: &Terminator<'tcx>,
+                        location: Location,
+                    ) {
+                        if matches!(terminator.kind, TerminatorKind::Return) {
+                            self.maybe_storage_live.seek_after_primary_effect(location);
+                            for local in self.maybe_storage_live.get().iter() {
+                                // If a local may be live here, it is definitely not transient.
+                                self.transient.remove(local);
+                            }
                         }
                     }
                 }
-                let mut v = StorageDeads { locals: BitSet::new_empty(ccx.body.local_decls.len()) };
-                v.visit_body(ccx.body);
-                v.locals
+
+                let mut v = TransientLocalVisitor {
+                    maybe_storage_live,
+                    transient: BitSet::new_filled(ccx.body.local_decls.len()),
+                };
+                v.visit_body(&ccx.body);
+
+                v.transient
             })
             .contains(local)
     }
@@ -375,7 +403,7 @@ impl<'mir, 'tcx> Checker<'mir, 'tcx> {
                 // `StorageDead` in every control flow path leading to a `return` terminator.
                 // The good news is that interning will detect if any unexpected mutable
                 // pointer slips through.
-                if place.is_indirect() || self.local_has_storage_dead(place.local) {
+                if place.is_indirect() || self.local_is_transient(place.local) {
                     self.check_op(ops::TransientMutBorrow(kind));
                 } else {
                     self.check_op(ops::MutBorrow(kind));
@@ -526,7 +554,7 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
                             // `StorageDead` in every control flow path leading to a `return` terminator.
                             // The good news is that interning will detect if any unexpected mutable
                             // pointer slips through.
-                            if self.local_has_storage_dead(place.local) {
+                            if self.local_is_transient(place.local) {
                                 self.check_op(ops::TransientCellBorrow);
                             } else {
                                 self.check_op(ops::CellBorrow);