about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/const_goto.rs2
-rw-r--r--compiler/rustc_mir_transform/src/const_prop.rs54
-rw-r--r--compiler/rustc_mir_transform/src/const_prop_lint.rs6
-rw-r--r--compiler/rustc_mir_transform/src/errors.rs11
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs5
-rw-r--r--compiler/rustc_mir_transform/src/required_consts.rs4
-rw-r--r--compiler/rustc_mir_transform/src/separate_const_switch.rs2
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs50
8 files changed, 103 insertions, 31 deletions
diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs
index da101ca7ad2..024bea62098 100644
--- a/compiler/rustc_mir_transform/src/const_goto.rs
+++ b/compiler/rustc_mir_transform/src/const_goto.rs
@@ -28,7 +28,7 @@ pub struct ConstGoto;
 
 impl<'tcx> MirPass<'tcx> for ConstGoto {
     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
-        sess.mir_opt_level() >= 4
+        sess.mir_opt_level() >= 2
     }
 
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs
index 1ba1951afde..1d43dbda0aa 100644
--- a/compiler/rustc_mir_transform/src/const_prop.rs
+++ b/compiler/rustc_mir_transform/src/const_prop.rs
@@ -4,6 +4,7 @@
 use either::Right;
 
 use rustc_const_eval::const_eval::CheckAlignment;
+use rustc_const_eval::ReportErrorExt;
 use rustc_data_structures::fx::FxHashSet;
 use rustc_hir::def::DefKind;
 use rustc_index::bit_set::BitSet;
@@ -37,6 +38,7 @@ macro_rules! throw_machine_stop_str {
     ($($tt:tt)*) => {{
         // We make a new local type for it. The type itself does not carry any information,
         // but its vtable (for the `MachineStopType` trait) does.
+        #[derive(Debug)]
         struct Zst;
         // Printing this type shows the desired string.
         impl std::fmt::Display for Zst {
@@ -44,7 +46,17 @@ macro_rules! throw_machine_stop_str {
                 write!(f, $($tt)*)
             }
         }
-        impl rustc_middle::mir::interpret::MachineStopType for Zst {}
+
+        impl rustc_middle::mir::interpret::MachineStopType for Zst {
+            fn diagnostic_message(&self) -> rustc_errors::DiagnosticMessage {
+                self.to_string().into()
+            }
+
+            fn add_args(
+                self: Box<Self>,
+                _: &mut dyn FnMut(std::borrow::Cow<'static, str>, rustc_errors::DiagnosticArgValue<'static>),
+            ) {}
+        }
         throw_machine_stop!(Zst)
     }};
 }
@@ -103,7 +115,14 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
         // That would require a uniform one-def no-mutation analysis
         // and RPO (or recursing when needing the value of a local).
         let mut optimization_finder = ConstPropagator::new(body, dummy_body, tcx);
-        optimization_finder.visit_body(body);
+
+        // Traverse the body in reverse post-order, to ensure that `FullConstProp` locals are
+        // assigned before being read.
+        let postorder = body.basic_blocks.postorder().to_vec();
+        for bb in postorder.into_iter().rev() {
+            let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
+            optimization_finder.visit_basic_block_data(bb, data);
+        }
 
         trace!("ConstProp done for {:?}", def_id);
     }
@@ -367,7 +386,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
                 op
             }
             Err(e) => {
-                trace!("get_const failed: {}", e);
+                trace!("get_const failed: {:?}", e.into_kind().debug());
                 return None;
             }
         };
@@ -789,12 +808,6 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> {
         self.tcx
     }
 
-    fn visit_body(&mut self, body: &mut Body<'tcx>) {
-        for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
-            self.visit_basic_block_data(bb, data);
-        }
-    }
-
     fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
         self.super_operand(operand, location);
 
@@ -885,14 +898,23 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> {
                 }
             }
             StatementKind::StorageLive(local) => {
-                let frame = self.ecx.frame_mut();
-                frame.locals[local].value =
-                    LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit));
-            }
-            StatementKind::StorageDead(local) => {
-                let frame = self.ecx.frame_mut();
-                frame.locals[local].value = LocalValue::Dead;
+                Self::remove_const(&mut self.ecx, local);
             }
