about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-12-24 05:19:43 +0000
committerbors <bors@rust-lang.org>2023-12-24 05:19:43 +0000
commit1754946b3ed693834bb79760169e528277fd98d2 (patch)
treeb3d97623b0f43b549734a6730ef55f0adf29dbea /compiler/rustc_mir_transform/src
parent2c7e0fd373ae8f7d8c5aee1c0ec8e3249e3f3649 (diff)
parent29f25ee3f381922b39a67089bb07d70bfbe2f17e (diff)
downloadrust-1754946b3ed693834bb79760169e528277fd98d2.tar.gz
rust-1754946b3ed693834bb79760169e528277fd98d2.zip
Auto merge of #3238 - rust-lang:rustup-2023-12-24, r=saethlin
Automatic Rustup
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs44
-rw-r--r--compiler/rustc_mir_transform/src/errors.rs8
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs1
-rw-r--r--compiler/rustc_mir_transform/src/lint.rs119
-rw-r--r--compiler/rustc_mir_transform/src/pass_manager.rs12
-rw-r--r--compiler/rustc_mir_transform/src/ref_prop.rs3
6 files changed, 165 insertions, 22 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index d7dd44af7d2..6da102dcb1c 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -59,7 +59,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_errors::pluralize;
 use rustc_hir as hir;
 use rustc_hir::lang_items::LangItem;
-use rustc_hir::CoroutineKind;
+use rustc_hir::{CoroutineDesugaring, CoroutineKind};
 use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
 use rustc_index::{Idx, IndexVec};
 use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
