diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/add_retag.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/check_unsafety.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/const_goto.rs | 9 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/counters.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/graph.rs | 14 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/mod.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/spans.rs | 22 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/function_item_references.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/generator.rs | 458 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline.rs | 58 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/instcombine.rs | 86 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 8 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/shim.rs | 12 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/simplify_try.rs | 822 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/sroa.rs | 2 |
15 files changed, 587 insertions, 922 deletions
diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs index 3d22035f078..7d2146214c6 100644 --- a/compiler/rustc_mir_transform/src/add_retag.rs +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -120,7 +120,7 @@ impl<'tcx> MirPass<'tcx> for AddRetag { // PART 3 // Add retag after assignments where data "enters" this function: the RHS is behind a deref and the LHS is not. for block_data in basic_blocks { - // We want to insert statements as we iterate. To this end, we + // We want to insert statements as we iterate. To this end, we // iterate backwards using indices. for i in (0..block_data.statements.len()).rev() { let (retag_kind, place) = match block_data.statements[i].kind { diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs index adf6ae4c727..9f006a76162 100644 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -445,7 +445,7 @@ impl<'tcx> intravisit::Visitor<'tcx> for UnusedUnsafeVisitor<'_, 'tcx> { _fd: &'tcx hir::FnDecl<'tcx>, b: hir::BodyId, _s: rustc_span::Span, - _id: HirId, + _id: LocalDefId, ) { if matches!(fk, intravisit::FnKind::Closure) { self.visit_body(self.tcx.hir().body(b)) diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs index 40eefda4f07..da101ca7ad2 100644 --- a/compiler/rustc_mir_transform/src/const_goto.rs +++ b/compiler/rustc_mir_transform/src/const_goto.rs @@ -57,6 +57,15 @@ impl<'tcx> MirPass<'tcx> for ConstGoto { } impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> { + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { + if data.is_cleanup { + // Because of the restrictions around control flow in cleanup blocks, we don't perform + // this optimization at all in such blocks. + return; + } + self.super_basic_block_data(block, data); + } + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { let _: Option<_> = try { let target = terminator.kind.as_goto()?; diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index 45de0c28035..658e01d9310 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -520,7 +520,7 @@ impl<'a> BcbCounters<'a> { let mut found_loop_exit = false; for &branch in branches.iter() { if backedge_from_bcbs.iter().any(|&backedge_from_bcb| { - self.bcb_is_dominated_by(backedge_from_bcb, branch.target_bcb) + self.bcb_dominates(branch.target_bcb, backedge_from_bcb) }) { if let Some(reloop_branch) = some_reloop_branch { if reloop_branch.counter(&self.basic_coverage_blocks).is_none() { @@ -603,8 +603,8 @@ impl<'a> BcbCounters<'a> { } #[inline] - fn bcb_is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { - self.basic_coverage_blocks.is_dominated_by(node, dom) + fn bcb_dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { + self.basic_coverage_blocks.dominates(dom, node) } #[inline] diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs index 78d28f1ebab..a2671eef2e9 100644 --- a/compiler/rustc_mir_transform/src/coverage/graph.rs +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -209,8 +209,8 @@ impl CoverageGraph { } #[inline(always)] - pub fn is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { - self.dominators.as_ref().unwrap().is_dominated_by(node, dom) + pub fn dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { + self.dominators.as_ref().unwrap().dominates(dom, node) } #[inline(always)] @@ -312,7 +312,7 @@ rustc_index::newtype_index! { /// to the BCB's primary counter or expression). /// /// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based -/// queries (`is_dominated_by()`, `predecessors`, `successors`, etc.) have branch (control flow) +/// queries (`dominates()`, `predecessors`, `successors`, etc.) have branch (control flow) /// significance. #[derive(Debug, Clone)] pub(super) struct BasicCoverageBlockData { @@ -594,7 +594,7 @@ impl TraverseCoverageGraphWithLoops { // branching block would have given an `Expression` (or vice versa). let (some_successor_to_add, some_loop_header) = if let Some((_, loop_header)) = context.loop_backedges { - if basic_coverage_blocks.is_dominated_by(successor, loop_header) { + if basic_coverage_blocks.dominates(loop_header, successor) { (Some(successor), Some(loop_header)) } else { (None, None) @@ -666,15 +666,15 @@ pub(super) fn find_loop_backedges( // // The overall complexity appears to be comparable to many other MIR transform algorithms, and I // don't expect that this function is creating a performance hot spot, but if this becomes an - // issue, there may be ways to optimize the `is_dominated_by` algorithm (as indicated by an + // issue, there may be ways to optimize the `dominates` algorithm (as indicated by an // existing `FIXME` comment in that code), or possibly ways to optimize it's usage here, perhaps // by keeping track of results for visited `BasicCoverageBlock`s if they can be used to short - // circuit downstream `is_dominated_by` checks. + // circuit downstream `dominates` checks. // // For now, that kind of optimization seems unnecessarily complicated. for (bcb, _) in basic_coverage_blocks.iter_enumerated() { for &successor in &basic_coverage_blocks.successors[bcb] { - if basic_coverage_blocks.is_dominated_by(bcb, successor) { + if basic_coverage_blocks.dominates(successor, bcb) { let loop_header = successor; let backedge_from_bcb = bcb; debug!( diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index 1468afc6456..9a617159813 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -540,7 +540,8 @@ fn fn_sig_and_body( // FIXME(#79625): Consider improving MIR to provide the information needed, to avoid going back // to HIR for it. let hir_node = tcx.hir().get_if_local(def_id).expect("expected DefId is local"); - let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body"); + let (_, fn_body_id) = + hir::map::associated_body(hir_node).expect("HIR node is a function with body"); (hir_node.fn_sig(), tcx.hir().body(fn_body_id)) } diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 9f842c929dc..31d5541a31b 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -63,7 +63,7 @@ impl CoverageStatement { /// Note: A `CoverageStatement` merged into another CoverageSpan may come from a `BasicBlock` that /// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches /// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock` -/// `is_dominated_by()` the `BasicBlock`s in this `CoverageSpan`. +/// `dominates()` the `BasicBlock`s in this `CoverageSpan`. #[derive(Debug, Clone)] pub(super) struct CoverageSpan { pub span: Span, @@ -341,11 +341,11 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { if a.is_in_same_bcb(b) { Some(Ordering::Equal) } else { - // Sort equal spans by dominator relationship, in reverse order (so - // dominators always come after the dominated equal spans). When later - // comparing two spans in order, the first will either dominate the second, - // or they will have no dominator relationship. - self.basic_coverage_blocks.dominators().rank_partial_cmp(b.bcb, a.bcb) + // Sort equal spans by dominator relationship (so dominators always come + // before the dominated equal spans). When later comparing two spans in + // order, the first will either dominate the second, or they will have no + // dominator relationship. + self.basic_coverage_blocks.dominators().rank_partial_cmp(a.bcb, b.bcb) } } else { // Sort hi() in reverse order so shorter spans are attempted after longer spans. @@ -705,12 +705,12 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { fn hold_pending_dups_unless_dominated(&mut self) { // Equal coverage spans are ordered by dominators before dominated (if any), so it should be // impossible for `curr` to dominate any previous `CoverageSpan`. - debug_assert!(!self.span_bcb_is_dominated_by(self.prev(), self.curr())); + debug_assert!(!self.span_bcb_dominates(self.curr(), self.prev())); let initial_pending_count = self.pending_dups.len(); if initial_pending_count > 0 { let mut pending_dups = self.pending_dups.split_off(0); - pending_dups.retain(|dup| !self.span_bcb_is_dominated_by(self.curr(), dup)); + pending_dups.retain(|dup| !self.span_bcb_dominates(dup, self.curr())); self.pending_dups.append(&mut pending_dups); if self.pending_dups.len() < initial_pending_count { debug!( @@ -721,7 +721,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } } - if self.span_bcb_is_dominated_by(self.curr(), self.prev()) { + if self.span_bcb_dominates(self.prev(), self.curr()) { debug!( " different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}", self.prev() @@ -787,8 +787,8 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } } - fn span_bcb_is_dominated_by(&self, covspan: &CoverageSpan, dom_covspan: &CoverageSpan) -> bool { - self.basic_coverage_blocks.is_dominated_by(covspan.bcb, dom_covspan.bcb) + fn span_bcb_dominates(&self, dom_covspan: &CoverageSpan, covspan: &CoverageSpan) -> bool { + self.basic_coverage_blocks.dominates(dom_covspan.bcb, covspan.bcb) } } diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index b708d780b70..aa19b1fdb5e 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -79,7 +79,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { for bound in bounds { if let Some(bound_ty) = self.is_pointer_trait(&bound.kind().skip_binder()) { // Get the argument types as they appear in the function signature. - let arg_defs = self.tcx.fn_sig(def_id).skip_binder().inputs(); + let arg_defs = self.tcx.fn_sig(def_id).subst_identity().skip_binder().inputs(); for (arg_num, arg_def) in arg_defs.iter().enumerate() { // For all types reachable from the argument type in the fn sig for generic_inner_ty in arg_def.walk() { @@ -161,7 +161,8 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { .as_ref() .assert_crate_local() .lint_root; - let fn_sig = self.tcx.fn_sig(fn_id); + // FIXME: use existing printing routines to print the function signature + let fn_sig = self.tcx.fn_sig(fn_id).subst(self.tcx, fn_substs); let unsafety = fn_sig.unsafety().prefix_str(); let abi = match fn_sig.abi() { Abi::Rust => String::from(""), diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index c097af61611..e8871ff37f2 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -54,7 +54,8 @@ use crate::deref_separator::deref_finder; use crate::simplify; use crate::util::expand_aggregate; use crate::MirPass; -use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_errors::pluralize; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; use rustc_hir::GeneratorKind; @@ -70,6 +71,9 @@ use rustc_mir_dataflow::impls::{ }; use rustc_mir_dataflow::storage::always_storage_live_locals; use rustc_mir_dataflow::{self, Analysis}; +use rustc_span::def_id::DefId; +use rustc_span::symbol::sym; +use rustc_span::Span; use rustc_target::abi::VariantIdx; use rustc_target::spec::PanicStrategy; use std::{iter, ops}; @@ -460,6 +464,104 @@ fn replace_local<'tcx>( new_local } +/// Transforms the `body` of the generator applying the following transforms: +/// +/// - Eliminates all the `get_context` calls that async lowering created. +/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`). +/// +/// The `Local`s that have their types replaced are: +/// - The `resume` argument itself. +/// - The argument to `get_context`. +/// - The yielded value of a `yield`. +/// +/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the +/// `get_context` function is being used to convert that back to a `&mut Context<'_>`. +/// +/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection, +/// but rather directly use `&mut Context<'_>`, however that would currently +/// lead to higher-kinded lifetime errors. +/// See <https://github.com/rust-lang/rust/issues/105501>. +/// +/// The async lowering step and the type / lifetime inference / checking are +/// still using the `ResumeTy` indirection for the time being, and that indirection +/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`. +fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let context_mut_ref = tcx.mk_task_context(); + + // replace the type of the `resume` argument + replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref); + + let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None); + + for bb in BasicBlock::new(0)..body.basic_blocks.next_index() { + let bb_data = &body[bb]; + if bb_data.is_cleanup { + continue; + } + + match &bb_data.terminator().kind { + TerminatorKind::Call { func, .. } => { + let func_ty = func.ty(body, tcx); + if let ty::FnDef(def_id, _) = *func_ty.kind() { + if def_id == get_context_def_id { + let local = eliminate_get_context_call(&mut body[bb]); + replace_resume_ty_local(tcx, body, local, context_mut_ref); + } + } else { + continue; + } + } + TerminatorKind::Yield { resume_arg, .. } => { + replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref); + } + _ => {} + } + } +} + +fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local { + let terminator = bb_data.terminator.take().unwrap(); + if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind { + let arg = args.pop().unwrap(); + let local = arg.place().unwrap().local; + + let arg = Rvalue::Use(arg); + let assign = Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new((destination, arg))), + }; + bb_data.statements.push(assign); + bb_data.terminator = Some(Terminator { + source_info: terminator.source_info, + kind: TerminatorKind::Goto { target: target.unwrap() }, + }); + local + } else { + bug!(); + } +} + +#[cfg_attr(not(debug_assertions), allow(unused))] +fn replace_resume_ty_local<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + local: Local, + context_mut_ref: Ty<'tcx>, +) { + let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref); + // We have to replace the `ResumeTy` that is used for type and borrow checking + // with `&mut Context<'_>` in MIR. + #[cfg(debug_assertions)] + { + if let ty::Adt(resume_ty_adt, _) = local_ty.kind() { + let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None)); + assert_eq!(*resume_ty_adt, expected_adt); + } else { + panic!("expected `ResumeTy`, found `{:?}`", local_ty); + }; + } +} + struct LivenessInfo { /// Which locals are live across any suspension point. saved_locals: GeneratorSavedLocals, @@ -756,7 +858,7 @@ fn sanitize_witness<'tcx>( body: &Body<'tcx>, witness: Ty<'tcx>, upvars: Vec<Ty<'tcx>>, - saved_locals: &GeneratorSavedLocals, + layout: &GeneratorLayout<'tcx>, ) { let did = body.source.def_id(); let param_env = tcx.param_env(did); @@ -775,31 +877,36 @@ fn sanitize_witness<'tcx>( } }; - for (local, decl) in body.local_decls.iter_enumerated() { - // Ignore locals which are internal or not saved between yields. - if !saved_locals.contains(local) || decl.internal { + let mut mismatches = Vec::new(); + for fty in &layout.field_tys { + if fty.ignore_for_traits { continue; } - let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty); + let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty); // Sanity check that typeck knows about the type of locals which are // live across a suspension point if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) { - span_bug!( - body.span, - "Broken MIR: generator contains type {} in MIR, \ - but typeck only knows about {} and {:?}", - decl_ty, - allowed, - allowed_upvars - ); + mismatches.push(decl_ty); } } + + if !mismatches.is_empty() { + span_bug!( + body.span, + "Broken MIR: generator contains type {:?} in MIR, \ + but typeck only knows about {} and {:?}", + mismatches, + allowed, + allowed_upvars + ); + } } fn compute_layout<'tcx>( + tcx: TyCtxt<'tcx>, liveness: LivenessInfo, - body: &mut Body<'tcx>, + body: &Body<'tcx>, ) -> ( FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>, GeneratorLayout<'tcx>, @@ -817,9 +924,33 @@ fn compute_layout<'tcx>( let mut locals = IndexVec::<GeneratorSavedLocal, _>::new(); let mut tys = IndexVec::<GeneratorSavedLocal, _>::new(); for (saved_local, local) in saved_locals.iter_enumerated() { - locals.push(local); - tys.push(body.local_decls[local].ty); debug!("generator saved local {:?} => {:?}", saved_local, local); + + locals.push(local); + let decl = &body.local_decls[local]; + debug!(?decl); + + let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir { + match decl.local_info { + // Do not include raw pointers created from accessing `static` items, as those could + // well be re-created by another access to the same static. + Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local, + // Fake borrows are only read by fake reads, so do not have any reality in + // post-analysis MIR. + Some(box LocalInfo::FakeBorrow) => true, + _ => false, + } + } else { + // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that + // MIR building may introduce. This leads to wrongly ignored types, but this is + // necessary for internal consistency and to avoid ICEs. + decl.internal + }; + let decl = + GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; + debug!(?decl); + + tys.push(decl); } // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states. @@ -849,7 +980,7 @@ fn compute_layout<'tcx>( // just use the first one here. That's fine; fields do not move // around inside generators, so it doesn't matter which variant // index we access them by. - remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx)); + remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx)); } variant_fields.push(fields); variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); @@ -859,6 +990,7 @@ fn compute_layout<'tcx>( let layout = GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts }; + debug!(?layout); (remap, layout, storage_liveness) } @@ -1253,6 +1385,52 @@ fn create_cases<'tcx>( .collect() } +#[instrument(level = "debug", skip(tcx), ret)] +pub(crate) fn mir_generator_witnesses<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, +) -> GeneratorLayout<'tcx> { + let def_id = def_id.expect_local(); + + let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id)); + let body = body.borrow(); + let body = &*body; + + // The first argument is the generator type passed by value + let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + + // Get the interior types and substs which typeck computed + let (upvars, interior, movable) = match *gen_ty.kind() { + ty::Generator(_, substs, movability) => { + let substs = substs.as_generator(); + ( + substs.upvar_tys().collect::<Vec<_>>(), + substs.witness(), + movability == hir::Movability::Movable, + ) + } + _ => span_bug!(body.span, "unexpected generator type {}", gen_ty), + }; + + // When first entering the generator, move the resume argument into its new local. + let always_live_locals = always_storage_live_locals(&body); + + let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + + // Extract locals which are live across suspension point into `layout` + // `remap` gives a mapping from local indices onto generator struct indices + // `storage_liveness` tells us which locals have live storage at suspension points + let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body); + + if tcx.sess.opts.unstable_opts.drop_tracking_mir { + check_suspend_tys(tcx, &generator_layout, &body); + } else { + sanitize_witness(tcx, body, interior, upvars, &generator_layout); + } + + generator_layout +} + impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let Some(yield_ty) = body.yield_ty() else { @@ -1265,16 +1443,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // The first argument is the generator type passed by value let gen_ty = body.local_decls.raw[1].ty; - // Get the interior types and substs which typeck computed - let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() { + // Get the discriminant type and substs which typeck computed + let (discr_ty, movable) = match *gen_ty.kind() { ty::Generator(_, substs, movability) => { let substs = substs.as_generator(); - ( - substs.upvar_tys().collect(), - substs.witness(), - substs.discr_ty(tcx), - movability == hir::Movability::Movable, - ) + (substs.discr_ty(tcx), movability == hir::Movability::Movable) } _ => { tcx.sess @@ -1283,13 +1456,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } }; - let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen; + let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_))); let (state_adt_ref, state_substs) = if is_async_kind { // Compute Poll<return_ty> - let state_did = tcx.require_lang_item(LangItem::Poll, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_substs = tcx.intern_substs(&[body.return_ty().into()]); - (state_adt_ref, state_substs) + let poll_did = tcx.require_lang_item(LangItem::Poll, None); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_substs = tcx.intern_substs(&[body.return_ty().into()]); + (poll_adt_ref, poll_substs) } else { // Compute GeneratorState<yield_ty, return_ty> let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); @@ -1303,13 +1476,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // RETURN_PLACE then is a fresh unused local with type ret_ty. let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx); + // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies. + if is_async_kind { + transform_async_context(tcx, body); + } + // We also replace the resume argument and insert an `Assign`. // This is needed because the resume argument `_2` might be live across a `yield`, in which // case there is no `Assign` to it that the transform can turn into a store to the generator // state. After the yield the slot in the generator state would then be uninitialized. let resume_local = Local::new(2); - let new_resume_local = - replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx); + let resume_ty = + if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty }; + let new_resume_local = replace_local(resume_local, resume_ty, body, tcx); // When first entering the generator, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); @@ -1330,8 +1509,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); - sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals); - if tcx.sess.opts.unstable_opts.validate_mir { let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias { assigned_local: None, @@ -1345,7 +1522,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Extract locals which are live across suspension point into `layout` // `remap` gives a mapping from local indices onto generator struct indices // `storage_liveness` tells us which locals have live storage at suspension points - let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); + let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body); let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id())); @@ -1527,3 +1704,212 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { } } } + +fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) { + let mut linted_tys = FxHashSet::default(); + + // We want a user-facing param-env. + let param_env = tcx.param_env(body.source.def_id()); + + for (variant, yield_source_info) in + layout.variant_fields.iter().zip(&layout.variant_source_info) + { + debug!(?variant); + for &local in variant { + let decl = &layout.field_tys[local]; + debug!(?decl); + + if !decl.ignore_for_traits && linted_tys.insert(decl.ty) { + let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue }; + + check_must_not_suspend_ty( + tcx, + decl.ty, + hir_id, + param_env, + SuspendCheckData { + source_span: decl.source_info.span, + yield_span: yield_source_info.span, + plural_len: 1, + ..Default::default() + }, + ); + } + } + } +} + +#[derive(Default)] +struct SuspendCheckData<'a> { + source_span: Span, + yield_span: Span, + descr_pre: &'a str, + descr_post: &'a str, + plural_len: usize, +} + +// Returns whether it emitted a diagnostic or not +// Note that this fn and the proceeding one are based on the code +// for creating must_use diagnostics +// +// Note that this technique was chosen over things like a `Suspend` marker trait +// as it is simpler and has precedent in the compiler +fn check_must_not_suspend_ty<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + hir_id: hir::HirId, + param_env: ty::ParamEnv<'tcx>, + data: SuspendCheckData<'_>, +) -> bool { + if ty.is_unit() { + return false; + } + + let plural_suffix = pluralize!(data.plural_len); + + debug!("Checking must_not_suspend for {}", ty); + + match *ty.kind() { + ty::Adt(..) if ty.is_box() => { + let boxed_ty = ty.boxed_ty(); + let descr_pre = &format!("{}boxed ", data.descr_pre); + check_must_not_suspend_ty( + tcx, + boxed_ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data), + // FIXME: support adding the attribute to TAITs + ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => { + let mut has_emitted = false; + for &(predicate, _) in tcx.explicit_item_bounds(def) { + // We only look at the `DefId`, so it is safe to skip the binder here. + if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) = + predicate.kind().skip_binder() + { + let def_id = poly_trait_predicate.trait_ref.def_id; + let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_pre, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Dynamic(binder, _, _) => { + let mut has_emitted = false; + for predicate in binder.iter() { + if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() { + let def_id = trait_ref.def_id; + let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Tuple(fields) => { + let mut has_emitted = false; + for (i, ty) in fields.iter().enumerate() { + let descr_post = &format!(" in tuple element {i}"); + if check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + } + } + has_emitted + } + ty::Array(ty, len) => { + let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { + descr_pre, + plural_len: len.try_eval_usize(tcx, param_env).unwrap_or(0) as usize + 1, + ..data + }, + ) + } + // If drop tracking is enabled, we want to look through references, since the referrent + // may not be considered live across the await point. + ty::Ref(_region, ty, _mutability) => { + let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + _ => false, + } +} + +fn check_must_not_suspend_def( + tcx: TyCtxt<'_>, + def_id: DefId, + hir_id: hir::HirId, + data: SuspendCheckData<'_>, +) -> bool { + if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) { + let msg = format!( + "{}`{}`{} held across a suspend point, but should not be", + data.descr_pre, + tcx.def_path_str(def_id), + data.descr_post, + ); + tcx.struct_span_lint_hir( + rustc_session::lint::builtin::MUST_NOT_SUSPEND, + hir_id, + data.source_span, + msg, + |lint| { + // add span pointing to the offending yield/await + lint.span_label(data.yield_span, "the value is held across this suspend point"); + + // Add optional reason note + if let Some(note) = attr.value_str() { + // FIXME(guswynn): consider formatting this better + lint.span_note(data.source_span, note.as_str()); + } + + // Add some quick suggestions on what to do + // FIXME: can `drop` work as a suggestion here as well? + lint.span_help( + data.source_span, + "consider using a block (`{ ... }`) \ + to shrink the value's scope, ending before the suspend point", + ) + }, + ); + + true + } else { + false + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 4219e6280eb..84640b703c8 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -331,7 +331,7 @@ impl<'tcx> Inliner<'tcx> { return None; } - let fn_sig = self.tcx.bound_fn_sig(def_id).subst(self.tcx, substs); + let fn_sig = self.tcx.fn_sig(def_id).subst(self.tcx, substs); let source_info = SourceInfo { span: fn_span, ..terminator.source_info }; return Some(CallSite { callee, fn_sig, block: bb, target, source_info }); @@ -542,6 +542,21 @@ impl<'tcx> Inliner<'tcx> { destination }; + // Always create a local to hold the destination, as `RETURN_PLACE` may appear + // where a full `Place` is not allowed. + let (remap_destination, destination_local) = if let Some(d) = dest.as_local() { + (false, d) + } else { + ( + true, + self.new_call_temp( + caller_body, + &callsite, + destination.ty(caller_body, self.tcx).ty, + ), + ) + }; + // Copy the arguments if needed. let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body); @@ -560,7 +575,7 @@ impl<'tcx> Inliner<'tcx> { new_locals: Local::new(caller_body.local_decls.len()).., new_scopes: SourceScope::new(caller_body.source_scopes.len()).., new_blocks: BasicBlock::new(caller_body.basic_blocks.len()).., - destination: dest, + destination: destination_local, callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(), callsite, cleanup_block: cleanup, @@ -591,6 +606,16 @@ impl<'tcx> Inliner<'tcx> { // To avoid repeated O(n) insert, push any new statements to the end and rotate // the slice once. let mut n = 0; + if remap_destination { + caller_body[block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new(( + dest, + Rvalue::Use(Operand::Move(destination_local.into())), + ))), + }); + n += 1; + } for local in callee_body.vars_and_temps_iter().rev() { if !callee_body.local_decls[local].internal && integrator.always_live_locals.contains(local) @@ -922,12 +947,12 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { return; }; - let Some(&f_ty) = layout.field_tys.get(local) else { + let Some(f_ty) = layout.field_tys.get(local) else { self.validation = Err("malformed MIR"); return; }; - f_ty + f_ty.ty } else { let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else { self.validation = Err("malformed MIR"); @@ -959,7 +984,7 @@ struct Integrator<'a, 'tcx> { new_locals: RangeFrom<Local>, new_scopes: RangeFrom<SourceScope>, new_blocks: RangeFrom<BasicBlock>, - destination: Place<'tcx>, + destination: Local, callsite_scope: SourceScopeData<'tcx>, callsite: &'a CallSite<'tcx>, cleanup_block: Option<BasicBlock>, @@ -972,7 +997,7 @@ struct Integrator<'a, 'tcx> { impl Integrator<'_, '_> { fn map_local(&self, local: Local) -> Local { let new = if local == RETURN_PLACE { - self.destination.local + self.destination } else { let idx = local.index() - 1; if idx < self.args.len() { @@ -1053,27 +1078,6 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { *span = span.fresh_expansion(self.expn_data); } - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - for elem in place.projection { - // FIXME: Make sure that return place is not used in an indexing projection, since it - // won't be rebased as it is supposed to be. - assert_ne!(ProjectionElem::Index(RETURN_PLACE), elem); - } - - // If this is the `RETURN_PLACE`, we need to rebase any projections onto it. - let dest_proj_len = self.destination.projection.len(); - if place.local == RETURN_PLACE && dest_proj_len > 0 { - let mut projs = Vec::with_capacity(dest_proj_len + place.projection.len()); - projs.extend(self.destination.projection); - projs.extend(place.projection); - - place.projection = self.tcx.intern_place_elems(&*projs); - } - // Handles integrating any locals that occur in the base - // or projections - self.super_place(place, context, location) - } - fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { self.in_cleanup_block = data.is_cleanup; self.super_basic_block_data(block, data); diff --git a/compiler/rustc_mir_transform/src/instcombine.rs b/compiler/rustc_mir_transform/src/instcombine.rs index 2f3c65869ef..e1faa7a08d9 100644 --- a/compiler/rustc_mir_transform/src/instcombine.rs +++ b/compiler/rustc_mir_transform/src/instcombine.rs @@ -6,7 +6,8 @@ use rustc_middle::mir::{ BinOp, Body, Constant, ConstantKind, LocalDecls, Operand, Place, ProjectionElem, Rvalue, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp, }; -use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::ty::{self, layout::TyAndLayout, ParamEnv, ParamEnvAnd, SubstsRef, Ty, TyCtxt}; +use rustc_span::symbol::{sym, Symbol}; pub struct InstCombine; @@ -16,7 +17,11 @@ impl<'tcx> MirPass<'tcx> for InstCombine { } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let ctx = InstCombineContext { tcx, local_decls: &body.local_decls }; + let ctx = InstCombineContext { + tcx, + local_decls: &body.local_decls, + param_env: tcx.param_env_reveal_all_normalized(body.source.def_id()), + }; for block in body.basic_blocks.as_mut() { for statement in block.statements.iter_mut() { match statement.kind { @@ -33,6 +38,10 @@ impl<'tcx> MirPass<'tcx> for InstCombine { &mut block.terminator.as_mut().unwrap(), &mut block.statements, ); + ctx.combine_intrinsic_assert( + &mut block.terminator.as_mut().unwrap(), + &mut block.statements, + ); } } } @@ -40,6 +49,7 @@ impl<'tcx> MirPass<'tcx> for InstCombine { struct InstCombineContext<'tcx, 'a> { tcx: TyCtxt<'tcx>, local_decls: &'a LocalDecls<'tcx>, + param_env: ParamEnv<'tcx>, } impl<'tcx> InstCombineContext<'tcx, '_> { @@ -200,4 +210,76 @@ impl<'tcx> InstCombineContext<'tcx, '_> { }); terminator.kind = TerminatorKind::Goto { target: destination_block }; } + + fn combine_intrinsic_assert( + &self, + terminator: &mut Terminator<'tcx>, + _statements: &mut Vec<Statement<'tcx>>, + ) { + let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else { return; }; + let Some(target_block) = target else { return; }; + let func_ty = func.ty(self.local_decls, self.tcx); + let Some((intrinsic_name, substs)) = resolve_rust_intrinsic(self.tcx, func_ty) else { + return; + }; + // The intrinsics we are interested in have one generic parameter + if substs.is_empty() { + return; + } + let ty = substs.type_at(0); + + // Check this is a foldable intrinsic before we query the layout of our generic parameter + let Some(assert_panics) = intrinsic_assert_panics(intrinsic_name) else { return; }; + let Ok(layout) = self.tcx.layout_of(self.param_env.and(ty)) else { return; }; + if assert_panics(self.tcx, self.param_env.and(layout)) { + // If we know the assert panics, indicate to later opts that the call diverges + *target = None; + } else { + // If we know the assert does not panic, turn the call into a Goto + terminator.kind = TerminatorKind::Goto { target: *target_block }; + } + } +} + +fn intrinsic_assert_panics<'tcx>( + intrinsic_name: Symbol, +) -> Option<fn(TyCtxt<'tcx>, ParamEnvAnd<'tcx, TyAndLayout<'tcx>>) -> bool> { + fn inhabited_predicate<'tcx>( + _tcx: TyCtxt<'tcx>, + param_env_and_layout: ParamEnvAnd<'tcx, TyAndLayout<'tcx>>, + ) -> bool { + let (_param_env, layout) = param_env_and_layout.into_parts(); + layout.abi.is_uninhabited() + } + fn zero_valid_predicate<'tcx>( + tcx: TyCtxt<'tcx>, + param_env_and_layout: ParamEnvAnd<'tcx, TyAndLayout<'tcx>>, + ) -> bool { + !tcx.permits_zero_init(param_env_and_layout) + } + fn mem_uninitialized_valid_predicate<'tcx>( + tcx: TyCtxt<'tcx>, + param_env_and_layout: ParamEnvAnd<'tcx, TyAndLayout<'tcx>>, + ) -> bool { + !tcx.permits_uninit_init(param_env_and_layout) + } + + match intrinsic_name { + sym::assert_inhabited => Some(inhabited_predicate), + sym::assert_zero_valid => Some(zero_valid_predicate), + sym::assert_mem_uninitialized_valid => Some(mem_uninitialized_valid_predicate), + _ => None, + } +} + +fn resolve_rust_intrinsic<'tcx>( + tcx: TyCtxt<'tcx>, + func_ty: Ty<'tcx>, +) -> Option<(Symbol, SubstsRef<'tcx>)> { + if let ty::FnDef(def_id, substs) = *func_ty.kind() { + if tcx.is_intrinsic(def_id) { + return Some((tcx.item_name(def_id), substs)); + } + } + None } diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 16b8a901f36..fe3d5b1cce4 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -90,7 +90,6 @@ mod shim; pub mod simplify; mod simplify_branches; mod simplify_comparison_integral; -mod simplify_try; mod sroa; mod uninhabited_enum_branching; mod unreachable_prop; @@ -124,6 +123,7 @@ pub fn provide(providers: &mut Providers) { mir_drops_elaborated_and_const_checked, mir_for_ctfe, mir_for_ctfe_of_const_arg, + mir_generator_witnesses: generator::mir_generator_witnesses, optimized_mir, is_mir_available, is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did), @@ -426,6 +426,9 @@ fn mir_drops_elaborated_and_const_checked( return tcx.mir_drops_elaborated_and_const_checked(def); } + if tcx.generator_kind(def.did).is_some() { + tcx.ensure().mir_generator_witnesses(def.did); + } let mir_borrowck = tcx.mir_borrowck_opt_const_arg(def); let is_fn_like = tcx.def_kind(def.did).is_fn_like(); @@ -487,7 +490,6 @@ fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx> fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let passes: &[&dyn MirPass<'tcx>] = &[ &cleanup_post_borrowck::CleanupPostBorrowck, - &simplify_branches::SimplifyConstCondition::new("initial"), &remove_noop_landing_pads::RemoveNoopLandingPads, &simplify::SimplifyCfg::new("early-opt"), &deref_separator::Derefer, @@ -568,8 +570,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &o1(simplify_branches::SimplifyConstCondition::new("after-const-prop")), &early_otherwise_branch::EarlyOtherwiseBranch, &simplify_comparison_integral::SimplifyComparisonIntegral, - &simplify_try::SimplifyArmIdentity, - &simplify_try::SimplifyBranchSame, &dead_store_elimination::DeadStoreElimination, &dest_prop::DestinationPropagation, &o1(simplify_branches::SimplifyConstCondition::new("final")), diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index dace540fa29..8d4fe74e7d3 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -152,7 +152,7 @@ fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) } else { InternalSubsts::identity_for_item(tcx, def_id) }; - let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs); + let sig = tcx.fn_sig(def_id).subst(tcx, substs); let sig = tcx.erase_late_bound_regions(sig); let span = tcx.def_span(def_id); @@ -363,7 +363,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { // we must subst the self_ty because it's // otherwise going to be TySelf and we can't index // or access fields of a Place of type TySelf. - let sig = tcx.bound_fn_sig(def_id).subst(tcx, &[self_ty.into()]); + let sig = tcx.fn_sig(def_id).subst(tcx, &[self_ty.into()]); let sig = tcx.erase_late_bound_regions(sig); let span = tcx.def_span(def_id); @@ -606,7 +606,7 @@ fn build_call_shim<'tcx>( }; let def_id = instance.def_id(); - let sig = tcx.bound_fn_sig(def_id); + let sig = tcx.fn_sig(def_id); let sig = sig.map_bound(|sig| tcx.erase_late_bound_regions(sig)); assert_eq!(sig_substs.is_some(), !instance.has_polymorphic_mir_body()); @@ -798,7 +798,11 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { let param_env = tcx.param_env(ctor_id); // Normalize the sig. - let sig = tcx.fn_sig(ctor_id).no_bound_vars().expect("LBR in ADT constructor signature"); + let sig = tcx + .fn_sig(ctor_id) + .subst_identity() + .no_bound_vars() + .expect("LBR in ADT constructor signature"); let sig = tcx.normalize_erasing_regions(param_env, sig); let ty::Adt(adt_def, substs) = sig.output().kind() else { diff --git a/compiler/rustc_mir_transform/src/simplify_try.rs b/compiler/rustc_mir_transform/src/simplify_try.rs deleted file mode 100644 index e4f3ace9a93..00000000000 --- a/compiler/rustc_mir_transform/src/simplify_try.rs +++ /dev/null @@ -1,822 +0,0 @@ -//! The general point of the optimizations provided here is to simplify something like: -//! -//! ```rust -//! # fn foo<T, E>(x: Result<T, E>) -> Result<T, E> { -//! match x { -//! Ok(x) => Ok(x), -//! Err(x) => Err(x) -//! } -//! # } -//! ``` -//! -//! into just `x`. - -use crate::{simplify, MirPass}; -use itertools::Itertools as _; -use rustc_index::{bit_set::BitSet, vec::IndexVec}; -use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; -use rustc_middle::mir::*; -use rustc_middle::ty::{self, List, Ty, TyCtxt}; -use rustc_target::abi::VariantIdx; -use std::iter::{once, Enumerate, Peekable}; -use std::slice::Iter; - -/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. -/// -/// This is done by transforming basic blocks where the statements match: -/// -/// ```ignore (MIR) -/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY ); -/// _TMP_2 = _LOCAL_TMP; -/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2; -/// discriminant(_LOCAL_0) = VAR_IDX; -/// ``` -/// -/// into: -/// -/// ```ignore (MIR) -/// _LOCAL_0 = move _LOCAL_1 -/// ``` -pub struct SimplifyArmIdentity; - -#[derive(Debug)] -struct ArmIdentityInfo<'tcx> { - /// Storage location for the variant's field - local_temp_0: Local, - /// Storage location holding the variant being read from - local_1: Local, - /// The variant field being read from - vf_s0: VarField<'tcx>, - /// Index of the statement which loads the variant being read - get_variant_field_stmt: usize, - - /// Tracks each assignment to a temporary of the variant's field - field_tmp_assignments: Vec<(Local, Local)>, - - /// Storage location holding the variant's field that was read from - local_tmp_s1: Local, - /// Storage location holding the enum that we are writing to - local_0: Local, - /// The variant field being written to - vf_s1: VarField<'tcx>, - - /// Storage location that the discriminant is being written to - set_discr_local: Local, - /// The variant being written - set_discr_var_idx: VariantIdx, - - /// Index of the statement that should be overwritten as a move - stmt_to_overwrite: usize, - /// SourceInfo for the new move - source_info: SourceInfo, - - /// Indices of matching Storage{Live,Dead} statements encountered. - /// (StorageLive index,, StorageDead index, Local) - storage_stmts: Vec<(usize, usize, Local)>, - - /// The statements that should be removed (turned into nops) - stmts_to_remove: Vec<usize>, - - /// Indices of debug variables that need to be adjusted to point to - // `{local_0}.{dbg_projection}`. - dbg_info_to_adjust: Vec<usize>, - - /// The projection used to rewrite debug info. - dbg_projection: &'tcx List<PlaceElem<'tcx>>, -} - -fn get_arm_identity_info<'a, 'tcx>( - stmts: &'a [Statement<'tcx>], - locals_count: usize, - debug_info: &'a [VarDebugInfo<'tcx>], -) -> Option<ArmIdentityInfo<'tcx>> { - // This can't possibly match unless there are at least 3 statements in the block - // so fail fast on tiny blocks. - if stmts.len() < 3 { - return None; - } - - let mut tmp_assigns = Vec::new(); - let mut nop_stmts = Vec::new(); - let mut storage_stmts = Vec::new(); - let mut storage_live_stmts = Vec::new(); - let mut storage_dead_stmts = Vec::new(); - - type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>; - - fn is_storage_stmt(stmt: &Statement<'_>) -> bool { - matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_)) - } - - /// Eats consecutive Statements which match `test`, performing the specified `action` for each. - /// The iterator `stmt_iter` is not advanced if none were matched. - fn try_eat<'a, 'tcx>( - stmt_iter: &mut StmtIter<'a, 'tcx>, - test: impl Fn(&'a Statement<'tcx>) -> bool, - mut action: impl FnMut(usize, &'a Statement<'tcx>), - ) { - while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) { - let (idx, stmt) = stmt_iter.next().unwrap(); - - action(idx, stmt); - } - } - - /// Eats consecutive `StorageLive` and `StorageDead` Statements. - /// The iterator `stmt_iter` is not advanced if none were found. - fn try_eat_storage_stmts( - stmt_iter: &mut StmtIter<'_, '_>, - storage_live_stmts: &mut Vec<(usize, Local)>, - storage_dead_stmts: &mut Vec<(usize, Local)>, - ) { - try_eat(stmt_iter, is_storage_stmt, |idx, stmt| { - if let StatementKind::StorageLive(l) = stmt.kind { - storage_live_stmts.push((idx, l)); - } else if let StatementKind::StorageDead(l) = stmt.kind { - storage_dead_stmts.push((idx, l)); - } - }) - } - - fn is_tmp_storage_stmt(stmt: &Statement<'_>) -> bool { - use rustc_middle::mir::StatementKind::Assign; - if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind { - place.as_local().is_some() && p.as_local().is_some() - } else { - false - } - } - - /// Eats consecutive `Assign` Statements. - // The iterator `stmt_iter` is not advanced if none were found. - fn try_eat_assign_tmp_stmts( - stmt_iter: &mut StmtIter<'_, '_>, - tmp_assigns: &mut Vec<(Local, Local)>, - nop_stmts: &mut Vec<usize>, - ) { - try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| { - use rustc_middle::mir::StatementKind::Assign; - if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = - &stmt.kind - { - tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap())); - nop_stmts.push(idx); - } - }) - } - - fn find_storage_live_dead_stmts_for_local( - local: Local, - stmts: &[Statement<'_>], - ) -> Option<(usize, usize)> { - trace!("looking for {:?}", local); - let mut storage_live_stmt = None; - let mut storage_dead_stmt = None; - for (idx, stmt) in stmts.iter().enumerate() { - if stmt.kind == StatementKind::StorageLive(local) { - storage_live_stmt = Some(idx); - } else if stmt.kind == StatementKind::StorageDead(local) { - storage_dead_stmt = Some(idx); - } - } - - Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX))) - } - - // Try to match the expected MIR structure with the basic block we're processing. - // We want to see something that looks like: - // ``` - // (StorageLive(_) | StorageDead(_));* - // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); - // (StorageLive(_) | StorageDead(_));* - // (tmp_n+1 = tmp_n);* - // (StorageLive(_) | StorageDead(_));* - // (tmp_n+1 = tmp_n);* - // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp; - // discriminant(LOCAL_FROM) = VariantIdx; - // (StorageLive(_) | StorageDead(_));* - // ``` - let mut stmt_iter = stmts.iter().enumerate().peekable(); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - let (get_variant_field_stmt, stmt) = stmt_iter.next()?; - let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?; - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); - - let (idx, stmt) = stmt_iter.next()?; - let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?; - nop_stmts.push(idx); - - let (idx, stmt) = stmt_iter.next()?; - let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?; - let discr_stmt_source_info = stmt.source_info; - nop_stmts.push(idx); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - for (live_idx, live_local) in storage_live_stmts { - if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) { - let (dead_idx, _) = storage_dead_stmts.swap_remove(i); - storage_stmts.push((live_idx, dead_idx, live_local)); - - if live_local == local_tmp_s0 { - nop_stmts.push(get_variant_field_stmt); - } - } - } - // We sort primitive usize here so we can use unstable sort - nop_stmts.sort_unstable(); - - // Use one of the statements we're going to discard between the point - // where the storage location for the variant field becomes live and - // is killed. - let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?; - let stmt_to_overwrite = - nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx); - - let mut tmp_assigned_vars = BitSet::new_empty(locals_count); - for (l, r) in &tmp_assigns { - tmp_assigned_vars.insert(*l); - tmp_assigned_vars.insert(*r); - } - - let dbg_info_to_adjust: Vec<_> = debug_info - .iter() - .enumerate() - .filter_map(|(i, var_info)| { - if let VarDebugInfoContents::Place(p) = var_info.value { - if tmp_assigned_vars.contains(p.local) { - return Some(i); - } - } - - None - }) - .collect(); - - Some(ArmIdentityInfo { - local_temp_0: local_tmp_s0, - local_1, - vf_s0, - get_variant_field_stmt, - field_tmp_assignments: tmp_assigns, - local_tmp_s1, - local_0, - vf_s1, - set_discr_local, - set_discr_var_idx, - stmt_to_overwrite: *stmt_to_overwrite?, - source_info: discr_stmt_source_info, - storage_stmts, - stmts_to_remove: nop_stmts, - dbg_info_to_adjust, - dbg_projection, - }) -} - -fn optimization_applies<'tcx>( - opt_info: &ArmIdentityInfo<'tcx>, - local_decls: &IndexVec<Local, LocalDecl<'tcx>>, - local_uses: &IndexVec<Local, usize>, - var_debug_info: &[VarDebugInfo<'tcx>], -) -> bool { - trace!("testing if optimization applies..."); - - // FIXME(wesleywiser): possibly relax this restriction? - if opt_info.local_0 == opt_info.local_1 { - trace!("NO: moving into ourselves"); - return false; - } else if opt_info.vf_s0 != opt_info.vf_s1 { - trace!("NO: the field-and-variant information do not match"); - return false; - } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty { - // FIXME(Centril,oli-obk): possibly relax to same layout? - trace!("NO: source and target locals have different types"); - return false; - } else if (opt_info.local_0, opt_info.vf_s0.var_idx) - != (opt_info.set_discr_local, opt_info.set_discr_var_idx) - { - trace!("NO: the discriminants do not match"); - return false; - } - - // Verify the assignment chain consists of the form b = a; c = b; d = c; etc... - if opt_info.field_tmp_assignments.is_empty() { - trace!("NO: no assignments found"); - return false; - } - let mut last_assigned_to = opt_info.field_tmp_assignments[0].1; - let source_local = last_assigned_to; - for (l, r) in &opt_info.field_tmp_assignments { - if *r != last_assigned_to { - trace!("NO: found unexpected assignment {:?} = {:?}", l, r); - return false; - } - - last_assigned_to = *l; - } - - // Check that the first and last used locals are only used twice - // since they are of the form: - // - // ``` - // _first = ((_x as Variant).n: ty); - // _n = _first; - // ... - // ((_y as Variant).n: ty) = _n; - // discriminant(_y) = z; - // ``` - for (l, r) in &opt_info.field_tmp_assignments { - if local_uses[*l] != 2 { - warn!("NO: FAILED assignment chain local {:?} was used more than twice", l); - return false; - } else if local_uses[*r] != 2 { - warn!("NO: FAILED assignment chain local {:?} was used more than twice", r); - return false; - } - } - - // Check that debug info only points to full Locals and not projections. - for dbg_idx in &opt_info.dbg_info_to_adjust { - let dbg_info = &var_debug_info[*dbg_idx]; - if let VarDebugInfoContents::Place(p) = dbg_info.value { - if !p.projection.is_empty() { - trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p); - return false; - } - } - } - - if source_local != opt_info.local_temp_0 { - trace!( - "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}", - source_local, - opt_info.local_temp_0 - ); - return false; - } else if last_assigned_to != opt_info.local_tmp_s1 { - trace!( - "NO: end of assignment chain does not match written enum temp: {:?} != {:?}", - last_assigned_to, - opt_info.local_tmp_s1 - ); - return false; - } - - trace!("SUCCESS: optimization applies!"); - true -} - -impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // FIXME(77359): This optimization can result in unsoundness. - if !tcx.sess.opts.unstable_opts.unsound_mir_opts { - return; - } - - let source = body.source; - trace!("running SimplifyArmIdentity on {:?}", source); - - let local_uses = LocalUseCounter::get_local_uses(body); - for bb in body.basic_blocks.as_mut() { - if let Some(opt_info) = - get_arm_identity_info(&bb.statements, body.local_decls.len(), &body.var_debug_info) - { - trace!("got opt_info = {:#?}", opt_info); - if !optimization_applies( - &opt_info, - &body.local_decls, - &local_uses, - &body.var_debug_info, - ) { - debug!("optimization skipped for {:?}", source); - continue; - } - - // Also remove unused Storage{Live,Dead} statements which correspond - // to temps used previously. - for (live_idx, dead_idx, local) in &opt_info.storage_stmts { - // The temporary that we've read the variant field into is scoped to this block, - // so we can remove the assignment. - if *local == opt_info.local_temp_0 { - bb.statements[opt_info.get_variant_field_stmt].make_nop(); - } - - for (left, right) in &opt_info.field_tmp_assignments { - if local == left || local == right { - bb.statements[*live_idx].make_nop(); - bb.statements[*dead_idx].make_nop(); - } - } - } - - // Right shape; transform - for stmt_idx in opt_info.stmts_to_remove { - bb.statements[stmt_idx].make_nop(); - } - - let stmt = &mut bb.statements[opt_info.stmt_to_overwrite]; - stmt.source_info = opt_info.source_info; - stmt.kind = StatementKind::Assign(Box::new(( - opt_info.local_0.into(), - Rvalue::Use(Operand::Move(opt_info.local_1.into())), - ))); - - bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop); - - // Fix the debug info to point to the right local - for dbg_index in opt_info.dbg_info_to_adjust { - let dbg_info = &mut body.var_debug_info[dbg_index]; - assert!( - matches!(dbg_info.value, VarDebugInfoContents::Place(_)), - "value was not a Place" - ); - if let VarDebugInfoContents::Place(p) = &mut dbg_info.value { - assert!(p.projection.is_empty()); - p.local = opt_info.local_0; - p.projection = opt_info.dbg_projection; - } - } - - trace!("block is now {:?}", bb.statements); - } - } - } -} - -struct LocalUseCounter { - local_uses: IndexVec<Local, usize>, -} - -impl LocalUseCounter { - fn get_local_uses(body: &Body<'_>) -> IndexVec<Local, usize> { - let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) }; - counter.visit_body(body); - counter.local_uses - } -} - -impl Visitor<'_> for LocalUseCounter { - fn visit_local(&mut self, local: Local, context: PlaceContext, _location: Location) { - if context.is_storage_marker() - || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo) - { - return; - } - - self.local_uses[local] += 1; - } -} - -/// Match on: -/// ```ignore (MIR) -/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); -/// ``` -fn match_get_variant_field<'tcx>( - stmt: &Statement<'tcx>, -) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> { - match &stmt.kind { - StatementKind::Assign(box ( - place_into, - Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)), - )) => { - let local_into = place_into.as_local()?; - let (local_from, vf) = match_variant_field_place(*pf)?; - Some((local_into, local_from, vf, pf.projection)) - } - _ => None, - } -} - -/// Match on: -/// ```ignore (MIR) -/// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO; -/// ``` -fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> { - match &stmt.kind { - StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => { - let local_into = place_into.as_local()?; - let (local_from, vf) = match_variant_field_place(*place_from)?; - Some((local_into, local_from, vf)) - } - _ => None, - } -} - -/// Match on: -/// ```ignore (MIR) -/// discriminant(_LOCAL_TO_SET) = VAR_IDX; -/// ``` -fn match_set_discr(stmt: &Statement<'_>) -> Option<(Local, VariantIdx)> { - match &stmt.kind { - StatementKind::SetDiscriminant { place, variant_index } => { - Some((place.as_local()?, *variant_index)) - } - _ => None, - } -} - -#[derive(PartialEq, Debug)] -struct VarField<'tcx> { - field: Field, - field_ty: Ty<'tcx>, - var_idx: VariantIdx, -} - -/// Match on `((_LOCAL as Variant).FIELD: TY)`. -fn match_variant_field_place(place: Place<'_>) -> Option<(Local, VarField<'_>)> { - match place.as_ref() { - PlaceRef { - local, - projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)], - } => Some((local, VarField { field, field_ty: ty, var_idx })), - _ => None, - } -} - -/// Simplifies `SwitchInt(_) -> [targets]`, -/// where all the `targets` have the same form, -/// into `goto -> target_first`. -pub struct SimplifyBranchSame; - -impl<'tcx> MirPass<'tcx> for SimplifyBranchSame { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // This optimization is disabled by default for now due to - // soundness concerns; see issue #89485 and PR #89489. - if !tcx.sess.opts.unstable_opts.unsound_mir_opts { - return; - } - - trace!("Running SimplifyBranchSame on {:?}", body.source); - let finder = SimplifyBranchSameOptimizationFinder { body, tcx }; - let opts = finder.find(); - - let did_remove_blocks = opts.len() > 0; - for opt in opts.iter() { - trace!("SUCCESS: Applying optimization {:?}", opt); - // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`. - body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind = - TerminatorKind::Goto { target: opt.bb_to_goto }; - } - - if did_remove_blocks { - // We have dead blocks now, so remove those. - simplify::remove_dead_blocks(tcx, body); - } - } -} - -#[derive(Debug)] -struct SimplifyBranchSameOptimization { - /// All basic blocks are equal so go to this one - bb_to_goto: BasicBlock, - /// Basic block where the terminator can be simplified to a goto - bb_to_opt_terminator: BasicBlock, -} - -struct SwitchTargetAndValue { - target: BasicBlock, - // None in case of the `otherwise` case - value: Option<u128>, -} - -struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> { - body: &'a Body<'tcx>, - tcx: TyCtxt<'tcx>, -} - -impl<'tcx> SimplifyBranchSameOptimizationFinder<'_, 'tcx> { - fn find(&self) -> Vec<SimplifyBranchSameOptimization> { - self.body - .basic_blocks - .iter_enumerated() - .filter_map(|(bb_idx, bb)| { - let (discr_switched_on, targets_and_values) = match &bb.terminator().kind { - TerminatorKind::SwitchInt { targets, discr, .. } => { - let targets_and_values: Vec<_> = targets.iter() - .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) }) - .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None })) - .collect(); - (discr, targets_and_values) - }, - _ => return None, - }; - - // find the adt that has its discriminant read - // assuming this must be the last statement of the block - let adt_matched_on = match &bb.statements.last()?.kind { - StatementKind::Assign(box (place, rhs)) - if Some(*place) == discr_switched_on.place() => - { - match rhs { - Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place, - _ => { - trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs); - return None; - } - } - } - other => { - trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other); - return None - }, - }; - - let mut iter_bbs_reachable = targets_and_values - .iter() - .map(|target_and_value| (target_and_value, &self.body.basic_blocks[target_and_value.target])) - .filter(|(_, bb)| { - // Reaching `unreachable` is UB so assume it doesn't happen. - bb.terminator().kind != TerminatorKind::Unreachable - }) - .peekable(); - - let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx); - let mut all_successors_equivalent = StatementEquality::TrivialEqual; - - // All successor basic blocks must be equal or contain statements that are pairwise considered equal. - for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() { - let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup - && bb_l.terminator().kind == bb_r.terminator().kind - && bb_l.statements.len() == bb_r.statements.len(); - let statement_check = || { - bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| { - let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r); - if matches!(stmt_equality, StatementEquality::NotEqual) { - // short circuit - None - } else { - Some(acc.combine(&stmt_equality)) - } - }) - .unwrap_or(StatementEquality::NotEqual) - }; - if !trivial_checks { - all_successors_equivalent = StatementEquality::NotEqual; - break; - } - all_successors_equivalent = all_successors_equivalent.combine(&statement_check()); - }; - - match all_successors_equivalent{ - StatementEquality::TrivialEqual => { - // statements are trivially equal, so just take first - trace!("Statements are trivially equal"); - Some(SimplifyBranchSameOptimization { - bb_to_goto: bb_first.target, - bb_to_opt_terminator: bb_idx, - }) - } - StatementEquality::ConsideredEqual(bb_to_choose) => { - trace!("Statements are considered equal"); - Some(SimplifyBranchSameOptimization { - bb_to_goto: bb_to_choose, - bb_to_opt_terminator: bb_idx, - }) - } - StatementEquality::NotEqual => { - trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx); - None - } - } - }) - .collect() - } - - /// Tests if two statements can be considered equal - /// - /// Statements can be trivially equal if the kinds match. - /// But they can also be considered equal in the following case A: - /// ```ignore (MIR) - /// discriminant(_0) = 0; // bb1 - /// _0 = move _1; // bb2 - /// ``` - /// In this case the two statements are equal iff - /// - `_0` is an enum where the variant index 0 is fieldless, and - /// - bb1 was targeted by a switch where the discriminant of `_1` was switched on - fn statement_equality( - &self, - adt_matched_on: Place<'tcx>, - x: &Statement<'tcx>, - x_target_and_value: &SwitchTargetAndValue, - y: &Statement<'tcx>, - y_target_and_value: &SwitchTargetAndValue, - ) -> StatementEquality { - let helper = |rhs: &Rvalue<'tcx>, - place: &Place<'tcx>, - variant_index: VariantIdx, - switch_value: u128, - side_to_choose| { - let place_type = place.ty(self.body, self.tcx).ty; - let adt = match *place_type.kind() { - ty::Adt(adt, _) if adt.is_enum() => adt, - _ => return StatementEquality::NotEqual, - }; - // We need to make sure that the switch value that targets the bb with - // SetDiscriminant is the same as the variant discriminant. - let variant_discr = adt.discriminant_for_variant(self.tcx, variant_index).val; - if variant_discr != switch_value { - trace!( - "NO: variant discriminant {} does not equal switch value {}", - variant_discr, - switch_value - ); - return StatementEquality::NotEqual; - } - let variant_is_fieldless = adt.variant(variant_index).fields.is_empty(); - if !variant_is_fieldless { - trace!("NO: variant {:?} was not fieldless", variant_index); - return StatementEquality::NotEqual; - } - - match rhs { - Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => { - StatementEquality::ConsideredEqual(side_to_choose) - } - _ => { - trace!( - "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}", - rhs, - adt_matched_on - ); - StatementEquality::NotEqual - } - } - }; - match (&x.kind, &y.kind) { - // trivial case - (x, y) if x == y => StatementEquality::TrivialEqual, - - // check for case A - ( - StatementKind::Assign(box (_, rhs)), - &StatementKind::SetDiscriminant { ref place, variant_index }, - ) if y_target_and_value.value.is_some() => { - // choose basic block of x, as that has the assign - helper( - rhs, - place, - variant_index, - y_target_and_value.value.unwrap(), - x_target_and_value.target, - ) - } - ( - &StatementKind::SetDiscriminant { ref place, variant_index }, - &StatementKind::Assign(box (_, ref rhs)), - ) if x_target_and_value.value.is_some() => { - // choose basic block of y, as that has the assign - helper( - rhs, - place, - variant_index, - x_target_and_value.value.unwrap(), - y_target_and_value.target, - ) - } - _ => { - trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y); - StatementEquality::NotEqual - } - } - } -} - -#[derive(Copy, Clone, Eq, PartialEq)] -enum StatementEquality { - /// The two statements are trivially equal; same kind - TrivialEqual, - /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization. - /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index - ConsideredEqual(BasicBlock), - /// The two statements are not equal - NotEqual, -} - -impl StatementEquality { - fn combine(&self, other: &StatementEquality) -> StatementEquality { - use StatementEquality::*; - match (self, other) { - (TrivialEqual, TrivialEqual) => TrivialEqual, - (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => { - ConsideredEqual(*b) - } - (ConsideredEqual(b1), ConsideredEqual(b2)) => { - if b1 == b2 { - ConsideredEqual(*b1) - } else { - NotEqual - } - } - (_, NotEqual) | (NotEqual, _) => NotEqual, - } - } -} diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 3a2bf051516..42124f5a480 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -215,7 +215,7 @@ struct ReplacementVisitor<'tcx, 'll> { replacements: ReplacementMap<'tcx>, /// This is used to check that we are not leaving references to replaced locals behind. all_dead_locals: BitSet<Local>, - /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage + /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage /// and deinit statement and debuginfo. fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>, } |
