diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/coroutine.rs | 190 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/counters.rs | 190 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/mod.rs | 17 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/tests.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/cross_crate_inline.rs | 9 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/dataflow_const_prop.rs | 14 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/deduce_param_attrs.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/gvn.rs | 759 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 7 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/shim.rs | 14 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/simplify.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/simplify_branches.rs | 19 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs | 104 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/unreachable_prop.rs | 201 |
15 files changed, 1039 insertions, 499 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index fa56d59dd80..fc30a718cbb 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -147,7 +147,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> { } struct PinArgVisitor<'tcx> { - ref_gen_ty: Ty<'tcx>, + ref_coroutine_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>, } @@ -168,7 +168,7 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> { local: SELF_ARG, projection: self.tcx().mk_place_elems(&[ProjectionElem::Field( FieldIdx::new(0), - self.ref_gen_ty, + self.ref_coroutine_ty, )]), }, self.tcx, @@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> { struct TransformVisitor<'tcx> { tcx: TyCtxt<'tcx>, - is_async_kind: bool, + coroutine_kind: hir::CoroutineKind, state_adt_ref: AdtDef<'tcx>, state_args: GenericArgsRef<'tcx>, @@ -249,6 +249,47 @@ struct TransformVisitor<'tcx> { } impl<'tcx> TransformVisitor<'tcx> { + fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock { + let block = BasicBlock::new(body.basic_blocks.len()); + + let source_info = SourceInfo::outermost(body.span); + + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true); + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + let statements = vec![Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }]; + + body.basic_blocks_mut().push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }); + + block + } + + fn coroutine_state_adt_and_variant_idx( + &self, + is_return: bool, + ) -> (AggregateKind<'tcx>, VariantIdx) { + let idx = VariantIdx::new(match (is_return, self.coroutine_kind) { + (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete + (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded + (true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready + (false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending + (true, hir::CoroutineKind::Gen(_)) => 0, // Option::None + (false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some + }); + + let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); + (kind, idx) + } + // Make a `CoroutineState` or `Poll` variant assignment. // // `core::ops::CoroutineState` only has single element tuple variants, @@ -261,31 +302,44 @@ impl<'tcx> TransformVisitor<'tcx> { is_return: bool, statements: &mut Vec<Statement<'tcx>>, ) { - let idx = VariantIdx::new(match (is_return, self.is_async_kind) { - (true, false) => 1, // CoroutineState::Complete - (false, false) => 0, // CoroutineState::Yielded - (true, true) => 0, // Poll::Ready - (false, true) => 1, // Poll::Pending - }); + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return); - let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); + match self.coroutine_kind { + // `Poll::Pending` + CoroutineKind::Async(_) => { + if !is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - // `Poll::Pending` - if self.is_async_kind && idx == VariantIdx::new(1) { - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + // FIXME(swatinem): assert that `val` is indeed unit? + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + // `Option::None` + CoroutineKind::Gen(_) => { + if is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - // FIXME(swatinem): assert that `val` is indeed unit? - statements.push(Statement { - kind: StatementKind::Assign(Box::new(( - Place::return_place(), - Rvalue::Aggregate(Box::new(kind), IndexVec::new()), - ))), - source_info, - }); - return; + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + CoroutineKind::Coroutine => {} } - // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)` + // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); statements.push(Statement { @@ -414,34 +468,34 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { } fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let gen_ty = body.local_decls.raw[1].ty; + let coroutine_ty = body.local_decls.raw[1].ty; - let ref_gen_ty = Ty::new_ref( + let ref_coroutine_ty = Ty::new_ref( tcx, tcx.lifetimes.re_erased, - ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut }, + ty::TypeAndMut { ty: coroutine_ty, mutbl: Mutability::Mut }, ); // Replace the by value coroutine argument - body.local_decls.raw[1].ty = ref_gen_ty; + body.local_decls.raw[1].ty = ref_coroutine_ty; // Add a deref to accesses of the coroutine state DerefArgVisitor { tcx }.visit_body(body); } fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let ref_gen_ty = body.local_decls.raw[1].ty; + let ref_coroutine_ty = body.local_decls.raw[1].ty; let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span)); let pin_adt_ref = tcx.adt_def(pin_did); - let args = tcx.mk_args(&[ref_gen_ty.into()]); - let pin_ref_gen_ty = Ty::new_adt(tcx, pin_adt_ref, args); + let args = tcx.mk_args(&[ref_coroutine_ty.into()]); + let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args); // Replace the by ref coroutine argument - body.local_decls.raw[1].ty = pin_ref_gen_ty; + body.local_decls.raw[1].ty = pin_ref_coroutine_ty; // Add the Pin field access to accesses of the coroutine state - PinArgVisitor { ref_gen_ty, tcx }.visit_body(body); + PinArgVisitor { ref_coroutine_ty, tcx }.visit_body(body); } /// Allocates a new local and replaces all references of `local` with it. Returns the new local. @@ -1050,7 +1104,7 @@ fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { fn create_coroutine_drop_shim<'tcx>( tcx: TyCtxt<'tcx>, transform: &TransformVisitor<'tcx>, - gen_ty: Ty<'tcx>, + coroutine_ty: Ty<'tcx>, body: &mut Body<'tcx>, drop_clean: BasicBlock, ) -> Body<'tcx> { @@ -1082,7 +1136,7 @@ fn create_coroutine_drop_shim<'tcx>( // Change the coroutine argument from &mut to *mut body.local_decls[SELF_ARG] = LocalDecl::with_source_info( - Ty::new_ptr(tcx, ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }), + Ty::new_ptr(tcx, ty::TypeAndMut { ty: coroutine_ty, mutbl: hir::Mutability::Mut }), source_info, ); @@ -1092,9 +1146,9 @@ fn create_coroutine_drop_shim<'tcx>( // Update the body's def to become the drop glue. // This needs to be updated before the AbortUnwindingCalls pass. - let gen_instance = body.source.instance; + let coroutine_instance = body.source.instance; let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, None); - let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(gen_ty)); + let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(coroutine_ty)); body.source.instance = drop_instance; pm::run_passes_no_validate( @@ -1106,7 +1160,7 @@ fn create_coroutine_drop_shim<'tcx>( // Temporary change MirSource to coroutine's instance so that dump_mir produces more sensible // filename. - body.source.instance = gen_instance; + body.source.instance = coroutine_instance; dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(())); body.source.instance = drop_instance; @@ -1263,10 +1317,13 @@ fn create_coroutine_resume_function<'tcx>( } if can_return { - cases.insert( - 1, - (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))), - ); + let block = match coroutine_kind { + CoroutineKind::Async(_) | CoroutineKind::Coroutine => { + insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) + } + CoroutineKind::Gen(_) => transform.insert_none_ret_block(body), + }; + cases.insert(1, (RETURNED, block)); } insert_switch(body, cases, &transform, TerminatorKind::Unreachable); @@ -1390,13 +1447,13 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>( let body = &*body; // The first argument is the coroutine type passed by value - let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; // Get the interior types and args which typeck computed - let movable = match *gen_ty.kind() { + let movable = match *coroutine_ty.kind() { ty::Coroutine(_, _, movability) => movability == hir::Movability::Movable, ty::Error(_) => return None, - _ => span_bug!(body.span, "unexpected coroutine type {}", gen_ty), + _ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty), }; // When first entering the coroutine, move the resume argument into its new local. @@ -1424,33 +1481,44 @@ impl<'tcx> MirPass<'tcx> for StateTransform { assert!(body.coroutine_drop().is_none()); // The first argument is the coroutine type passed by value - let gen_ty = body.local_decls.raw[1].ty; + let coroutine_ty = body.local_decls.raw[1].ty; // Get the discriminant type and args which typeck computed - let (discr_ty, movable) = match *gen_ty.kind() { + let (discr_ty, movable) = match *coroutine_ty.kind() { ty::Coroutine(_, args, movability) => { let args = args.as_coroutine(); (args.discr_ty(tcx), movability == hir::Movability::Movable) } _ => { - tcx.sess.delay_span_bug(body.span, format!("unexpected coroutine type {gen_ty}")); + tcx.sess + .delay_span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}")); return; } }; let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); - let (state_adt_ref, state_args) = if is_async_kind { - // Compute Poll<return_ty> - let poll_did = tcx.require_lang_item(LangItem::Poll, None); - let poll_adt_ref = tcx.adt_def(poll_did); - let poll_args = tcx.mk_args(&[body.return_ty().into()]); - (poll_adt_ref, poll_args) - } else { - // Compute CoroutineState<yield_ty, return_ty> - let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); - (state_adt_ref, state_args) + let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() { + CoroutineKind::Async(_) => { + // Compute Poll<return_ty> + let poll_did = tcx.require_lang_item(LangItem::Poll, None); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_args = tcx.mk_args(&[body.return_ty().into()]); + (poll_adt_ref, poll_args) + } + CoroutineKind::Gen(_) => { + // Compute Option<yield_ty> + let option_did = tcx.require_lang_item(LangItem::Option, None); + let option_adt_ref = tcx.adt_def(option_did); + let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]); + (option_adt_ref, option_args) + } + CoroutineKind::Coroutine => { + // Compute CoroutineState<yield_ty, return_ty> + let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); + (state_adt_ref, state_args) + } }; let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args); @@ -1518,7 +1586,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`. let mut transform = TransformVisitor { tcx, - is_async_kind, + coroutine_kind: body.coroutine_kind().unwrap(), state_adt_ref, state_args, remap, @@ -1559,7 +1627,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { dump_mir(tcx, false, "coroutine_post-transform", &0, body, |_, _| Ok(())); // Create a copy of our MIR and use it to create the drop shim for the coroutine - let drop_shim = create_coroutine_drop_shim(tcx, &transform, gen_ty, body, drop_clean); + let drop_shim = create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean); body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim); diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index d07f59bc72a..b34ec95b4e8 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -1,5 +1,3 @@ -use super::Error; - use super::graph; use graph::{BasicCoverageBlock, BcbBranch, CoverageGraph, TraverseCoverageGraphWithLoops}; @@ -12,8 +10,6 @@ use rustc_middle::mir::coverage::*; use std::fmt::{self, Debug}; -const NESTED_INDENT: &str = " "; - /// The coverage counter or counter expression associated with a particular /// BCB node or BCB edge. #[derive(Clone)] @@ -55,7 +51,7 @@ pub(super) struct CoverageCounters { /// edge between two BCBs. bcb_edge_counters: FxHashMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, /// Tracks which BCBs have a counter associated with some incoming edge. - /// Only used by debug assertions, to verify that BCBs with incoming edge + /// Only used by assertions, to verify that BCBs with incoming edge /// counters do not have their own physical counters (expressions are allowed). bcb_has_incoming_edge_counters: BitSet<BasicCoverageBlock>, /// Table of expression data, associating each expression ID with its @@ -83,7 +79,7 @@ impl CoverageCounters { &mut self, basic_coverage_blocks: &CoverageGraph, bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool, - ) -> Result<(), Error> { + ) { MakeBcbCounters::new(self, basic_coverage_blocks).make_bcb_counters(bcb_has_coverage_spans) } @@ -113,26 +109,23 @@ impl CoverageCounters { self.expressions.len() } - fn set_bcb_counter( - &mut self, - bcb: BasicCoverageBlock, - counter_kind: BcbCounter, - ) -> Result<CovTerm, Error> { - debug_assert!( + fn set_bcb_counter(&mut self, bcb: BasicCoverageBlock, counter_kind: BcbCounter) -> CovTerm { + assert!( // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also // have an expression (to be injected into an existing `BasicBlock` represented by this // `BasicCoverageBlock`). counter_kind.is_expression() || !self.bcb_has_incoming_edge_counters.contains(bcb), "attempt to add a `Counter` to a BCB target with existing incoming edge counters" ); + let term = counter_kind.as_term(); if let Some(replaced) = self.bcb_counters[bcb].replace(counter_kind) { - Error::from_string(format!( + bug!( "attempt to set a BasicCoverageBlock coverage counter more than once; \ {bcb:?} already had counter {replaced:?}", - )) + ); } else { - Ok(term) + term } } @@ -141,27 +134,26 @@ impl CoverageCounters { from_bcb: BasicCoverageBlock, to_bcb: BasicCoverageBlock, counter_kind: BcbCounter, - ) -> Result<CovTerm, Error> { - if level_enabled!(tracing::Level::DEBUG) { - // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also - // have an expression (to be injected into an existing `BasicBlock` represented by this - // `BasicCoverageBlock`). - if self.bcb_counter(to_bcb).is_some_and(|c| !c.is_expression()) { - return Error::from_string(format!( - "attempt to add an incoming edge counter from {from_bcb:?} when the target BCB already \ - has a `Counter`" - )); - } + ) -> CovTerm { + // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also + // have an expression (to be injected into an existing `BasicBlock` represented by this + // `BasicCoverageBlock`). + if let Some(node_counter) = self.bcb_counter(to_bcb) && !node_counter.is_expression() { + bug!( + "attempt to add an incoming edge counter from {from_bcb:?} \ + when the target BCB already has {node_counter:?}" + ); } + self.bcb_has_incoming_edge_counters.insert(to_bcb); let term = counter_kind.as_term(); if let Some(replaced) = self.bcb_edge_counters.insert((from_bcb, to_bcb), counter_kind) { - Error::from_string(format!( + bug!( "attempt to set an edge counter more than once; from_bcb: \ {from_bcb:?} already had counter {replaced:?}", - )) + ); } else { - Ok(term) + term } } @@ -215,14 +207,7 @@ impl<'a> MakeBcbCounters<'a> { /// One way to predict which branch executes the least is by considering loops. A loop is exited /// at a branch, so the branch that jumps to a `BasicCoverageBlock` outside the loop is almost /// always executed less than the branch that does not exit the loop. - /// - /// Returns any non-code-span expressions created to represent intermediate values (such as to - /// add two counters so the result can be subtracted from another counter), or an Error with - /// message for subsequent debugging. - fn make_bcb_counters( - &mut self, - bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool, - ) -> Result<(), Error> { + fn make_bcb_counters(&mut self, bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool) { debug!("make_bcb_counters(): adding a counter or expression to each BasicCoverageBlock"); // Walk the `CoverageGraph`. For each `BasicCoverageBlock` node with an associated @@ -239,10 +224,10 @@ impl<'a> MakeBcbCounters<'a> { while let Some(bcb) = traversal.next() { if bcb_has_coverage_spans(bcb) { debug!("{:?} has at least one coverage span. Get or make its counter", bcb); - let branching_counter_operand = self.get_or_make_counter_operand(bcb)?; + let branching_counter_operand = self.get_or_make_counter_operand(bcb); if self.bcb_needs_branch_counters(bcb) { - self.make_branch_counters(&traversal, bcb, branching_counter_operand)?; + self.make_branch_counters(&traversal, bcb, branching_counter_operand); } } else { debug!( @@ -253,14 +238,11 @@ impl<'a> MakeBcbCounters<'a> { } } - if traversal.is_complete() { - Ok(()) - } else { - Error::from_string(format!( - "`TraverseCoverageGraphWithLoops` missed some `BasicCoverageBlock`s: {:?}", - traversal.unvisited(), - )) - } + assert!( + traversal.is_complete(), + "`TraverseCoverageGraphWithLoops` missed some `BasicCoverageBlock`s: {:?}", + traversal.unvisited(), + ); } fn make_branch_counters( @@ -268,7 +250,7 @@ impl<'a> MakeBcbCounters<'a> { traversal: &TraverseCoverageGraphWithLoops<'_>, branching_bcb: BasicCoverageBlock, branching_counter_operand: CovTerm, - ) -> Result<(), Error> { + ) { let branches = self.bcb_branches(branching_bcb); debug!( "{:?} has some branch(es) without counters:\n {}", @@ -301,10 +283,10 @@ impl<'a> MakeBcbCounters<'a> { counter", branch, branching_bcb ); - self.get_or_make_counter_operand(branch.target_bcb)? + self.get_or_make_counter_operand(branch.target_bcb) } else { debug!(" {:?} has multiple incoming edges, so adding an edge counter", branch); - self.get_or_make_edge_counter_operand(branching_bcb, branch.target_bcb)? + self.get_or_make_edge_counter_operand(branching_bcb, branch.target_bcb) }; if let Some(sumup_counter_operand) = some_sumup_counter_operand.replace(branch_counter_operand) @@ -339,31 +321,18 @@ impl<'a> MakeBcbCounters<'a> { debug!("{:?} gets an expression: {:?}", expression_branch, expression); let bcb = expression_branch.target_bcb; if expression_branch.is_only_path_to_target() { - self.coverage_counters.set_bcb_counter(bcb, expression)?; + self.coverage_counters.set_bcb_counter(bcb, expression); } else { - self.coverage_counters.set_bcb_edge_counter(branching_bcb, bcb, expression)?; + self.coverage_counters.set_bcb_edge_counter(branching_bcb, bcb, expression); } - Ok(()) - } - - fn get_or_make_counter_operand(&mut self, bcb: BasicCoverageBlock) -> Result<CovTerm, Error> { - self.recursive_get_or_make_counter_operand(bcb, 1) } - fn recursive_get_or_make_counter_operand( - &mut self, - bcb: BasicCoverageBlock, - debug_indent_level: usize, - ) -> Result<CovTerm, Error> { + #[instrument(level = "debug", skip(self))] + fn get_or_make_counter_operand(&mut self, bcb: BasicCoverageBlock) -> CovTerm { // If the BCB already has a counter, return it. if let Some(counter_kind) = &self.coverage_counters.bcb_counters[bcb] { - debug!( - "{}{:?} already has a counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind, - ); - return Ok(counter_kind.as_term()); + debug!("{bcb:?} already has a counter: {counter_kind:?}"); + return counter_kind.as_term(); } // A BCB with only one incoming edge gets a simple `Counter` (via `make_counter()`). @@ -373,20 +342,12 @@ impl<'a> MakeBcbCounters<'a> { if one_path_to_target || self.bcb_predecessors(bcb).contains(&bcb) { let counter_kind = self.coverage_counters.make_counter(); if one_path_to_target { - debug!( - "{}{:?} gets a new counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind, - ); + debug!("{bcb:?} gets a new counter: {counter_kind:?}"); } else { debug!( - "{}{:?} has itself as its own predecessor. It can't be part of its own \ - Expression sum, so it will get its own new counter: {:?}. (Note, the compiled \ - code will generate an infinite loop.)", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind, + "{bcb:?} has itself as its own predecessor. It can't be part of its own \ + Expression sum, so it will get its own new counter: {counter_kind:?}. \ + (Note, the compiled code will generate an infinite loop.)", ); } return self.coverage_counters.set_bcb_counter(bcb, counter_kind); @@ -396,24 +357,14 @@ impl<'a> MakeBcbCounters<'a> { // counters and/or expressions of its incoming edges. This will recursively get or create // counters for those incoming edges first, then call `make_expression()` to sum them up, // with additional intermediate expressions as needed. + let _sumup_debug_span = debug_span!("(preparing sum-up expression)").entered(); + let mut predecessors = self.bcb_predecessors(bcb).to_owned().into_iter(); - debug!( - "{}{:?} has multiple incoming edges and will get an expression that sums them up...", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - ); - let first_edge_counter_operand = self.recursive_get_or_make_edge_counter_operand( - predecessors.next().unwrap(), - bcb, - debug_indent_level + 1, - )?; + let first_edge_counter_operand = + self.get_or_make_edge_counter_operand(predecessors.next().unwrap(), bcb); let mut some_sumup_edge_counter_operand = None; for predecessor in predecessors { - let edge_counter_operand = self.recursive_get_or_make_edge_counter_operand( - predecessor, - bcb, - debug_indent_level + 1, - )?; + let edge_counter_operand = self.get_or_make_edge_counter_operand(predecessor, bcb); if let Some(sumup_edge_counter_operand) = some_sumup_edge_counter_operand.replace(edge_counter_operand) { @@ -422,11 +373,7 @@ impl<'a> MakeBcbCounters<'a> { Op::Add, edge_counter_operand, ); - debug!( - "{}new intermediate expression: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - intermediate_expression - ); + debug!("new intermediate expression: {intermediate_expression:?}"); let intermediate_expression_operand = intermediate_expression.as_term(); some_sumup_edge_counter_operand.replace(intermediate_expression_operand); } @@ -436,59 +383,36 @@ impl<'a> MakeBcbCounters<'a> { Op::Add, some_sumup_edge_counter_operand.unwrap(), ); - debug!( - "{}{:?} gets a new counter (sum of predecessor counters): {:?}", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind - ); + drop(_sumup_debug_span); + + debug!("{bcb:?} gets a new counter (sum of predecessor counters): {counter_kind:?}"); self.coverage_counters.set_bcb_counter(bcb, counter_kind) } + #[instrument(level = "debug", skip(self))] fn get_or_make_edge_counter_operand( &mut self, from_bcb: BasicCoverageBlock, to_bcb: BasicCoverageBlock, - ) -> Result<CovTerm, Error> { - self.recursive_get_or_make_edge_counter_operand(from_bcb, to_bcb, 1) - } - - fn recursive_get_or_make_edge_counter_operand( - &mut self, - from_bcb: BasicCoverageBlock, - to_bcb: BasicCoverageBlock, - debug_indent_level: usize, - ) -> Result<CovTerm, Error> { + ) -> CovTerm { // If the source BCB has only one successor (assumed to be the given target), an edge // counter is unnecessary. Just get or make a counter for the source BCB. let successors = self.bcb_successors(from_bcb).iter(); if successors.len() == 1 { - return self.recursive_get_or_make_counter_operand(from_bcb, debug_indent_level + 1); + return self.get_or_make_counter_operand(from_bcb); } // If the edge already has a counter, return it. if let Some(counter_kind) = self.coverage_counters.bcb_edge_counters.get(&(from_bcb, to_bcb)) { - debug!( - "{}Edge {:?}->{:?} already has a counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - from_bcb, - to_bcb, - counter_kind - ); - return Ok(counter_kind.as_term()); + debug!("Edge {from_bcb:?}->{to_bcb:?} already has a counter: {counter_kind:?}"); + return counter_kind.as_term(); } // Make a new counter to count this edge. let counter_kind = self.coverage_counters.make_counter(); - debug!( - "{}Edge {:?}->{:?} gets a new counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - from_bcb, - to_bcb, - counter_kind - ); + debug!("Edge {from_bcb:?}->{to_bcb:?} gets a new counter: {counter_kind:?}"); self.coverage_counters.set_bcb_edge_counter(from_bcb, to_bcb, counter_kind) } diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index c9b36ba25ac..97e4468a0e8 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -26,18 +26,6 @@ use rustc_span::def_id::DefId; use rustc_span::source_map::SourceMap; use rustc_span::{ExpnKind, SourceFile, Span, Symbol}; -/// A simple error message wrapper for `coverage::Error`s. -#[derive(Debug)] -struct Error { - message: String, -} - -impl Error { - pub fn from_string<T>(message: String) -> Result<T, Error> { - Err(Self { message }) - } -} - /// Inserts `StatementKind::Coverage` statements that either instrument the binary with injected /// counters, via intrinsic `llvm.instrprof.increment`, and/or inject metadata used during codegen /// to construct the coverage map. @@ -167,10 +155,7 @@ impl<'a, 'tcx> Instrumentor<'a, 'tcx> { // `BasicCoverageBlock`s not already associated with a coverage span. let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); self.coverage_counters - .make_bcb_counters(&mut self.basic_coverage_blocks, bcb_has_coverage_spans) - .unwrap_or_else(|e| { - bug!("Error processing: {:?}: {:?}", self.mir_body.source.def_id(), e.message) - }); + .make_bcb_counters(&self.basic_coverage_blocks, bcb_has_coverage_spans); let mappings = self.create_mappings_and_inject_coverage_statements(&coverage_spans); diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs index 795cbce963d..702fe5f563e 100644 --- a/compiler/rustc_mir_transform/src/coverage/tests.rs +++ b/compiler/rustc_mir_transform/src/coverage/tests.rs @@ -647,15 +647,13 @@ fn test_traverse_coverage_with_loops() { fn test_make_bcb_counters() { rustc_span::create_default_session_globals_then(|| { let mir_body = goto_switchint(); - let mut basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); // Historically this test would use `spans` internals to set up fake // coverage spans for BCBs 1 and 2. Now we skip that step and just tell // BCB counter construction that those BCBs have spans. let bcb_has_coverage_spans = |bcb: BasicCoverageBlock| (1..=2).contains(&bcb.as_usize()); let mut coverage_counters = counters::CoverageCounters::new(&basic_coverage_blocks); - coverage_counters - .make_bcb_counters(&mut basic_coverage_blocks, bcb_has_coverage_spans) - .expect("should be Ok"); + coverage_counters.make_bcb_counters(&basic_coverage_blocks, bcb_has_coverage_spans); assert_eq!(coverage_counters.num_expressions(), 0); let_bcb!(1); diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs index 24d081f2ac9..4d0e261ed1f 100644 --- a/compiler/rustc_mir_transform/src/cross_crate_inline.rs +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -1,3 +1,5 @@ +use crate::inline; +use crate::pass_manager as pm; use rustc_attr::InlineAttr; use rustc_hir::def::DefKind; use rustc_hir::def_id::LocalDefId; @@ -40,8 +42,11 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { return false; } - // Don't do any inference unless optimizations are enabled. - if matches!(tcx.sess.opts.optimize, OptLevel::No) { + // Don't do any inference if codegen optimizations are disabled and also MIR inlining is not + // enabled. This ensures that we do inference even if someone only passes -Zinline-mir, + // which is less confusing than having to also enable -Copt-level=1. + if matches!(tcx.sess.opts.optimize, OptLevel::No) && !pm::should_run_pass(tcx, &inline::Inline) + { return false; } diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index 2c29978173f..81d2bba989a 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -286,9 +286,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { let val = match null_op { NullOp::SizeOf if layout.is_sized() => layout.size.bytes(), NullOp::AlignOf if layout.is_sized() => layout.align.abi.bytes(), - NullOp::OffsetOf(fields) => layout - .offset_of_subfield(&self.ecx, fields.iter().map(|f| f.index())) - .bytes(), + NullOp::OffsetOf(fields) => { + layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() + } _ => return ValueOrPlace::Value(FlatSet::Top), }; FlatSet::Elem(Scalar::from_target_usize(val, &self.tcx)) @@ -406,7 +406,8 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(), TrackElem::Discriminant => { let variant = self.ecx.read_discriminant(op).ok()?; - let discr_value = self.ecx.discriminant_for_variant(op.layout, variant).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?; Some(discr_value.into()) } TrackElem::DerefLen => { @@ -507,7 +508,8 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { return None; } let enum_ty_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?; - let discr_value = self.ecx.discriminant_for_variant(enum_ty_layout, variant_index).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(enum_ty_layout.ty, variant_index).ok()?; Some(discr_value.to_scalar()) } @@ -854,7 +856,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { } } -struct DummyMachine; +pub(crate) struct DummyMachine; impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for DummyMachine { rustc_const_eval::interpret::compile_time_machine!(<'mir, 'tcx>); diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs index 79645310a39..990cfb05e60 100644 --- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -44,6 +44,7 @@ impl<'tcx> Visitor<'tcx> for DeduceReadOnly { // Whether mutating though a `&raw const` is allowed is still undecided, so we // disable any sketchy `readonly` optimizations for now. // But we only need to do this if the pointer would point into the argument. + // IOW: for indirect places, like `&raw (*local).field`, this surely cannot mutate `local`. !place.is_indirect() } PlaceContext::NonMutatingUse(..) | PlaceContext::NonUse(..) => { diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index eece7c3e834..dce298e92e1 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -52,19 +52,59 @@ //! _a = *_b // _b is &Freeze //! _c = *_b // replaced by _c = _a //! ``` +//! +//! # Determinism of constant propagation +//! +//! When registering a new `Value`, we attempt to opportunistically evaluate it as a constant. +//! The evaluated form is inserted in `evaluated` as an `OpTy` or `None` if evaluation failed. +//! +//! The difficulty is non-deterministic evaluation of MIR constants. Some `Const` can have +//! different runtime values each time they are evaluated. This is the case with +//! `Const::Slice` which have a new pointer each time they are evaluated, and constants that +//! contain a fn pointer (`AllocId` pointing to a `GlobalAlloc::Function`) pointing to a different +//! symbol in each codegen unit. +//! +//! Meanwhile, we want to be able to read indirect constants. For instance: +//! ``` +//! static A: &'static &'static u8 = &&63; +//! fn foo() -> u8 { +//! **A // We want to replace by 63. +//! } +//! fn bar() -> u8 { +//! b"abc"[1] // We want to replace by 'b'. +//! } +//! ``` +//! +//! The `Value::Constant` variant stores a possibly unevaluated constant. Evaluating that constant +//! may be non-deterministic. When that happens, we assign a disambiguator to ensure that we do not +//! merge the constants. See `duplicate_slice` test in `gvn.rs`. +//! +//! Second, when writing constants in MIR, we do not write `Const::Slice` or `Const` +//! that contain `AllocId`s. +use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind}; +use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar}; use rustc_data_structures::fx::{FxHashMap, FxIndexSet}; use rustc_data_structures::graph::dominators::Dominators; +use rustc_hir::def::DefKind; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; use rustc_macros::newtype_index; +use rustc_middle::mir::interpret::GlobalAlloc; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::{self, Ty, TyCtxt}; -use rustc_target::abi::{VariantIdx, FIRST_VARIANT}; +use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::layout::LayoutOf; +use rustc_middle::ty::{self, Ty, TyCtxt, TypeAndMut}; +use rustc_span::def_id::DefId; +use rustc_span::DUMMY_SP; +use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT}; +use std::borrow::Cow; +use crate::dataflow_const_prop::DummyMachine; use crate::ssa::{AssignedValue, SsaLocals}; use crate::MirPass; +use either::Either; pub struct GVN; @@ -118,22 +158,33 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb]; state.visit_basic_block_data(bb, data); } - let any_replacement = state.any_replacement; // For each local that is reused (`y` above), we remove its storage statements do avoid any // difficulty. Those locals are SSA, so should be easy to optimize by LLVM without storage // statements. StorageRemover { tcx, reused_locals: state.reused_locals }.visit_body_preserves_cfg(body); - - if any_replacement { - crate::simplify::remove_unused_definitions(body); - } } newtype_index! { struct VnIndex {} } +/// Computing the aggregate's type can be quite slow, so we only keep the minimal amount of +/// information to reconstruct it when needed. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum AggregateTy<'tcx> { + /// Invariant: this must not be used for an empty array. + Array, + Tuple, + Def(DefId, ty::GenericArgsRef<'tcx>), +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum AddressKind { + Ref(BorrowKind), + Address(Mutability), +} + #[derive(Debug, PartialEq, Eq, Hash)] enum Value<'tcx> { // Root values. @@ -141,15 +192,21 @@ enum Value<'tcx> { /// The `usize` is a counter incremented by `new_opaque`. Opaque(usize), /// Evaluated or unevaluated constant value. - Constant(Const<'tcx>), + Constant { + value: Const<'tcx>, + /// Some constants do not have a deterministic value. To avoid merging two instances of the + /// same `Const`, we assign them an additional integer index. + disambiguator: usize, + }, /// An aggregate value, either tuple/closure/struct/enum. /// This does not contain unions, as we cannot reason with the value. - Aggregate(Ty<'tcx>, VariantIdx, Vec<VnIndex>), + Aggregate(AggregateTy<'tcx>, VariantIdx, Vec<VnIndex>), /// This corresponds to a `[value; count]` expression. Repeat(VnIndex, ty::Const<'tcx>), /// The address of a place. Address { place: Place<'tcx>, + kind: AddressKind, /// Give each borrow and pointer a different provenance, so we don't merge them. provenance: usize, }, @@ -177,6 +234,7 @@ enum Value<'tcx> { struct VnState<'body, 'tcx> { tcx: TyCtxt<'tcx>, + ecx: InterpCx<'tcx, 'tcx, DummyMachine>, param_env: ty::ParamEnv<'tcx>, local_decls: &'body LocalDecls<'tcx>, /// Value stored in each local. @@ -184,13 +242,14 @@ struct VnState<'body, 'tcx> { /// First local to be assigned that value. rev_locals: FxHashMap<VnIndex, Vec<Local>>, values: FxIndexSet<Value<'tcx>>, + /// Values evaluated as constants if possible. + evaluated: IndexVec<VnIndex, Option<OpTy<'tcx>>>, /// Counter to generate different values. /// This is an option to stop creating opaques during replacement. next_opaque: Option<usize>, ssa: &'body SsaLocals, dominators: &'body Dominators<BasicBlock>, reused_locals: BitSet<Local>, - any_replacement: bool, } impl<'body, 'tcx> VnState<'body, 'tcx> { @@ -203,23 +262,30 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { ) -> Self { VnState { tcx, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), param_env, local_decls, locals: IndexVec::from_elem(None, local_decls), rev_locals: FxHashMap::default(), values: FxIndexSet::default(), + evaluated: IndexVec::new(), next_opaque: Some(0), ssa, dominators, reused_locals: BitSet::new_empty(local_decls.len()), - any_replacement: false, } } #[instrument(level = "trace", skip(self), ret)] fn insert(&mut self, value: Value<'tcx>) -> VnIndex { - let (index, _) = self.values.insert_full(value); - VnIndex::from_usize(index) + let (index, new) = self.values.insert_full(value); + let index = VnIndex::from_usize(index); + if new { + let evaluated = self.eval_to_const(index); + let _index = self.evaluated.push(evaluated); + debug_assert_eq!(index, _index); + } + index } /// Create a new `Value` for which we have no information at all, except that it is distinct @@ -234,9 +300,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { /// Create a new `Value::Address` distinct from all the others. #[instrument(level = "trace", skip(self), ret)] - fn new_pointer(&mut self, place: Place<'tcx>) -> Option<VnIndex> { + fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> Option<VnIndex> { let next_opaque = self.next_opaque.as_mut()?; - let value = Value::Address { place, provenance: *next_opaque }; + let value = Value::Address { place, kind, provenance: *next_opaque }; *next_opaque += 1; Some(self.insert(value)) } @@ -258,6 +324,343 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } } + fn insert_constant(&mut self, value: Const<'tcx>) -> Option<VnIndex> { + let disambiguator = if value.is_deterministic() { + // The constant is deterministic, no need to disambiguate. + 0 + } else { + // Multiple mentions of this constant will yield different values, + // so assign a different `disambiguator` to ensure they do not get the same `VnIndex`. + let next_opaque = self.next_opaque.as_mut()?; + let disambiguator = *next_opaque; + *next_opaque += 1; + disambiguator + }; + Some(self.insert(Value::Constant { value, disambiguator })) + } + + fn insert_scalar(&mut self, scalar: Scalar, ty: Ty<'tcx>) -> VnIndex { + self.insert_constant(Const::from_scalar(self.tcx, scalar, ty)) + .expect("scalars are deterministic") + } + + #[instrument(level = "trace", skip(self), ret)] + fn eval_to_const(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> { + use Value::*; + let op = match *self.get(value) { + Opaque(_) => return None, + // Do not bother evaluating repeat expressions. This would uselessly consume memory. + Repeat(..) => return None, + + Constant { ref value, disambiguator: _ } => { + self.ecx.eval_mir_constant(value, None, None).ok()? + } + Aggregate(kind, variant, ref fields) => { + let fields = fields + .iter() + .map(|&f| self.evaluated[f].as_ref()) + .collect::<Option<Vec<_>>>()?; + let ty = match kind { + AggregateTy::Array => { + assert!(fields.len() > 0); + Ty::new_array(self.tcx, fields[0].layout.ty, fields.len() as u64) + } + AggregateTy::Tuple => { + Ty::new_tup_from_iter(self.tcx, fields.iter().map(|f| f.layout.ty)) + } + AggregateTy::Def(def_id, args) => { + self.tcx.type_of(def_id).instantiate(self.tcx, args) + } + }; + let variant = if ty.is_enum() { Some(variant) } else { None }; + let ty = self.ecx.layout_of(ty).ok()?; + if ty.is_zst() { + ImmTy::uninit(ty).into() + } else if matches!(ty.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { + let dest = self.ecx.allocate(ty, MemoryKind::Stack).ok()?; + let variant_dest = if let Some(variant) = variant { + self.ecx.project_downcast(&dest, variant).ok()? + } else { + dest.clone() + }; + for (field_index, op) in fields.into_iter().enumerate() { + let field_dest = self.ecx.project_field(&variant_dest, field_index).ok()?; + self.ecx.copy_op(op, &field_dest, /*allow_transmute*/ false).ok()?; + } + self.ecx.write_discriminant(variant.unwrap_or(FIRST_VARIANT), &dest).ok()?; + self.ecx.alloc_mark_immutable(dest.ptr().provenance.unwrap()).ok()?; + dest.into() + } else { + return None; + } + } + + Projection(base, elem) => { + let value = self.evaluated[base].as_ref()?; + let elem = match elem { + ProjectionElem::Deref => ProjectionElem::Deref, + ProjectionElem::Downcast(name, read_variant) => { + ProjectionElem::Downcast(name, read_variant) + } + ProjectionElem::Field(f, ty) => ProjectionElem::Field(f, ty), + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + ProjectionElem::ConstantIndex { offset, min_length, from_end } + } + ProjectionElem::Subslice { from, to, from_end } => { + ProjectionElem::Subslice { from, to, from_end } + } + ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), + ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), + // This should have been replaced by a `ConstantIndex` earlier. + ProjectionElem::Index(_) => return None, + }; + self.ecx.project(value, elem).ok()? + } + Address { place, kind, provenance: _ } => { + if !place.is_indirect_first_projection() { + return None; + } + let local = self.locals[place.local]?; + let pointer = self.evaluated[local].as_ref()?; + let mut mplace = self.ecx.deref_pointer(pointer).ok()?; + for proj in place.projection.iter().skip(1) { + // We have no call stack to associate a local with a value, so we cannot interpret indexing. + if matches!(proj, ProjectionElem::Index(_)) { + return None; + } + mplace = self.ecx.project(&mplace, proj).ok()?; + } + let pointer = mplace.to_ref(&self.ecx); + let ty = match kind { + AddressKind::Ref(bk) => Ty::new_ref( + self.tcx, + self.tcx.lifetimes.re_erased, + ty::TypeAndMut { ty: mplace.layout.ty, mutbl: bk.to_mutbl_lossy() }, + ), + AddressKind::Address(mutbl) => { + Ty::new_ptr(self.tcx, TypeAndMut { ty: mplace.layout.ty, mutbl }) + } + }; + let layout = self.ecx.layout_of(ty).ok()?; + ImmTy::from_immediate(pointer, layout).into() + } + + Discriminant(base) => { + let base = self.evaluated[base].as_ref()?; + let variant = self.ecx.read_discriminant(base).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(base.layout.ty, variant).ok()?; + discr_value.into() + } + Len(slice) => { + let slice = self.evaluated[slice].as_ref()?; + let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + let len = slice.len(&self.ecx).ok()?; + let imm = ImmTy::try_from_uint(len, usize_layout)?; + imm.into() + } + NullaryOp(null_op, ty) => { + let layout = self.ecx.layout_of(ty).ok()?; + if let NullOp::SizeOf | NullOp::AlignOf = null_op && layout.is_unsized() { + return None; + } + let val = match null_op { + NullOp::SizeOf => layout.size.bytes(), + NullOp::AlignOf => layout.align.abi.bytes(), + NullOp::OffsetOf(fields) => { + layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() + } + }; + let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + let imm = ImmTy::try_from_uint(val, usize_layout)?; + imm.into() + } + UnaryOp(un_op, operand) => { + let operand = self.evaluated[operand].as_ref()?; + let operand = self.ecx.read_immediate(operand).ok()?; + let (val, _) = self.ecx.overflowing_unary_op(un_op, &operand).ok()?; + val.into() + } + BinaryOp(bin_op, lhs, rhs) => { + let lhs = self.evaluated[lhs].as_ref()?; + let lhs = self.ecx.read_immediate(lhs).ok()?; + let rhs = self.evaluated[rhs].as_ref()?; + let rhs = self.ecx.read_immediate(rhs).ok()?; + let (val, _) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; + val.into() + } + CheckedBinaryOp(bin_op, lhs, rhs) => { + let lhs = self.evaluated[lhs].as_ref()?; + let lhs = self.ecx.read_immediate(lhs).ok()?; + let rhs = self.evaluated[rhs].as_ref()?; + let rhs = self.ecx.read_immediate(rhs).ok()?; + let (val, overflowed) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; + let tuple = Ty::new_tup_from_iter( + self.tcx, + [val.layout.ty, self.tcx.types.bool].into_iter(), + ); + let tuple = self.ecx.layout_of(tuple).ok()?; + ImmTy::from_scalar_pair(val.to_scalar(), Scalar::from_bool(overflowed), tuple) + .into() + } + Cast { kind, value, from: _, to } => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + let value = self.evaluated[value].as_ref()?; + let value = self.ecx.read_immediate(value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.int_to_int_or_float(&value, to).ok()?; + res.into() + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + let value = self.evaluated[value].as_ref()?; + let value = self.ecx.read_immediate(value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.float_to_float_or_int(&value, to).ok()?; + res.into() + } + CastKind::Transmute => { + let value = self.evaluated[value].as_ref()?; + let to = self.ecx.layout_of(to).ok()?; + // `offset` for immediates only supports scalar/scalar-pair ABIs, + // so bail out if the target is not one. + if value.as_mplace_or_imm().is_right() { + match (value.layout.abi, to.abi) { + (Abi::Scalar(..), Abi::Scalar(..)) => {} + (Abi::ScalarPair(..), Abi::ScalarPair(..)) => {} + _ => return None, + } + } + value.offset(Size::ZERO, to, &self.ecx).ok()? + } + _ => return None, + }, + }; + Some(op) + } + + fn project( + &mut self, + place: PlaceRef<'tcx>, + value: VnIndex, + proj: PlaceElem<'tcx>, + ) -> Option<VnIndex> { + let proj = match proj { + ProjectionElem::Deref => { + let ty = place.ty(self.local_decls, self.tcx).ty; + if let Some(Mutability::Not) = ty.ref_mutability() + && let Some(pointee_ty) = ty.builtin_deref(true) + && pointee_ty.ty.is_freeze(self.tcx, self.param_env) + { + // An immutable borrow `_x` always points to the same value for the + // lifetime of the borrow, so we can merge all instances of `*_x`. + ProjectionElem::Deref + } else { + return None; + } + } + ProjectionElem::Downcast(name, index) => ProjectionElem::Downcast(name, index), + ProjectionElem::Field(f, ty) => { + if let Value::Aggregate(_, _, fields) = self.get(value) { + return Some(fields[f.as_usize()]); + } else if let Value::Projection(outer_value, ProjectionElem::Downcast(_, read_variant)) = self.get(value) + && let Value::Aggregate(_, written_variant, fields) = self.get(*outer_value) + // This pass is not aware of control-flow, so we do not know whether the + // replacement we are doing is actually reachable. We could be in any arm of + // ``` + // match Some(x) { + // Some(y) => /* stuff */, + // None => /* other */, + // } + // ``` + // + // In surface rust, the current statement would be unreachable. + // + // However, from the reference chapter on enums and RFC 2195, + // accessing the wrong variant is not UB if the enum has repr. + // So it's not impossible for a series of MIR opts to generate + // a downcast to an inactive variant. + && written_variant == read_variant + { + return Some(fields[f.as_usize()]); + } + ProjectionElem::Field(f, ty) + } + ProjectionElem::Index(idx) => { + if let Value::Repeat(inner, _) = self.get(value) { + return Some(*inner); + } + let idx = self.locals[idx]?; + ProjectionElem::Index(idx) + } + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + match self.get(value) { + Value::Repeat(inner, _) => { + return Some(*inner); + } + Value::Aggregate(AggregateTy::Array, _, operands) => { + let offset = if from_end { + operands.len() - offset as usize + } else { + offset as usize + }; + return operands.get(offset).copied(); + } + _ => {} + }; + ProjectionElem::ConstantIndex { offset, min_length, from_end } + } + ProjectionElem::Subslice { from, to, from_end } => { + ProjectionElem::Subslice { from, to, from_end } + } + ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), + ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), + }; + + Some(self.insert(Value::Projection(value, proj))) + } + + /// Simplify the projection chain if we know better. + #[instrument(level = "trace", skip(self))] + fn simplify_place_projection(&mut self, place: &mut Place<'tcx>, location: Location) { + // If the projection is indirect, we treat the local as a value, so can replace it with + // another local. + if place.is_indirect() + && let Some(base) = self.locals[place.local] + && let Some(new_local) = self.try_as_local(base, location) + { + place.local = new_local; + self.reused_locals.insert(new_local); + } + + let mut projection = Cow::Borrowed(&place.projection[..]); + + for i in 0..projection.len() { + let elem = projection[i]; + if let ProjectionElem::Index(idx) = elem + && let Some(idx) = self.locals[idx] + { + if let Some(offset) = self.evaluated[idx].as_ref() + && let Ok(offset) = self.ecx.read_target_usize(offset) + { + projection.to_mut()[i] = ProjectionElem::ConstantIndex { + offset, + min_length: offset + 1, + from_end: false, + }; + } else if let Some(new_idx) = self.try_as_local(idx, location) { + projection.to_mut()[i] = ProjectionElem::Index(new_idx); + self.reused_locals.insert(new_idx); + } + } + } + + if projection.is_owned() { + place.projection = self.tcx.mk_place_elems(&projection); + } + + trace!(?place); + } + /// Represent the *value* which would be read from `place`, and point `place` to a preexisting /// place with the same value (if that already exists). #[instrument(level = "trace", skip(self), ret)] @@ -266,6 +669,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { place: &mut Place<'tcx>, location: Location, ) -> Option<VnIndex> { + self.simplify_place_projection(place, location); + // Invariant: `place` and `place_ref` point to the same value, even if they point to // different memory locations. let mut place_ref = place.as_ref(); @@ -280,58 +685,18 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { place_ref = PlaceRef { local, projection: &place.projection[index..] }; } - let proj = match proj { - ProjectionElem::Deref => { - let ty = Place::ty_from( - place.local, - &place.projection[..index], - self.local_decls, - self.tcx, - ) - .ty; - if let Some(Mutability::Not) = ty.ref_mutability() - && let Some(pointee_ty) = ty.builtin_deref(true) - && pointee_ty.ty.is_freeze(self.tcx, self.param_env) - { - // An immutable borrow `_x` always points to the same value for the - // lifetime of the borrow, so we can merge all instances of `*_x`. - ProjectionElem::Deref - } else { - return None; - } - } - ProjectionElem::Field(f, ty) => ProjectionElem::Field(f, ty), - ProjectionElem::Index(idx) => { - let idx = self.locals[idx]?; - ProjectionElem::Index(idx) - } - ProjectionElem::ConstantIndex { offset, min_length, from_end } => { - ProjectionElem::ConstantIndex { offset, min_length, from_end } - } - ProjectionElem::Subslice { from, to, from_end } => { - ProjectionElem::Subslice { from, to, from_end } - } - ProjectionElem::Downcast(name, index) => ProjectionElem::Downcast(name, index), - ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), - ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), - }; - value = self.insert(Value::Projection(value, proj)); + let base = PlaceRef { local: place.local, projection: &place.projection[..index] }; + value = self.project(base, value, proj)?; } - if let Some(local) = self.try_as_local(value, location) - && local != place.local - // in case we had no projection to begin with. - { - *place = local.into(); - self.reused_locals.insert(local); - self.any_replacement = true; - } else if place_ref.local != place.local - || place_ref.projection.len() < place.projection.len() - { + if let Some(new_local) = self.try_as_local(value, location) { + place_ref = PlaceRef { local: new_local, projection: &[] }; + } + + if place_ref.local != place.local || place_ref.projection.len() < place.projection.len() { // By the invariant on `place_ref`. *place = place_ref.project_deeper(&[], self.tcx); self.reused_locals.insert(place_ref.local); - self.any_replacement = true; } Some(value) @@ -344,12 +709,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { location: Location, ) -> Option<VnIndex> { match *operand { - Operand::Constant(ref constant) => Some(self.insert(Value::Constant(constant.const_))), + Operand::Constant(ref mut constant) => { + let const_ = constant.const_.normalize(self.tcx, self.param_env); + self.insert_constant(const_) + } Operand::Copy(ref mut place) | Operand::Move(ref mut place) => { let value = self.simplify_place_value(place, location)?; if let Some(const_) = self.try_as_constant(value) { *operand = Operand::Constant(Box::new(const_)); - self.any_replacement = true; } Some(value) } @@ -378,24 +745,15 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { Value::Repeat(op, amount) } Rvalue::NullaryOp(op, ty) => Value::NullaryOp(op, ty), - Rvalue::Aggregate(box ref kind, ref mut fields) => { - let variant_index = match *kind { - AggregateKind::Array(..) - | AggregateKind::Tuple - | AggregateKind::Closure(..) - | AggregateKind::Coroutine(..) => FIRST_VARIANT, - AggregateKind::Adt(_, variant_index, _, _, None) => variant_index, - // Do not track unions. - AggregateKind::Adt(_, _, _, _, Some(_)) => return None, - }; - let fields: Option<Vec<_>> = fields - .iter_mut() - .map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque())) - .collect(); - let ty = rvalue.ty(self.local_decls, self.tcx); - Value::Aggregate(ty, variant_index, fields?) + Rvalue::Aggregate(..) => return self.simplify_aggregate(rvalue, location), + Rvalue::Ref(_, borrow_kind, ref mut place) => { + self.simplify_place_projection(place, location); + return self.new_pointer(*place, AddressKind::Ref(borrow_kind)); + } + Rvalue::AddressOf(mutbl, ref mut place) => { + self.simplify_place_projection(place, location); + return self.new_pointer(*place, AddressKind::Address(mutbl)); } - Rvalue::Ref(.., place) | Rvalue::AddressOf(_, place) => return self.new_pointer(place), // Operations. Rvalue::Len(ref mut place) => { @@ -405,6 +763,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { Rvalue::Cast(kind, ref mut value, to) => { let from = value.ty(self.local_decls, self.tcx); let value = self.simplify_operand(value, location)?; + if let CastKind::PointerCoercion( + PointerCoercion::ReifyFnPointer | PointerCoercion::ClosureFnPointer(_), + ) = kind + { + // Each reification of a generic fn may get a different pointer. + // Do not try to merge them. + return self.new_opaque(); + } Value::Cast { kind, value, from, to } } Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => { @@ -423,6 +789,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } Rvalue::Discriminant(ref mut place) => { let place = self.simplify_place_value(place, location)?; + if let Some(discr) = self.simplify_discriminant(place) { + return Some(discr); + } Value::Discriminant(place) } @@ -432,45 +801,182 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { debug!(?value); Some(self.insert(value)) } + + fn simplify_discriminant(&mut self, place: VnIndex) -> Option<VnIndex> { + if let Value::Aggregate(enum_ty, variant, _) = *self.get(place) + && let AggregateTy::Def(enum_did, enum_substs) = enum_ty + && let DefKind::Enum = self.tcx.def_kind(enum_did) + { + let enum_ty = self.tcx.type_of(enum_did).instantiate(self.tcx, enum_substs); + let discr = self.ecx.discriminant_for_variant(enum_ty, variant).ok()?; + return Some(self.insert_scalar(discr.to_scalar(), discr.layout.ty)); + } + + None + } + + fn simplify_aggregate( + &mut self, + rvalue: &mut Rvalue<'tcx>, + location: Location, + ) -> Option<VnIndex> { + let Rvalue::Aggregate(box ref kind, ref mut fields) = *rvalue else { bug!() }; + + let tcx = self.tcx; + if fields.is_empty() { + let is_zst = match *kind { + AggregateKind::Array(..) | AggregateKind::Tuple | AggregateKind::Closure(..) => { + true + } + // Only enums can be non-ZST. + AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum, + // Coroutines are never ZST, as they at least contain the implicit states. + AggregateKind::Coroutine(..) => false, + }; + + if is_zst { + let ty = rvalue.ty(self.local_decls, tcx); + return self.insert_constant(Const::zero_sized(ty)); + } + } + + let (ty, variant_index) = match *kind { + AggregateKind::Array(..) => { + assert!(!fields.is_empty()); + (AggregateTy::Array, FIRST_VARIANT) + } + AggregateKind::Tuple => { + assert!(!fields.is_empty()); + (AggregateTy::Tuple, FIRST_VARIANT) + } + AggregateKind::Closure(did, substs) | AggregateKind::Coroutine(did, substs, _) => { + (AggregateTy::Def(did, substs), FIRST_VARIANT) + } + AggregateKind::Adt(did, variant_index, substs, _, None) => { + (AggregateTy::Def(did, substs), variant_index) + } + // Do not track unions. + AggregateKind::Adt(_, _, _, _, Some(_)) => return None, + }; + + let fields: Option<Vec<_>> = fields + .iter_mut() + .map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque())) + .collect(); + let fields = fields?; + + if let AggregateTy::Array = ty && fields.len() > 4 { + let first = fields[0]; + if fields.iter().all(|&v| v == first) { + let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap()); + if let Some(const_) = self.try_as_constant(first) { + *rvalue = Rvalue::Repeat(Operand::Constant(Box::new(const_)), len); + } else if let Some(local) = self.try_as_local(first, location) { + *rvalue = Rvalue::Repeat(Operand::Copy(local.into()), len); + self.reused_locals.insert(local); + } + return Some(self.insert(Value::Repeat(first, len))); + } + } + + Some(self.insert(Value::Aggregate(ty, variant_index, fields))) + } +} + +fn op_to_prop_const<'tcx>( + ecx: &mut InterpCx<'_, 'tcx, DummyMachine>, + op: &OpTy<'tcx>, +) -> Option<ConstValue<'tcx>> { + // Do not attempt to propagate unsized locals. + if op.layout.is_unsized() { + return None; + } + + // This constant is a ZST, just return an empty value. + if op.layout.is_zst() { + return Some(ConstValue::ZeroSized); + } + + // Do not synthetize too large constants. Codegen will just memcpy them, which we'd like to avoid. + if !matches!(op.layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { + return None; + } + + // If this constant has scalar ABI, return it as a `ConstValue::Scalar`. + if let Abi::Scalar(abi::Scalar::Initialized { .. }) = op.layout.abi + && let Ok(scalar) = ecx.read_scalar(op) + && scalar.try_to_int().is_ok() + { + return Some(ConstValue::Scalar(scalar)); + } + + // If this constant is already represented as an `Allocation`, + // try putting it into global memory to return it. + if let Either::Left(mplace) = op.as_mplace_or_imm() { + let (size, _align) = ecx.size_and_align_of_mplace(&mplace).ok()??; + + // Do not try interning a value that contains provenance. + // Due to https://github.com/rust-lang/rust/issues/79738, doing so could lead to bugs. + // FIXME: remove this hack once that issue is fixed. + let alloc_ref = ecx.get_ptr_alloc(mplace.ptr(), size).ok()??; + if alloc_ref.has_provenance() { + return None; + } + + let pointer = mplace.ptr().into_pointer_or_addr().ok()?; + let (alloc_id, offset) = pointer.into_parts(); + intern_const_alloc_for_constprop(ecx, alloc_id).ok()?; + if matches!(ecx.tcx.global_alloc(alloc_id), GlobalAlloc::Memory(_)) { + // `alloc_id` may point to a static. Codegen will choke on an `Indirect` with anything + // by `GlobalAlloc::Memory`, so do fall through to copying if needed. + // FIXME: find a way to treat this more uniformly + // (probably by fixing codegen) + return Some(ConstValue::Indirect { alloc_id, offset }); + } + } + + // Everything failed: create a new allocation to hold the data. + let alloc_id = + ecx.intern_with_temp_alloc(op.layout, |ecx, dest| ecx.copy_op(op, dest, false)).ok()?; + let value = ConstValue::Indirect { alloc_id, offset: Size::ZERO }; + + // Check that we do not leak a pointer. + // Those pointers may lose part of their identity in codegen. + // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed. + if ecx.tcx.global_alloc(alloc_id).unwrap_memory().inner().provenance().ptrs().is_empty() { + return Some(value); + } + + None } impl<'tcx> VnState<'_, 'tcx> { /// If `index` is a `Value::Constant`, return the `Constant` to be put in the MIR. fn try_as_constant(&mut self, index: VnIndex) -> Option<ConstOperand<'tcx>> { - if let Value::Constant(const_) = *self.get(index) { - // Some constants may contain pointers. We need to preserve the provenance of these - // pointers, but not all constants guarantee this: - // - valtrees purposefully do not; - // - ConstValue::Slice does not either. - match const_ { - Const::Ty(c) => match c.kind() { - ty::ConstKind::Value(valtree) => match valtree { - // This is just an integer, keep it. - ty::ValTree::Leaf(_) => {} - ty::ValTree::Branch(_) => return None, - }, - ty::ConstKind::Param(..) - | ty::ConstKind::Unevaluated(..) - | ty::ConstKind::Expr(..) => {} - // Should not appear in runtime MIR. - ty::ConstKind::Infer(..) - | ty::ConstKind::Bound(..) - | ty::ConstKind::Placeholder(..) - | ty::ConstKind::Error(..) => bug!(), - }, - Const::Unevaluated(..) => {} - // If the same slice appears twice in the MIR, we cannot guarantee that we will - // give the same `AllocId` to the data. - Const::Val(ConstValue::Slice { .. }, _) => return None, - Const::Val( - ConstValue::ZeroSized | ConstValue::Scalar(_) | ConstValue::Indirect { .. }, - _, - ) => {} - } - Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_ }) - } else { - None + // This was already constant in MIR, do not change it. + if let Value::Constant { value, disambiguator: _ } = *self.get(index) + // If the constant is not deterministic, adding an additional mention of it in MIR will + // not give the same value as the former mention. + && value.is_deterministic() + { + return Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_: value }); } + + let op = self.evaluated[index].as_ref()?; + if op.layout.is_unsized() { + // Do not attempt to propagate unsized locals. + return None; + } + + let value = op_to_prop_const(&mut self.ecx, op)?; + + // Check that we do not leak a pointer. + // Those pointers may lose part of their identity in codegen. + // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed. + assert!(!value.may_have_provenance(self.tcx, op.layout.size)); + + let const_ = Const::Val(value, op.layout.ty); + Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_ }) } /// If there is a local which is assigned `index`, and its assignment strictly dominates `loc`, @@ -489,27 +995,32 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, 'tcx> { self.tcx } + fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, location: Location) { + self.simplify_place_projection(place, location); + } + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { self.simplify_operand(operand, location); } fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) { - self.super_statement(stmt, location); if let StatementKind::Assign(box (_, ref mut rvalue)) = stmt.kind // Do not try to simplify a constant, it's already in canonical shape. && !matches!(rvalue, Rvalue::Use(Operand::Constant(_))) - && let Some(value) = self.simplify_rvalue(rvalue, location) { - if let Some(const_) = self.try_as_constant(value) { - *rvalue = Rvalue::Use(Operand::Constant(Box::new(const_))); - self.any_replacement = true; - } else if let Some(local) = self.try_as_local(value, location) - && *rvalue != Rvalue::Use(Operand::Move(local.into())) + if let Some(value) = self.simplify_rvalue(rvalue, location) { - *rvalue = Rvalue::Use(Operand::Copy(local.into())); - self.reused_locals.insert(local); - self.any_replacement = true; + if let Some(const_) = self.try_as_constant(value) { + *rvalue = Rvalue::Use(Operand::Constant(Box::new(const_))); + } else if let Some(local) = self.try_as_local(value, location) + && *rvalue != Rvalue::Use(Operand::Move(local.into())) + { + *rvalue = Rvalue::Use(Operand::Copy(local.into())); + self.reused_locals.insert(local); + } } + } else { + self.super_statement(stmt, location); } } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 277060573bc..793dcf0d994 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -439,6 +439,11 @@ impl<'tcx> Inliner<'tcx> { } if callee_attrs.target_features != self.codegen_fn_attrs.target_features { + // In general it is not correct to inline a callee with target features that are a + // subset of the caller. This is because the callee might contain calls, and the ABI of + // those calls depends on the target features of the surrounding function. By moving a + // `Call` terminator from one MIR body to another with more target features, we might + // change the ABI of that call! return Err("incompatible target features"); } diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 68b8911824c..bf5f0ca7cbd 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -2,6 +2,7 @@ #![deny(rustc::untranslatable_diagnostic)] #![deny(rustc::diagnostic_outside_of_impl)] #![feature(box_patterns)] +#![feature(cow_is_borrowed)] #![feature(decl_macro)] #![feature(is_sorted)] #![feature(let_chains)] @@ -567,10 +568,11 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &[ &check_alignment::CheckAlignment, &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first - &unreachable_prop::UnreachablePropagation, + &inline::Inline, + // Substitutions during inlining may introduce switch on enums with uninhabited branches. &uninhabited_enum_branching::UninhabitedEnumBranching, + &unreachable_prop::UnreachablePropagation, &o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching), - &inline::Inline, &remove_storage_markers::RemoveStorageMarkers, &remove_zsts::RemoveZsts, &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering @@ -590,6 +592,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &separate_const_switch::SeparateConstSwitch, &const_prop::ConstProp, &gvn::GVN, + &simplify::SimplifyLocals::AfterGVN, &dataflow_const_prop::DataflowConstProp, &const_debuginfo::ConstDebugInfo, &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 2400cfa21fb..4ae5ea4c8d6 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -69,8 +69,8 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? - if let Some(ty::Coroutine(gen_def_id, args, _)) = ty.map(Ty::kind) { - let body = tcx.optimized_mir(*gen_def_id).coroutine_drop().unwrap(); + if let Some(ty::Coroutine(coroutine_def_id, args, _)) = ty.map(Ty::kind) { + let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap(); let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); debug!("make_shim({:?}) = {:?}", instance, body); @@ -392,8 +392,8 @@ fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) - _ if is_copy => builder.copy_shim(), ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()), ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()), - ty::Coroutine(gen_def_id, args, hir::Movability::Movable) => { - builder.coroutine_shim(dest, src, *gen_def_id, args.as_coroutine()) + ty::Coroutine(coroutine_def_id, args, hir::Movability::Movable) => { + builder.coroutine_shim(dest, src, *coroutine_def_id, args.as_coroutine()) } _ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty), }; @@ -597,7 +597,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { &mut self, dest: Place<'tcx>, src: Place<'tcx>, - gen_def_id: DefId, + coroutine_def_id: DefId, args: CoroutineArgs<'tcx>, ) { self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false); @@ -607,8 +607,8 @@ impl<'tcx> CloneShimBuilder<'tcx> { let unwind = self.clone_fields(dest, src, switch, unwind, args.upvar_tys()); let target = self.block(vec![], TerminatorKind::Return, false); let unreachable = self.block(vec![], TerminatorKind::Unreachable, false); - let mut cases = Vec::with_capacity(args.state_tys(gen_def_id, self.tcx).count()); - for (index, state_tys) in args.state_tys(gen_def_id, self.tcx).enumerate() { + let mut cases = Vec::with_capacity(args.state_tys(coroutine_def_id, self.tcx).count()); + for (index, state_tys) in args.state_tys(coroutine_def_id, self.tcx).enumerate() { let variant_index = VariantIdx::new(index); let dest = self.tcx.mk_place_downcast_unnamed(dest, variant_index); let src = self.tcx.mk_place_downcast_unnamed(src, variant_index); diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs index 88c89e106fd..0a1c011147a 100644 --- a/compiler/rustc_mir_transform/src/simplify.rs +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -366,6 +366,7 @@ pub fn remove_dead_blocks(body: &mut Body<'_>) { pub enum SimplifyLocals { BeforeConstProp, + AfterGVN, Final, } @@ -373,6 +374,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyLocals { fn name(&self) -> &'static str { match &self { SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop", + SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering", SimplifyLocals::Final => "SimplifyLocals-final", } } diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs index b508cd1c9cc..1f0e605c3b8 100644 --- a/compiler/rustc_mir_transform/src/simplify_branches.rs +++ b/compiler/rustc_mir_transform/src/simplify_branches.rs @@ -16,8 +16,25 @@ impl<'tcx> MirPass<'tcx> for SimplifyConstCondition { } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running SimplifyConstCondition on {:?}", body.source); let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - for block in body.basic_blocks_mut() { + 'blocks: for block in body.basic_blocks_mut() { + for stmt in block.statements.iter_mut() { + if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind + && let NonDivergingIntrinsic::Assume(discr) = intrinsic + && let Operand::Constant(ref c) = discr + && let Some(constant) = c.const_.try_eval_bool(tcx, param_env) + { + if constant { + stmt.make_nop(); + } else { + block.statements.clear(); + block.terminator_mut().kind = TerminatorKind::Unreachable; + continue 'blocks; + } + } + } + let terminator = block.terminator_mut(); terminator.kind = match terminator.kind { TerminatorKind::SwitchInt { diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs index cb028a92d49..98f67e18a8d 100644 --- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs +++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs @@ -3,8 +3,7 @@ use crate::MirPass; use rustc_data_structures::fx::FxHashSet; use rustc_middle::mir::{ - BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator, - TerminatorKind, + BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind, }; use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::ty::{Ty, TyCtxt}; @@ -30,17 +29,20 @@ fn get_switched_on_type<'tcx>( let terminator = block_data.terminator(); // Only bother checking blocks which terminate by switching on a local. - if let Some(local) = get_discriminant_local(&terminator.kind) - && let [.., stmt_before_term] = &block_data.statements[..] - && let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind + let local = get_discriminant_local(&terminator.kind)?; + + let stmt_before_term = block_data.statements.last()?; + + if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind && l.as_local() == Some(local) - && let ty = place.ty(body, tcx).ty - && ty.is_enum() { - Some(ty) - } else { - None + let ty = place.ty(body, tcx).ty; + if ty.is_enum() { + return Some(ty); + } } + + None } fn variant_discriminants<'tcx>( @@ -67,28 +69,6 @@ fn variant_discriminants<'tcx>( } } -/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new -/// bb to use as the new target if not. -fn ensure_otherwise_unreachable<'tcx>( - body: &Body<'tcx>, - targets: &SwitchTargets, -) -> Option<BasicBlockData<'tcx>> { - let otherwise = targets.otherwise(); - let bb = &body.basic_blocks[otherwise]; - if bb.terminator().kind == TerminatorKind::Unreachable - && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_))) - { - return None; - } - - let mut new_block = BasicBlockData::new(Some(Terminator { - source_info: bb.terminator().source_info, - kind: TerminatorKind::Unreachable, - })); - new_block.is_cleanup = bb.is_cleanup; - Some(new_block) -} - impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { sess.mir_opt_level() > 0 @@ -97,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("UninhabitedEnumBranching starting for {:?}", body.source); - for bb in body.basic_blocks.indices() { + let mut removable_switchs = Vec::new(); + + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { trace!("processing block {:?}", bb); - let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body) - else { + if bb_data.is_cleanup { continue; - }; + } + + let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue }; let layout = tcx.layout_of( tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty), @@ -117,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { trace!("allowed_variants = {:?}", allowed_variants); - if let TerminatorKind::SwitchInt { targets, .. } = - &mut body.basic_blocks_mut()[bb].terminator_mut().kind - { - let mut new_targets = SwitchTargets::new( - targets.iter().filter(|(val, _)| allowed_variants.contains(val)), - targets.otherwise(), - ); - - if new_targets.iter().count() == allowed_variants.len() { - if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) { - let new_otherwise = body.basic_blocks_mut().push(updated); - *new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise; - } - } + let terminator = bb_data.terminator(); + let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() }; - if let TerminatorKind::SwitchInt { targets, .. } = - &mut body.basic_blocks_mut()[bb].terminator_mut().kind - { - *targets = new_targets; + let mut reachable_count = 0; + for (index, (val, _)) in targets.iter().enumerate() { + if allowed_variants.contains(&val) { + reachable_count += 1; } else { - unreachable!() + removable_switchs.push((bb, index)); } - } else { - unreachable!() } + + if reachable_count == allowed_variants.len() { + removable_switchs.push((bb, targets.iter().count())); + } + } + + if removable_switchs.is_empty() { + return; + } + + let new_block = BasicBlockData::new(Some(Terminator { + source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info, + kind: TerminatorKind::Unreachable, + })); + let unreachable_block = body.basic_blocks.as_mut().push(new_block); + + for (bb, index) in removable_switchs { + let bb = &mut body.basic_blocks.as_mut()[bb]; + let terminator = bb.terminator_mut(); + let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() }; + targets.all_targets_mut()[index] = unreachable_block; } } } diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index ea7aafd866b..919e8d6a234 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -2,11 +2,13 @@ //! when all of their successors are unreachable. This is achieved through a //! post-order traversal of the blocks. -use crate::simplify; use crate::MirPass; -use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_target::abi::Size; pub struct UnreachablePropagation; @@ -21,106 +23,133 @@ impl MirPass<'_> for UnreachablePropagation { } fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut patch = MirPatch::new(body); let mut unreachable_blocks = FxHashSet::default(); - let mut replacements = FxHashMap::default(); for (bb, bb_data) in traversal::postorder(body) { let terminator = bb_data.terminator(); - if terminator.kind == TerminatorKind::Unreachable { - unreachable_blocks.insert(bb); - } else { - let is_unreachable = |succ: BasicBlock| unreachable_blocks.contains(&succ); - let terminator_kind_opt = remove_successors(&terminator.kind, is_unreachable); - - if let Some(terminator_kind) = terminator_kind_opt { - if terminator_kind == TerminatorKind::Unreachable { - unreachable_blocks.insert(bb); - } - replacements.insert(bb, terminator_kind); + let is_unreachable = match &terminator.kind { + TerminatorKind::Unreachable => true, + // This will unconditionally run into an unreachable and is therefore unreachable as well. + TerminatorKind::Goto { target } if unreachable_blocks.contains(target) => { + patch.patch_terminator(bb, TerminatorKind::Unreachable); + true + } + // Try to remove unreachable targets from the switch. + TerminatorKind::SwitchInt { .. } => { + remove_successors_from_switch(tcx, bb, &unreachable_blocks, body, &mut patch) } + _ => false, + }; + if is_unreachable { + unreachable_blocks.insert(bb); } } + if !tcx + .consider_optimizing(|| format!("UnreachablePropagation {:?} ", body.source.def_id())) + { + return; + } + + patch.apply(body); + // We do want do keep some unreachable blocks, but make them empty. for bb in unreachable_blocks { - if !tcx.consider_optimizing(|| { - format!("UnreachablePropagation {:?} ", body.source.def_id()) - }) { - break; - } - body.basic_blocks_mut()[bb].statements.clear(); } + } +} - let replaced = !replacements.is_empty(); +/// Return whether the current terminator is fully unreachable. +fn remove_successors_from_switch<'tcx>( + tcx: TyCtxt<'tcx>, + bb: BasicBlock, + unreachable_blocks: &FxHashSet<BasicBlock>, + body: &Body<'tcx>, + patch: &mut MirPatch<'tcx>, +) -> bool { + let terminator = body.basic_blocks[bb].terminator(); + let TerminatorKind::SwitchInt { discr, targets } = &terminator.kind else { bug!() }; + let source_info = terminator.source_info; + let location = body.terminator_loc(bb); + + let is_unreachable = |bb| unreachable_blocks.contains(&bb); + + // If there are multiple targets, we want to keep information about reachability for codegen. + // For example (see tests/codegen/match-optimizes-away.rs) + // + // pub enum Two { A, B } + // pub fn identity(x: Two) -> Two { + // match x { + // Two::A => Two::A, + // Two::B => Two::B, + // } + // } + // + // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to + // turn it into just `x` later. Without the unreachable, such a transformation would be illegal. + // + // In order to preserve this information, we record reachable and unreachable targets as + // `Assume` statements in MIR. + + let discr_ty = discr.ty(body, tcx); + let discr_size = Size::from_bits(match discr_ty.kind() { + ty::Uint(uint) => uint.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Int(int) => int.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Char => 32, + ty::Bool => 1, + other => bug!("unhandled type: {:?}", other), + }); + + let mut add_assumption = |binop, value| { + let local = patch.new_temp(tcx.types.bool, source_info.span); + let value = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::from_scalar(tcx, Scalar::from_uint(value, discr_size), discr_ty), + })); + let cmp = Rvalue::BinaryOp(binop, Box::new((discr.to_copy(), value))); + patch.add_assign(location, local.into(), cmp); + + let assume = NonDivergingIntrinsic::Assume(Operand::Move(local.into())); + patch.add_statement(location, StatementKind::Intrinsic(Box::new(assume))); + }; - for (bb, terminator_kind) in replacements { - if !tcx.consider_optimizing(|| { - format!("UnreachablePropagation {:?} ", body.source.def_id()) - }) { - break; - } + let otherwise = targets.otherwise(); + let otherwise_unreachable = is_unreachable(otherwise); - body.basic_blocks_mut()[bb].terminator_mut().kind = terminator_kind; + let reachable_iter = targets.iter().filter(|&(value, bb)| { + let is_unreachable = is_unreachable(bb); + // We remove this target from the switch, so record the inequality using `Assume`. + if is_unreachable && !otherwise_unreachable { + add_assumption(BinOp::Ne, value); } - - if replaced { - simplify::remove_dead_blocks(body); + !is_unreachable + }); + + let new_targets = SwitchTargets::new(reachable_iter, otherwise); + + let num_targets = new_targets.all_targets().len(); + let fully_unreachable = num_targets == 1 && otherwise_unreachable; + + let terminator = match (num_targets, otherwise_unreachable) { + // If all targets are unreachable, we can be unreachable as well. + (1, true) => TerminatorKind::Unreachable, + (1, false) => TerminatorKind::Goto { target: otherwise }, + (2, true) => { + // All targets are unreachable except one. Record the equality, and make it a goto. + let (value, target) = new_targets.iter().next().unwrap(); + add_assumption(BinOp::Eq, value); + TerminatorKind::Goto { target } } - } -} - -fn remove_successors<'tcx, F>( - terminator_kind: &TerminatorKind<'tcx>, - is_unreachable: F, -) -> Option<TerminatorKind<'tcx>> -where - F: Fn(BasicBlock) -> bool, -{ - let terminator = match terminator_kind { - // This will unconditionally run into an unreachable and is therefore unreachable as well. - TerminatorKind::Goto { target } if is_unreachable(*target) => TerminatorKind::Unreachable, - TerminatorKind::SwitchInt { targets, discr } => { - let otherwise = targets.otherwise(); - - // If all targets are unreachable, we can be unreachable as well. - if targets.all_targets().iter().all(|bb| is_unreachable(*bb)) { - TerminatorKind::Unreachable - } else if is_unreachable(otherwise) { - // If there are multiple targets, don't delete unreachable branches (like an unreachable otherwise) - // unless otherwise is unreachable, in which case deleting a normal branch causes it to be merged with - // the otherwise, keeping its unreachable. - // This looses information about reachability causing worse codegen. - // For example (see tests/codegen/match-optimizes-away.rs) - // - // pub enum Two { A, B } - // pub fn identity(x: Two) -> Two { - // match x { - // Two::A => Two::A, - // Two::B => Two::B, - // } - // } - // - // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to - // turn it into just `x` later. Without the unreachable, such a transformation would be illegal. - // If the otherwise branch is unreachable, we can delete all other unreachable targets, as they will - // still point to the unreachable and therefore not lose reachability information. - let reachable_iter = targets.iter().filter(|(_, bb)| !is_unreachable(*bb)); - - let new_targets = SwitchTargets::new(reachable_iter, otherwise); - - // No unreachable branches were removed. - if new_targets.all_targets().len() == targets.all_targets().len() { - return None; - } - - TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets } - } else { - // If the otherwise branch is reachable, we don't want to delete any unreachable branches. - return None; - } + _ if num_targets == targets.all_targets().len() => { + // Nothing has changed. + return false; } - _ => return None, + _ => TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets }, }; - Some(terminator) + + patch.patch_terminator(bb, terminator); + fully_unreachable } |
