diff options
Diffstat (limited to 'compiler')
| -rw-r--r-- | compiler/rustc_mir_transform/src/sroa.rs | 50 |
1 files changed, 45 insertions, 5 deletions
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 2d77291293d..e4b3b8b9262 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -6,23 +6,29 @@ use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; -use rustc_target::abi::FieldIdx; +use rustc_target::abi::{FieldIdx, ReprFlags, FIRST_VARIANT}; pub struct ScalarReplacementOfAggregates; impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 3 + sess.mir_opt_level() >= 2 } #[instrument(level = "debug", skip(self, tcx, body))] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { debug!(def_id = ?body.source.def_id()); + + // Avoid query cycles (generators require optimized MIR for layout). + if tcx.type_of(body.source.def_id()).subst_identity().is_generator() { + return; + } + let mut excluded = excluded_locals(body); let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); loop { debug!(?excluded); - let escaping = escaping_locals(&excluded, body); + let escaping = escaping_locals(tcx, param_env, &excluded, body); debug!(?escaping); let replacements = compute_flattening(tcx, param_env, body, escaping); debug!(?replacements); @@ -48,11 +54,45 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { /// - the locals is a union or an enum; /// - the local's address is taken, and thus the relative addresses of the fields are observable to /// client code. -fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> { +fn escaping_locals<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + excluded: &BitSet<Local>, + body: &Body<'tcx>, +) -> BitSet<Local> { + let is_excluded_ty = |ty: Ty<'tcx>| { + if ty.is_union() || ty.is_enum() { + return true; + } + if let ty::Adt(def, _substs) = ty.kind() { + if def.repr().flags.contains(ReprFlags::IS_SIMD) { + // Exclude #[repr(simd)] types so that they are not de-optimized into an array + return true; + } + // We already excluded unions and enums, so this ADT must have one variant + let variant = def.variant(FIRST_VARIANT); + if variant.fields.len() > 1 { + // If this has more than one field, it cannot be a wrapper that only provides a + // niche, so we do not want to automatically exclude it. + return false; + } + let Ok(layout) = tcx.layout_of(param_env.and(ty)) else { + // We can't get the layout + return true; + }; + if layout.layout.largest_niche().is_some() { + // This type has a niche + return true; + } + } + // Default for non-ADTs + false + }; + let mut set = BitSet::new_empty(body.local_decls.len()); set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); for (local, decl) in body.local_decls().iter_enumerated() { - if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) { + if excluded.contains(local) || is_excluded_ty(decl.ty) { set.insert(local); } } |