+            // We do not need to mark dead locals as such. For `FullConstProp` locals,
+            // this allows to propagate the single assigned value in this case:
+            // ```
+            // let x = SOME_CONST;
+            // if a {
+            //   f(copy x);
+            //   StorageDead(x);
+            // } else {
+            //   g(copy x);
+            //   StorageDead(x);
+            // }
+            // ```
+            //
+            // This may propagate a constant where the local would be uninit or dead.
+            // In both cases, this does not matter, as those reads would be UB anyway.
             _ => {}
         }
     }
diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs
index 0fe49b8a1bb..759650fe4db 100644
--- a/compiler/rustc_mir_transform/src/const_prop_lint.rs
+++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs
@@ -9,6 +9,7 @@ use rustc_const_eval::interpret::Immediate;
 use rustc_const_eval::interpret::{
     self, InterpCx, InterpResult, LocalValue, MemoryKind, OpTy, Scalar, StackPopCleanup,
 };
+use rustc_const_eval::ReportErrorExt;
 use rustc_hir::def::DefKind;
 use rustc_hir::HirId;
 use rustc_index::bit_set::BitSet;
@@ -232,7 +233,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
                 op
             }
             Err(e) => {
-                trace!("get_const failed: {}", e);
+                trace!("get_const failed: {:?}", e.into_kind().debug());
                 return None;
             }
         };
@@ -272,8 +273,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
                 // dedicated error variants should be introduced instead.
                 assert!(
                     !error.kind().formatted_string(),
-                    "const-prop encountered formatting error: {}",
-                    error
+                    "const-prop encountered formatting error: {error:?}",
                 );
                 None
             }
diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs
index 602e40d5131..22f71bb0851 100644
--- a/compiler/rustc_mir_transform/src/errors.rs
+++ b/compiler/rustc_mir_transform/src/errors.rs
@@ -163,7 +163,14 @@ impl<'a, P: std::fmt::Debug> DecorateLint<'a, ()> for AssertLint<P> {
         self,
         diag: &'b mut DiagnosticBuilder<'a, ()>,
     ) -> &'b mut DiagnosticBuilder<'a, ()> {
-        diag.span_label(self.span(), format!("{:?}", self.panic()));
+        let span = self.span();
+        let assert_kind = self.panic();
+        let message = assert_kind.diagnostic_message();
+        assert_kind.add_args(&mut |name, value| {
+            diag.set_arg(name, value);
+        });
+        diag.span_label(span, message);
+
         diag
     }
 
@@ -191,7 +198,7 @@ impl<P> AssertLint<P> {
             AssertLint::ArithmeticOverflow(sp, _) | AssertLint::UnconditionalPanic(sp, _) => *sp,
         }
     }
