about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs50
1 files changed, 45 insertions, 5 deletions
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 2d77291293d..e4b3b8b9262 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -6,23 +6,29 @@ use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
-use rustc_target::abi::FieldIdx;
+use rustc_target::abi::{FieldIdx, ReprFlags, FIRST_VARIANT};
 
 pub struct ScalarReplacementOfAggregates;
 
 impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
-        sess.mir_opt_level() >= 3
+        sess.mir_opt_level() >= 2
     }
 
     #[instrument(level = "debug", skip(self, tcx, body))]
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         debug!(def_id = ?body.source.def_id());
+
+        // Avoid query cycles (generators require optimized MIR for layout).
+        if tcx.type_of(body.source.def_id()).subst_identity().is_generator() {
+            return;
+        }
+
         let mut excluded = excluded_locals(body);
         let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
         loop {
             debug!(?excluded);
-            let escaping = escaping_locals(&excluded, body);
+            let escaping = escaping_locals(tcx, param_env, &excluded, body);
             debug!(?escaping);
             let replacements = compute_flattening(tcx, param_env, body, escaping);
             debug!(?replacements);
@@ -48,11 +54,45 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
 /// - the locals is a union or an enum;
 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
 ///   client code.
-fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    param_env: ty::ParamEnv<'tcx>,
+    excluded: &BitSet<Local>,
+    body: &Body<'tcx>,
+) -> BitSet<Local> {
+    let is_excluded_ty = |ty: Ty<'tcx>| {
+        if ty.is_union() || ty.is_enum() {
+            return true;
+        }
+        if let ty::Adt(def, _substs) = ty.kind() {
+            if def.repr().flags.contains(ReprFlags::IS_SIMD) {
+                // Exclude #[repr(simd)] types so that they are not de-optimized into an array
+                return true;
+            }
+            // We already excluded unions and enums, so this ADT must have one variant
+            let variant = def.variant(FIRST_VARIANT);
+            if variant.fields.len() > 1 {
+                // If this has more than one field, it cannot be a wrapper that only provides a
+                // niche, so we do not want to automatically exclude it.
+                return false;
+            }
+            let Ok(layout) = tcx.layout_of(param_env.and(ty)) else {
+                // We can't get the layout
+                return true;
+            };
+            if layout.layout.largest_niche().is_some() {
+                // This type has a niche
+                return true;
+            }
+        }
+        // Default for non-ADTs
+        false
+    };
+
     let mut set = BitSet::new_empty(body.local_decls.len());
     set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
     for (local, decl) in body.local_decls().iter_enumerated() {
-        if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
+        if excluded.contains(local) || is_excluded_ty(decl.ty) {
             set.insert(local);
         }
     }