@@ -254,10 +254,12 @@ impl<'tcx> TransformVisitor<'tcx> {
         let source_info = SourceInfo::outermost(body.span);
 
         let none_value = match self.coroutine_kind {
-            CoroutineKind::Async(_) => span_bug!(body.span, "`Future`s are not fused inherently"),
+            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
+                span_bug!(body.span, "`Future`s are not fused inherently")
+            }
             CoroutineKind::Coroutine => span_bug!(body.span, "`Coroutine`s cannot be fused"),
             // `gen` continues return `None`
-            CoroutineKind::Gen(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
                 let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
                 Rvalue::Aggregate(
                     Box::new(AggregateKind::Adt(
@@ -271,7 +273,7 @@ impl<'tcx> TransformVisitor<'tcx> {
                 )
             }
             // `async gen` continues to return `Poll::Ready(None)`
-            CoroutineKind::AsyncGen(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
                 let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
                 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
                 let yield_ty = args.type_at(0);
@@ -316,7 +318,7 @@ impl<'tcx> TransformVisitor<'tcx> {
         statements: &mut Vec<Statement<'tcx>>,
     ) {
         let rvalue = match self.coroutine_kind {
-            CoroutineKind::Async(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
                 let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
                 let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
                 if is_return {
@@ -345,7 +347,7 @@ impl<'tcx> TransformVisitor<'tcx> {
                     )
                 }
             }
-            CoroutineKind::Gen(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
                 let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
                 let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
                 if is_return {
@@ -374,7 +376,7 @@ impl<'tcx> TransformVisitor<'tcx> {
                     )
                 }
             }
-            CoroutineKind::AsyncGen(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
                 if is_return {
                     let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
                     let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
@@ -1426,10 +1428,11 @@ fn create_coroutine_resume_function<'tcx>(
 
     if can_return {
         let block = match coroutine_kind {
-            CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
+            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) | CoroutineKind::Coroutine => {
                 insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
             }
-            CoroutineKind::AsyncGen(_) | CoroutineKind::Gen(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
+            | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
                 transform.insert_none_ret_block(body)
             }
         };
@@ -1443,7 +1446,7 @@ fn create_coroutine_resume_function<'tcx>(
     match coroutine_kind {
         // Iterator::next doesn't accept a pinned argument,
         // unlike for all other coroutine kinds.
-        CoroutineKind::Gen(_) => {}
+        CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
         _ => {
             make_coroutine_state_argument_pinned(tcx, body);
         }
@@ -1609,25 +1612,34 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
             }
         };
 
-        let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
-        let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_)));
-        let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
+        let is_async_kind = matches!(
+            body.coroutine_kind(),
+            Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
+        );
+        let is_async_gen_kind = matches!(
+            body.coroutine_kind(),
+            Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
+        );
+        let is_gen_kind = matches!(
+            body.coroutine_kind(),
+            Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _))
+        );
         let new_ret_ty = match body.coroutine_kind().unwrap() {
-            CoroutineKind::Async(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
                 // Compute Poll<return_ty>
                 let poll_did = tcx.require_lang_item(LangItem::Poll, None);
                 let poll_adt_ref = tcx.adt_def(poll_did);
                 let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
                 Ty::new_adt(tcx, poll_adt_ref, poll_args)
             }
-            CoroutineKind::Gen(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
                 // Compute Option<yield_ty>
                 let option_did = tcx.require_lang_item(LangItem::Option, None);
                 let option_adt_ref = tcx.adt_def(option_did);
                 let option_args = tcx.mk_args(&[old_yield_ty.into()]);
                 Ty::new_adt(tcx, option_adt_ref, option_args)
             }
-            CoroutineKind::AsyncGen(_) => {
+            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
                 // The yield ty is already `Poll<Option<yield_ty>>`
                 old_yield_ty
             }
diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs
index fd4af31501c..17916e16daf 100644
--- a/compiler/rustc_mir_transform/src/errors.rs
+++ b/compiler/rustc_mir_transform/src/errors.rs
@@ -2,7 +2,7 @@ use std::borrow::Cow;
 
 use rustc_errors::{
     Applicability, DecorateLint, DiagCtxt, DiagnosticArgValue, DiagnosticBuilder,
-    DiagnosticMessage, EmissionGuarantee, ErrorGuaranteed, IntoDiagnostic,
+    DiagnosticMessage, EmissionGuarantee, IntoDiagnostic, Level,
 };
 use rustc_macros::{Diagnostic, LintDiagnostic, Subdiagnostic};
 use rustc_middle::mir::{AssertKind, UnsafetyViolationDetails};
@@ -62,10 +62,10 @@ pub(crate) struct RequiresUnsafe {
 // so we need to eagerly translate the label here, which isn't supported by the derive API
 // We could also exhaustively list out the primary messages for all unsafe violations,
 // but this would result in a lot of duplication.
-impl<'sess> IntoDiagnostic<'sess> for RequiresUnsafe {
+impl<'a, G: EmissionGuarantee> IntoDiagnostic<'a, G> for RequiresUnsafe {
     #[track_caller]
-    fn into_diagnostic(self, dcx: &'sess DiagCtxt) -> DiagnosticBuilder<'sess, ErrorGuaranteed> {
-        let mut diag = dcx.struct_err(fluent::mir_transform_requires_unsafe);
+    fn into_diagnostic(self, dcx: &'a DiagCtxt, level: Level) -> DiagnosticBuilder<'a, G> {
+        let mut diag = DiagnosticBuilder::new(dcx, level, fluent::mir_transform_requires_unsafe);
         diag.code(rustc_errors::DiagnosticId::Error("E0133".to_string()));
         diag.set_span(self.span);
         diag.span_label(self.span, self.details.label());
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 89e897191e8..98d4d96d0c7 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -86,6 +86,7 @@ pub mod inline;
 mod instsimplify;
 mod jump_threading;
 mod large_enums;
+mod lint;
 mod lower_intrinsics;
 mod lower_slice_len;
 mod match_branches;
diff --git a/compiler/rustc_mir_transform/src/lint.rs b/compiler/rustc_mir_transform/src/lint.rs
new file mode 100644
index 00000000000..3940d0ddbf3
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/lint.rs
@@ -0,0 +1,119 @@
+//! This pass statically detects code which has undefined behaviour or is likely to be erroneous.
+//! It can be used to locate problems in MIR building or optimizations. It assumes that all code
+//! can be executed, so it has false positives.
+use rustc_index::bit_set::BitSet;
+use rustc_middle::mir::visit::{PlaceContext, Visitor};
+use rustc_middle::mir::*;
+use rustc_middle::ty::TyCtxt;
+use rustc_mir_dataflow::impls::{MaybeStorageDead, MaybeStorageLive};
+use rustc_mir_dataflow::storage::always_storage_live_locals;
+use rustc_mir_dataflow::{Analysis, ResultsCursor};
+use std::borrow::Cow;
+
+pub fn lint_body<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, when: String) {
+    let reachable_blocks = traversal::reachable_as_bitset(body);
+    let always_live_locals = &always_storage_live_locals(body);
+
+    let maybe_storage_live = MaybeStorageLive::new(Cow::Borrowed(always_live_locals))
+        .into_engine(tcx, body)
+        .iterate_to_fixpoint()
+        .into_results_cursor(body);
+
+    let maybe_storage_dead = MaybeStorageDead::new(Cow::Borrowed(always_live_locals))
+        .into_engine(tcx, body)
+        .iterate_to_fixpoint()
+        .into_results_cursor(body);
+
+    Lint {
+        tcx,
+        when,
+        body,
+        is_fn_like: tcx.def_kind(body.source.def_id()).is_fn_like(),
+        always_live_locals,
+        reachable_blocks,
+        maybe_storage_live,
+        maybe_storage_dead,
+    }
+    .visit_body(body);
+}
+
+struct Lint<'a, 'tcx> {
+    tcx: TyCtxt<'tcx>,
+    when: String,
+    body: &'a Body<'tcx>,
+    is_fn_like: bool,
+    always_live_locals: &'a BitSet<Local>,
+    reachable_blocks: BitSet<BasicBlock>,
+    maybe_storage_live: ResultsCursor<'a, 'tcx, MaybeStorageLive<'a>>,
+    maybe_storage_dead: ResultsCursor<'a, 'tcx, MaybeStorageDead<'a>>,
+}
+
+impl<'a, 'tcx> Lint<'a, 'tcx> {
+    #[track_caller]
+    fn fail(&self, location: Location, msg: impl AsRef<str>) {
+        let span = self.body.source_info(location).span;
+        self.tcx.sess.dcx().span_delayed_bug(
+            span,
+            format!(
+                "broken MIR in {:?} ({}) at {:?}:\n{}",
+                self.body.source.instance,
+                self.when,
+                location,
+                msg.as_ref()
+            ),
+        );
+    }
+}
+
+impl<'a, 'tcx> Visitor<'tcx> for Lint<'a, 'tcx> {
+    fn visit_local(&mut self, local: Local, context: PlaceContext, location: Location) {
+        if self.reachable_blocks.contains(location.block) && context.is_use() {
+            self.maybe_storage_dead.seek_after_primary_effect(location);
+            if self.maybe_storage_dead.get().contains(local) {
+                self.fail(location, format!("use of local {local:?}, which has no storage here"));
+            }
+        }
+    }
+
+    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
+        match statement.kind {
+            StatementKind::StorageLive(local) => {
+                if self.reachable_blocks.contains(location.block) {
+                    self.maybe_storage_live.seek_before_primary_effect(location);
+                    if self.maybe_storage_live.get().contains(local) {
+                        self.fail(
+                            location,
+                            format!("StorageLive({local:?}) which already has storage here"),
+                        );
+                    }
+                }
+            }
+            _ => {}
+        }
+
+        self.super_statement(statement, location);
+    }
+
+    fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
+        match terminator.kind {
+            TerminatorKind::Return => {
+                if self.is_fn_like && self.reachable_blocks.contains(location.block) {
+                    self.maybe_storage_live.seek_after_primary_effect(location);
+                    for local in self.maybe_storage_live.get().iter() {
+                        if !self.always_live_locals.contains(local) {
+                            self.fail(
+                                location,
+                                format!(
+                                    "local {local:?} still has storage when returning from function"
+                                ),
+                            );
+                        }
+                    }
+                }
+            }
+            _ => {}
+        }
+
+        self.super_terminator(terminator, location);
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs
index c4eca18ff27..1da1c1920b2 100644
--- a/compiler/rustc_mir_transform/src/pass_manager.rs
+++ b/compiler/rustc_mir_transform/src/pass_manager.rs
@@ -2,7 +2,7 @@ use rustc_middle::mir::{self, Body, MirPhase, RuntimePhase};
 use rustc_middle::ty::TyCtxt;
 use rustc_session::Session;
 
-use crate::{validate, MirPass};
+use crate::{lint::lint_body, validate, MirPass};
 
 /// Just like `MirPass`, except it cannot mutate `Body`.
 pub trait MirLint<'tcx> {
@@ -109,6 +109,7 @@ fn run_passes_inner<'tcx>(
     phase_change: Option<MirPhase>,
     validate_each: bool,
 ) {
+    let lint = tcx.sess.opts.unstable_opts.lint_mir & !body.should_skip();
     let validate = validate_each & tcx.sess.opts.unstable_opts.validate_mir & !body.should_skip();
     let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes;
     trace!(?overridden_passes);
@@ -131,6 +132,9 @@ fn run_passes_inner<'tcx>(
             if validate {
                 validate_body(tcx, body, format!("before pass {name}"));
             }
+            if lint {
+                lint_body(tcx, body, format!("before pass {name}"));
+            }
 
             if let Some(prof_arg) = &prof_arg {
                 tcx.sess
@@ -147,6 +151,9 @@ fn run_passes_inner<'tcx>(
             if validate {
                 validate_body(tcx, body, format!("after pass {name}"));
             }
+            if lint {
+                lint_body(tcx, body, format!("after pass {name}"));
+            }
 
             body.pass_count += 1;
         }
@@ -164,6 +171,9 @@ fn run_passes_inner<'tcx>(
         if validate || new_phase == MirPhase::Runtime(RuntimePhase::Optimized) {
             validate_body(tcx, body, format!("after phase change to {}", new_phase.name()));
         }
+        if lint {
+            lint_body(tcx, body, format!("after phase change to {}", new_phase.name()));
+        }
 
         body.pass_count = 1;
     }
diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs
index f13ab5b0f1f..05a3ac3cc75 100644
--- a/compiler/rustc_mir_transform/src/ref_prop.rs
+++ b/compiler/rustc_mir_transform/src/ref_prop.rs
@@ -7,6 +7,7 @@ use rustc_middle::ty::TyCtxt;
 use rustc_mir_dataflow::impls::MaybeStorageDead;
 use rustc_mir_dataflow::storage::always_storage_live_locals;
 use rustc_mir_dataflow::Analysis;
+use std::borrow::Cow;
 
 use crate::ssa::{SsaLocals, StorageLiveLocals};
 
@@ -120,7 +121,7 @@ fn compute_replacement<'tcx>(
 
     // Compute `MaybeStorageDead` dataflow to check that we only replace when the pointee is
     // definitely live.
-    let mut maybe_dead = MaybeStorageDead::new(always_live_locals)
+    let mut maybe_dead = MaybeStorageDead::new(Cow::Owned(always_live_locals))
         .into_engine(tcx, body)
         .iterate_to_fixpoint()
         .into_results_cursor(body);