-    pub fn panic(&self) -> &AssertKind<P> {
+    pub fn panic(self) -> AssertKind<P> {
         match self {
             AssertLint::ArithmeticOverflow(_, p) | AssertLint::UnconditionalPanic(_, p) => p,
         }
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 54c138b6fbd..7d9f6c38e36 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -559,10 +559,13 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
             // inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
             &multiple_return_terminators::MultipleReturnTerminators,
             &instsimplify::InstSimplify,
-            &separate_const_switch::SeparateConstSwitch,
             &simplify::SimplifyLocals::BeforeConstProp,
             &copy_prop::CopyProp,
             &ref_prop::ReferencePropagation,
+            // Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may
+            // destroy the SSA property. It should still happen before const-propagation, so the
+            // latter pass will leverage the created opportunities.
+            &separate_const_switch::SeparateConstSwitch,
             &const_prop::ConstProp,
             &dataflow_const_prop::DataflowConstProp,
             //
diff --git a/compiler/rustc_mir_transform/src/required_consts.rs b/compiler/rustc_mir_transform/src/required_consts.rs
index 0ea8f2ba93f..243cb463560 100644
--- a/compiler/rustc_mir_transform/src/required_consts.rs
+++ b/compiler/rustc_mir_transform/src/required_consts.rs
@@ -17,8 +17,8 @@ impl<'tcx> Visitor<'tcx> for RequiredConstsVisitor<'_, 'tcx> {
         let literal = constant.literal;
         match literal {
             ConstantKind::Ty(c) => match c.kind() {
-                ConstKind::Param(_) | ConstKind::Error(_) => {}
-                _ => bug!("only ConstKind::Param should be encountered here, got {:#?}", c),
+                ConstKind::Param(_) | ConstKind::Error(_) | ConstKind::Value(_) => {}
+                _ => bug!("only ConstKind::Param/Value should be encountered here, got {:#?}", c),
             },
             ConstantKind::Unevaluated(..) => self.required_consts.push(*constant),
             ConstantKind::Val(..) => {}
diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs
index 2479856b727..f35a5fb4276 100644
--- a/compiler/rustc_mir_transform/src/separate_const_switch.rs
+++ b/compiler/rustc_mir_transform/src/separate_const_switch.rs
@@ -46,7 +46,7 @@ pub struct SeparateConstSwitch;
 
 impl<'tcx> MirPass<'tcx> for SeparateConstSwitch {
     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
-        sess.mir_opt_level() >= 4
+        sess.mir_opt_level() >= 2
     }
 
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 2d77291293d..e4b3b8b9262 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -6,23 +6,29 @@ use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
-use rustc_target::abi::FieldIdx;
+use rustc_target::abi::{FieldIdx, ReprFlags, FIRST_VARIANT};
 
 pub struct ScalarReplacementOfAggregates;
 
 impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
-        sess.mir_opt_level() >= 3
+        sess.mir_opt_level() >= 2
     }
 
     #[instrument(level = "debug", skip(self, tcx, body))]
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         debug!(def_id = ?body.source.def_id());
+
+        // Avoid query cycles (generators require optimized MIR for layout).
+        if tcx.type_of(body.source.def_id()).subst_identity().is_generator() {
+            return;
+        }
+
         let mut excluded = excluded_locals(body);
         let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
         loop {
             debug!(?excluded);
-            let escaping = escaping_locals(&excluded, body);
+            let escaping = escaping_locals(tcx, param_env, &excluded, body);
             debug!(?escaping);
             let replacements = compute_flattening(tcx, param_env, body, escaping);
             debug!(?replacements);
@@ -48,11 +54,45 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
 /// - the locals is a union or an enum;
 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
 ///   client code.
-fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    param_env: ty::ParamEnv<'tcx>,
+    excluded: &BitSet<Local>,
+    body: &Body<'tcx>,
+) -> BitSet<Local> {
+    let is_excluded_ty = |ty: Ty<'tcx>| {
+        if ty.is_union() || ty.is_enum() {
+            return true;
+        }
+        if let ty::Adt(def, _substs) = ty.kind() {
+            if def.repr().flags.contains(ReprFlags::IS_SIMD) {
+                // Exclude #[repr(simd)] types so that they are not de-optimized into an array
+                return true;
+            }
+            // We already excluded unions and enums, so this ADT must have one variant
+            let variant = def.variant(FIRST_VARIANT);
+            if variant.fields.len() > 1 {
+                // If this has more than one field, it cannot be a wrapper that only provides a
+                // niche, so we do not want to automatically exclude it.
+                return false;
+            }
+            let Ok(layout) = tcx.layout_of(param_env.and(ty)) else {
+                // We can't get the layout
+                return true;
+            };
+            if layout.layout.largest_niche().is_some() {
+                // This type has a niche
+                return true;
+            }
+        }
+        // Default for non-ADTs
+        false
+    };
+
     let mut set = BitSet::new_empty(body.local_decls.len());
     set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
     for (local, decl) in body.local_decls().iter_enumerated() {
-        if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
+        if excluded.contains(local) || is_excluded_ty(decl.ty) {
             set.insert(local);
         }
     }