diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
71 files changed, 7727 insertions, 5218 deletions
diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs index 2502e8b603c..5aed89139e2 100644 --- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -34,14 +34,9 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { return; } - // This pass only runs on functions which themselves cannot unwind, - // forcibly changing the body of the function to structurally provide - // this guarantee by aborting on an unwind. If this function can unwind, - // then there's nothing to do because it already should work correctly. - // // Here we test for this function itself whether its ABI allows // unwinding or not. - let body_ty = tcx.type_of(def_id); + let body_ty = tcx.type_of(def_id).skip_binder(); let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, @@ -56,7 +51,7 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { // example. let mut calls_to_terminate = Vec::new(); let mut cleanups_to_remove = Vec::new(); - for (id, block) in body.basic_blocks().iter_enumerated() { + for (id, block) in body.basic_blocks.iter_enumerated() { if block.is_cleanup { continue; } @@ -74,7 +69,7 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { }; layout::fn_can_unwind(tcx, fn_def_id, sig.abi()) } - TerminatorKind::Drop { .. } | TerminatorKind::DropAndReplace { .. } => { + TerminatorKind::Drop { .. } => { tcx.sess.opts.unstable_opts.panic_in_drop == PanicStrategy::Unwind && layout::fn_can_unwind(tcx, None, Abi::Rust) } @@ -107,31 +102,14 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { } } - // For call instructions which need to be terminated, we insert a - // singular basic block which simply terminates, and then configure the - // `cleanup` attribute for all calls we found to this basic block we - // insert which means that any unwinding that happens in the functions - // will force an abort of the process. - if !calls_to_terminate.is_empty() { - let bb = BasicBlockData { - statements: Vec::new(), - is_cleanup: true, - terminator: Some(Terminator { - source_info: SourceInfo::outermost(body.span), - kind: TerminatorKind::Abort, - }), - }; - let abort_bb = body.basic_blocks_mut().push(bb); - - for bb in calls_to_terminate { - let cleanup = body.basic_blocks_mut()[bb].terminator_mut().unwind_mut().unwrap(); - *cleanup = Some(abort_bb); - } + for id in calls_to_terminate { + let cleanup = body.basic_blocks_mut()[id].terminator_mut().unwind_mut().unwrap(); + *cleanup = UnwindAction::Terminate; } for id in cleanups_to_remove { let cleanup = body.basic_blocks_mut()[id].terminator_mut().unwind_mut().unwrap(); - *cleanup = None; + *cleanup = UnwindAction::Unreachable; } // We may have invalidated some `cleanup` blocks so clean those up now. diff --git a/compiler/rustc_mir_transform/src/add_call_guards.rs b/compiler/rustc_mir_transform/src/add_call_guards.rs index f12c8560c0e..fb4705e0754 100644 --- a/compiler/rustc_mir_transform/src/add_call_guards.rs +++ b/compiler/rustc_mir_transform/src/add_call_guards.rs @@ -1,5 +1,5 @@ use crate::MirPass; -use rustc_index::vec::{Idx, IndexVec}; +use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; @@ -45,15 +45,16 @@ impl AddCallGuards { // We need a place to store the new blocks generated let mut new_blocks = Vec::new(); - let cur_len = body.basic_blocks().len(); + let cur_len = body.basic_blocks.len(); for block in body.basic_blocks_mut() { match block.terminator { Some(Terminator { - kind: TerminatorKind::Call { target: Some(ref mut destination), cleanup, .. }, + kind: TerminatorKind::Call { target: Some(ref mut destination), unwind, .. }, source_info, }) if pred_count[*destination] > 1 - && (cleanup.is_some() || self == &AllCallEdges) => + && (matches!(unwind, UnwindAction::Cleanup(_) | UnwindAction::Terminate) + || self == &AllCallEdges) => { // It's a critical edge, break it let call_guard = BasicBlockData { diff --git a/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs index 8de0aad041c..ef2a0c790e9 100644 --- a/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs +++ b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs @@ -5,37 +5,36 @@ use crate::util; use crate::MirPass; use rustc_middle::mir::patch::MirPatch; -// This pass moves values being dropped that are within a packed -// struct to a separate local before dropping them, to ensure that -// they are dropped from an aligned address. -// -// For example, if we have something like -// ```Rust -// #[repr(packed)] -// struct Foo { -// dealign: u8, -// data: Vec<u8> -// } -// -// let foo = ...; -// ``` -// -// We want to call `drop_in_place::<Vec<u8>>` on `data` from an aligned -// address. This means we can't simply drop `foo.data` directly, because -// its address is not aligned. -// -// Instead, we move `foo.data` to a local and drop that: -// ``` -// storage.live(drop_temp) -// drop_temp = foo.data; -// drop(drop_temp) -> next -// next: -// storage.dead(drop_temp) -// ``` -// -// The storage instructions are required to avoid stack space -// blowup. - +/// This pass moves values being dropped that are within a packed +/// struct to a separate local before dropping them, to ensure that +/// they are dropped from an aligned address. +/// +/// For example, if we have something like +/// ```ignore (illustrative) +/// #[repr(packed)] +/// struct Foo { +/// dealign: u8, +/// data: Vec<u8> +/// } +/// +/// let foo = ...; +/// ``` +/// +/// We want to call `drop_in_place::<Vec<u8>>` on `data` from an aligned +/// address. This means we can't simply drop `foo.data` directly, because +/// its address is not aligned. +/// +/// Instead, we move `foo.data` to a local and drop that: +/// ```ignore (illustrative) +/// storage.live(drop_temp) +/// drop_temp = foo.data; +/// drop(drop_temp) -> next +/// next: +/// storage.dead(drop_temp) +/// ``` +/// +/// The storage instructions are required to avoid stack space +/// blowup. pub struct AddMovesForPackedDrops; impl<'tcx> MirPass<'tcx> for AddMovesForPackedDrops { @@ -55,7 +54,7 @@ fn add_moves_for_packed_drops_patch<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) let mut patch = MirPatch::new(body); let param_env = tcx.param_env(def_id); - for (bb, data) in body.basic_blocks().iter_enumerated() { + for (bb, data) in body.basic_blocks.iter_enumerated() { let loc = Location { block: bb, statement_index: data.statements.len() }; let terminator = data.terminator(); @@ -65,9 +64,6 @@ fn add_moves_for_packed_drops_patch<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { add_move_for_packed_drop(tcx, body, &mut patch, terminator, loc, data.is_cleanup); } - TerminatorKind::DropAndReplace { .. } => { - span_bug!(terminator.source_info.span, "replace in AddMovesForPackedDrops"); - } _ => {} } } @@ -84,7 +80,7 @@ fn add_move_for_packed_drop<'tcx>( is_cleanup: bool, ) { debug!("add_move_for_packed_drop({:?} @ {:?})", terminator, loc); - let TerminatorKind::Drop { ref place, target, unwind } = terminator.kind else { + let TerminatorKind::Drop { ref place, target, unwind, replace } = terminator.kind else { unreachable!(); }; @@ -102,6 +98,11 @@ fn add_move_for_packed_drop<'tcx>( patch.add_assign(loc, Place::from(temp), Rvalue::Use(Operand::Move(*place))); patch.patch_terminator( loc.block, - TerminatorKind::Drop { place: Place::from(temp), target: storage_dead_block, unwind }, + TerminatorKind::Drop { + place: Place::from(temp), + target: storage_dead_block, + unwind, + replace, + }, ); } diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs index b91ae083cf5..187d38b385b 100644 --- a/compiler/rustc_mir_transform/src/add_retag.rs +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -10,29 +10,6 @@ use rustc_middle::ty::{self, Ty, TyCtxt}; pub struct AddRetag; -/// Determines whether this place is "stable": Whether, if we evaluate it again -/// after the assignment, we can be sure to obtain the same place value. -/// (Concurrent accesses by other threads are no problem as these are anyway non-atomic -/// copies. Data races are UB.) -fn is_stable(place: PlaceRef<'_>) -> bool { - place.projection.iter().all(|elem| { - match elem { - // Which place this evaluates to can change with any memory write, - // so cannot assume this to be stable. - ProjectionElem::Deref => false, - // Array indices are interesting, but MIR building generates a *fresh* - // temporary for every array access, so the index cannot be changed as - // a side-effect. - ProjectionElem::Index { .. } | - // The rest is completely boring, they just offset by a constant. - ProjectionElem::Field { .. } | - ProjectionElem::ConstantIndex { .. } | - ProjectionElem::Subslice { .. } | - ProjectionElem::Downcast { .. } => true, - } - }) -} - /// Determine whether this type may contain a reference (or box), and thus needs retagging. /// We will only recurse `depth` times into Tuples/ADTs to bound the cost of this. fn may_contain_reference<'tcx>(ty: Ty<'tcx>, depth: u32, tcx: TyCtxt<'tcx>) -> bool { @@ -79,47 +56,29 @@ impl<'tcx> MirPass<'tcx> for AddRetag { // We need an `AllCallEdges` pass before we can do any work. super::add_call_guards::AllCallEdges.run_pass(tcx, body); - let (span, arg_count) = (body.span, body.arg_count); let basic_blocks = body.basic_blocks.as_mut(); let local_decls = &body.local_decls; let needs_retag = |place: &Place<'tcx>| { - // FIXME: Instead of giving up for unstable places, we should introduce - // a temporary and retag on that. - is_stable(place.as_ref()) + !place.has_deref() // we're not really interested in stores to "outside" locations, they are hard to keep track of anyway && may_contain_reference(place.ty(&*local_decls, tcx).ty, /*depth*/ 3, tcx) && !local_decls[place.local].is_deref_temp() }; - let place_base_raw = |place: &Place<'tcx>| { - // If this is a `Deref`, get the type of what we are deref'ing. - let deref_base = - place.projection.iter().rposition(|p| matches!(p, ProjectionElem::Deref)); - if let Some(deref_base) = deref_base { - let base_proj = &place.projection[..deref_base]; - let ty = Place::ty_from(place.local, base_proj, &*local_decls, tcx).ty; - ty.is_unsafe_ptr() - } else { - // Not a deref, and thus not raw. - false - } - }; // PART 1 // Retag arguments at the beginning of the start block. { - // FIXME: Consider using just the span covering the function - // argument declaration. - let source_info = SourceInfo::outermost(span); // Gather all arguments, skip return value. - let places = local_decls - .iter_enumerated() - .skip(1) - .take(arg_count) - .map(|(local, _)| Place::from(local)) - .filter(needs_retag); + let places = local_decls.iter_enumerated().skip(1).take(body.arg_count).filter_map( + |(local, decl)| { + let place = Place::from(local); + needs_retag(&place).then_some((place, decl.source_info)) + }, + ); + // Emit their retags. basic_blocks[START_BLOCK].statements.splice( 0..0, - places.map(|place| Statement { + places.map(|(place, source_info)| Statement { source_info, kind: StatementKind::Retag(RetagKind::FnEntry, Box::new(place)), }), @@ -127,7 +86,7 @@ impl<'tcx> MirPass<'tcx> for AddRetag { } // PART 2 - // Retag return values of functions. Also escape-to-raw the argument of `drop`. + // Retag return values of functions. // We collect the return destinations because we cannot mutate while iterating. let returns = basic_blocks .iter_mut() @@ -141,7 +100,7 @@ impl<'tcx> MirPass<'tcx> for AddRetag { } // `Drop` is also a call, but it doesn't return anything so we are good. - TerminatorKind::Drop { .. } | TerminatorKind::DropAndReplace { .. } => None, + TerminatorKind::Drop { .. } => None, // Not a block ending in a Call -> ignore. _ => None, } @@ -159,30 +118,25 @@ impl<'tcx> MirPass<'tcx> for AddRetag { } // PART 3 - // Add retag after assignment. + // Add retag after assignments where data "enters" this function: the RHS is behind a deref and the LHS is not. for block_data in basic_blocks { - // We want to insert statements as we iterate. To this end, we + // We want to insert statements as we iterate. To this end, we // iterate backwards using indices. for i in (0..block_data.statements.len()).rev() { let (retag_kind, place) = match block_data.statements[i].kind { - // Retag-as-raw after escaping to a raw pointer, if the referent - // is not already a raw pointer. - StatementKind::Assign(box (lplace, Rvalue::AddressOf(_, ref rplace))) - if !place_base_raw(rplace) => - { - (RetagKind::Raw, lplace) - } // Retag after assignments of reference type. StatementKind::Assign(box (ref place, ref rvalue)) if needs_retag(place) => { - let kind = match rvalue { - Rvalue::Ref(_, borrow_kind, _) - if borrow_kind.allows_two_phase_borrow() => - { - RetagKind::TwoPhase - } - _ => RetagKind::Default, + let add_retag = match rvalue { + // Ptr-creating operations already do their own internal retagging, no + // need to also add a retag statement. + Rvalue::Ref(..) | Rvalue::AddressOf(..) => false, + _ => true, }; - (kind, *place) + if add_retag { + (RetagKind::Default, *place) + } else { + continue; + } } // Do nothing for the rest _ => continue, diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs new file mode 100644 index 00000000000..1fe8ea07892 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_alignment.rs @@ -0,0 +1,242 @@ +use crate::MirPass; +use rustc_hir::def_id::DefId; +use rustc_hir::lang_items::LangItem; +use rustc_index::IndexVec; +use rustc_middle::mir::*; +use rustc_middle::mir::{ + interpret::{ConstValue, Scalar}, + visit::{PlaceContext, Visitor}, +}; +use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut}; +use rustc_session::Session; + +pub struct CheckAlignment; + +impl<'tcx> MirPass<'tcx> for CheckAlignment { + fn is_enabled(&self, sess: &Session) -> bool { + sess.opts.debug_assertions + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // This pass emits new panics. If for whatever reason we do not have a panic + // implementation, running this pass may cause otherwise-valid code to not compile. + if tcx.lang_items().get(LangItem::PanicImpl).is_none() { + return; + } + + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + for block in (0..basic_blocks.len()).rev() { + let block = block.into(); + for statement_index in (0..basic_blocks[block].statements.len()).rev() { + let location = Location { block, statement_index }; + let statement = &basic_blocks[block].statements[statement_index]; + let source_info = statement.source_info; + + let mut finder = PointerFinder { + local_decls, + tcx, + pointers: Vec::new(), + def_id: body.source.def_id(), + }; + for (pointer, pointee_ty) in finder.find_pointers(statement) { + debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty); + + let new_block = split_block(basic_blocks, location); + insert_alignment_check( + tcx, + local_decls, + &mut basic_blocks[block], + pointer, + pointee_ty, + source_info, + new_block, + ); + } + } + } + } +} + +impl<'tcx, 'a> PointerFinder<'tcx, 'a> { + fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> { + self.pointers.clear(); + self.visit_statement(statement, Location::START); + core::mem::take(&mut self.pointers) + } +} + +struct PointerFinder<'tcx, 'a> { + local_decls: &'a mut LocalDecls<'tcx>, + tcx: TyCtxt<'tcx>, + def_id: DefId, + pointers: Vec<(Place<'tcx>, Ty<'tcx>)>, +} + +impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> { + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + if let Rvalue::AddressOf(..) = rvalue { + // Ignore dereferences inside of an AddressOf + return; + } + self.super_rvalue(rvalue, location); + } + + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { + if let PlaceContext::NonUse(_) = context { + return; + } + if !place.is_indirect() { + return; + } + + let pointer = Place::from(place.local); + let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty; + + // We only want to check unsafe pointers + if !pointer_ty.is_unsafe_ptr() { + trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty); + return; + } + + let Some(pointee) = pointer_ty.builtin_deref(true) else { + debug!("Indirect but no builtin deref: {:?}", pointer_ty); + return; + }; + let mut pointee_ty = pointee.ty; + if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() { + pointee_ty = pointee_ty.sequence_element_type(self.tcx); + } + + if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) { + debug!("Unsafe pointer, but unsized: {:?}", pointer_ty); + return; + } + + if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_] + .contains(&pointee_ty) + { + debug!("Trivially aligned pointee type: {:?}", pointer_ty); + return; + } + + self.pointers.push((pointer, pointee_ty)) + } +} + +fn split_block( + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, + location: Location, +) -> BasicBlock { + let block_data = &mut basic_blocks[location.block]; + + // Drain every statement after this one and move the current terminator to a new basic block + let new_block = BasicBlockData { + statements: block_data.statements.split_off(location.statement_index), + terminator: block_data.terminator.take(), + is_cleanup: block_data.is_cleanup, + }; + + basic_blocks.push(new_block) +} + +fn insert_alignment_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + block_data: &mut BasicBlockData<'tcx>, + pointer: Place<'tcx>, + pointee_ty: Ty<'tcx>, + source_info: SourceInfo, + new_block: BasicBlock, +) { + // Cast the pointer to a *const () + let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not }); + let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); + let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); + block_data + .statements + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); + + // Transmute the pointer to a usize (equivalent to `ptr.addr()`) + let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); + let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + block_data + .statements + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); + + // Get the alignment of the pointee + let alignment = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((alignment, rvalue))), + }); + + // Subtract 1 from the alignment to get the alignment mask + let alignment_mask = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + let one = Operand::Constant(Box::new(Constant { + span: source_info.span, + user_ty: None, + literal: ConstantKind::Val( + ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), + tcx.types.usize, + ), + })); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + alignment_mask, + Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))), + ))), + }); + + // BitAnd the alignment mask with the pointer + let alignment_bits = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + alignment_bits, + Rvalue::BinaryOp( + BinOp::BitAnd, + Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))), + ), + ))), + }); + + // Check if the alignment bits are all zero + let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + let zero = Operand::Constant(Box::new(Constant { + span: source_info.span, + user_ty: None, + literal: ConstantKind::Val( + ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), + tcx.types.usize, + ), + })); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))), + ))), + }); + + // Set this block's terminator to our assert, continuing to new_block if we pass + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Copy(is_ok), + expected: true, + target: new_block, + msg: Box::new(AssertKind::MisalignedPointerDereference { + required: Operand::Copy(alignment), + found: Operand::Copy(addr), + }), + unwind: UnwindAction::Terminate, + }, + }); +} diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs index 8838b14c53a..b79150737d6 100644 --- a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs +++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs @@ -1,11 +1,12 @@ -use rustc_errors::{DiagnosticBuilder, LintDiagnosticBuilder}; +use rustc_hir::HirId; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; use rustc_session::lint::builtin::CONST_ITEM_MUTATION; use rustc_span::def_id::DefId; +use rustc_span::Span; -use crate::MirLint; +use crate::{errors, MirLint}; pub struct CheckConstItemMutation; @@ -24,7 +25,7 @@ struct ConstMutationChecker<'a, 'tcx> { impl<'tcx> ConstMutationChecker<'_, 'tcx> { fn is_const_item(&self, local: Local) -> Option<DefId> { - if let Some(box LocalInfo::ConstRef { def_id }) = self.body.local_decls[local].local_info { + if let LocalInfo::ConstRef { def_id } = *self.body.local_decls[local].local_info() { Some(def_id) } else { None @@ -58,20 +59,21 @@ impl<'tcx> ConstMutationChecker<'_, 'tcx> { } } - fn lint_const_item_usage( + /// If we should lint on this usage, return the [`HirId`], source [`Span`] + /// and [`Span`] of the const item to use in the lint. + fn should_lint_const_item_usage( &self, place: &Place<'tcx>, const_item: DefId, location: Location, - decorate: impl for<'b> FnOnce(LintDiagnosticBuilder<'b, ()>) -> DiagnosticBuilder<'b, ()>, - ) { + ) -> Option<(HirId, Span, Span)> { // Don't lint on borrowing/assigning when a dereference is involved. // If we 'leave' the temporary via a dereference, we must // be modifying something else // // `unsafe { *FOO = 0; *BAR.field = 1; }` // `unsafe { &mut *FOO }` - // `unsafe { (*ARRAY)[0] = val; } + // `unsafe { (*ARRAY)[0] = val; }` if !place.projection.iter().any(|p| matches!(p, PlaceElem::Deref)) { let source_info = self.body.source_info(location); let lint_root = self.body.source_scopes[source_info.scope] @@ -80,16 +82,9 @@ impl<'tcx> ConstMutationChecker<'_, 'tcx> { .assert_crate_local() .lint_root; - self.tcx.struct_span_lint_hir( - CONST_ITEM_MUTATION, - lint_root, - source_info.span, - |lint| { - decorate(lint) - .span_note(self.tcx.def_span(const_item), "`const` item defined here") - .emit(); - }, - ); + Some((lint_root, source_info.span, self.tcx.def_span(const_item))) + } else { + None } } } @@ -101,12 +96,14 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { // Assigning directly to a constant (e.g. `FOO = true;`) is a hard error, // so emitting a lint would be redundant. if !lhs.projection.is_empty() { - if let Some(def_id) = self.is_const_item_without_destructor(lhs.local) { - self.lint_const_item_usage(&lhs, def_id, loc, |lint| { - let mut lint = lint.build("attempting to modify a `const` item"); - lint.note("each usage of a `const` item creates a new temporary; the original `const` item will not be modified"); - lint - }) + if let Some(def_id) = self.is_const_item_without_destructor(lhs.local) + && let Some((lint_root, span, item)) = self.should_lint_const_item_usage(&lhs, def_id, loc) { + self.tcx.emit_spanned_lint( + CONST_ITEM_MUTATION, + lint_root, + span, + errors::ConstMutate::Modify { konst: item } + ); } } // We are looking for MIR of the form: @@ -133,22 +130,31 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { // the `self` parameter of a method call (as the terminator of our current // BasicBlock). If so, we emit a more specific lint. let method_did = self.target_local.and_then(|target_local| { - crate::util::find_self_call(self.tcx, &self.body, target_local, loc.block) + rustc_middle::util::find_self_call( + self.tcx, + &self.body, + target_local, + loc.block, + ) }); let lint_loc = if method_did.is_some() { self.body.terminator_loc(loc.block) } else { loc }; - self.lint_const_item_usage(place, def_id, lint_loc, |lint| { - let mut lint = lint.build("taking a mutable reference to a `const` item"); - lint - .note("each usage of a `const` item creates a new temporary") - .note("the mutable reference will refer to this temporary, not the original `const` item"); - - if let Some((method_did, _substs)) = method_did { - lint.span_note(self.tcx.def_span(method_did), "mutable reference created due to call to this method"); - } - lint - }); + let method_call = if let Some((method_did, _)) = method_did { + Some(self.tcx.def_span(method_did)) + } else { + None + }; + if let Some((lint_root, span, item)) = + self.should_lint_const_item_usage(place, def_id, lint_loc) + { + self.tcx.emit_spanned_lint( + CONST_ITEM_MUTATION, + lint_root, + span, + errors::ConstMutate::MutBorrow { method_call, konst: item }, + ); + } } } self.super_rvalue(rvalue, loc); diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs index 2eb38941f1a..2e6cf603d59 100644 --- a/compiler/rustc_mir_transform/src/check_packed_ref.rs +++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs @@ -1,16 +1,9 @@ -use rustc_hir::def_id::LocalDefId; use rustc_middle::mir::visit::{PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::ty::query::Providers; use rustc_middle::ty::{self, TyCtxt}; -use rustc_session::lint::builtin::UNALIGNED_REFERENCES; -use crate::util; use crate::MirLint; - -pub(crate) fn provide(providers: &mut Providers) { - *providers = Providers { unsafe_derive_on_repr_packed, ..*providers }; -} +use crate::{errors, util}; pub struct CheckPackedRef; @@ -30,23 +23,6 @@ struct PackedRefChecker<'a, 'tcx> { source_info: SourceInfo, } -fn unsafe_derive_on_repr_packed(tcx: TyCtxt<'_>, def_id: LocalDefId) { - let lint_hir_id = tcx.hir().local_def_id_to_hir_id(def_id); - - tcx.struct_span_lint_hir(UNALIGNED_REFERENCES, lint_hir_id, tcx.def_span(def_id), |lint| { - // FIXME: when we make this a hard error, this should have its - // own error code. - let message = if tcx.generics_of(def_id).own_requires_monomorphization() { - "`#[derive]` can't be used on a `#[repr(packed)]` struct with \ - type or const parameters (error E0133)" - } else { - "`#[derive]` can't be used on a `#[repr(packed)]` struct that \ - does not derive Copy (error E0133)" - }; - lint.build(message).emit(); - }); -} - impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> { fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { // Make sure we know where in the MIR we are. @@ -64,40 +40,15 @@ impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> { if context.is_borrow() { if util::is_disaligned(self.tcx, self.body, self.param_env, *place) { let def_id = self.body.source.instance.def_id(); - if let Some(impl_def_id) = self - .tcx - .impl_of_method(def_id) - .filter(|&def_id| self.tcx.is_builtin_derive(def_id)) + if let Some(impl_def_id) = self.tcx.impl_of_method(def_id) + && self.tcx.is_builtin_derived(impl_def_id) { - // If a method is defined in the local crate, - // the impl containing that method should also be. - self.tcx.ensure().unsafe_derive_on_repr_packed(impl_def_id.expect_local()); + // If we ever reach here it means that the generated derive + // code is somehow doing an unaligned reference, which it + // shouldn't do. + span_bug!(self.source_info.span, "builtin derive created an unaligned reference"); } else { - let source_info = self.source_info; - let lint_root = self.body.source_scopes[source_info.scope] - .local_data - .as_ref() - .assert_crate_local() - .lint_root; - self.tcx.struct_span_lint_hir( - UNALIGNED_REFERENCES, - lint_root, - source_info.span, - |lint| { - lint.build("reference to packed field is unaligned") - .note( - "fields of packed structs are not properly aligned, and creating \ - a misaligned reference is undefined behavior (even if that \ - reference is never dereferenced)", - ) - .help( - "copy the field contents to a local variable, or replace the \ - reference with a raw pointer and use `read_unaligned`/`write_unaligned` \ - (loads and stores via `*p` must be properly aligned even when using raw pointers)" - ) - .emit(); - }, - ); + self.tcx.sess.emit_err(errors::UnalignedPackedRef { span: self.source_info.span }); } } } diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs index ded1f0462cb..2f851cd1eb5 100644 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -1,19 +1,21 @@ -use rustc_data_structures::fx::FxHashMap; -use rustc_errors::struct_span_err; +use rustc_data_structures::unord::{UnordItems, UnordSet}; use rustc_hir as hir; +use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_hir::hir_id::HirId; use rustc_hir::intravisit; +use rustc_hir::{BlockCheckMode, ExprKind, Node}; use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor}; -use rustc_middle::ty::query::Providers; +use rustc_middle::mir::*; +use rustc_middle::query::Providers; use rustc_middle::ty::{self, TyCtxt}; -use rustc_middle::{lint, mir::*}; use rustc_session::lint::builtin::{UNSAFE_OP_IN_UNSAFE_FN, UNUSED_UNSAFE}; use rustc_session::lint::Level; -use std::collections::hash_map; use std::ops::Bound; +use crate::errors; + pub struct UnsafetyChecker<'a, 'tcx> { body: &'a Body<'tcx>, body_did: LocalDefId, @@ -23,10 +25,7 @@ pub struct UnsafetyChecker<'a, 'tcx> { param_env: ty::ParamEnv<'tcx>, /// Used `unsafe` blocks in this function. This is used for the "unused_unsafe" lint. - /// - /// The keys are the used `unsafe` blocks, the UnusedUnsafeKind indicates whether - /// or not any of the usages happen at a place that doesn't allow `unsafe_op_in_unsafe_fn`. - used_unsafe_blocks: FxHashMap<HirId, UsedUnsafeBlockData>, + used_unsafe_blocks: UnordSet<HirId>, } impl<'a, 'tcx> UnsafetyChecker<'a, 'tcx> { @@ -57,10 +56,9 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { | TerminatorKind::Drop { .. } | TerminatorKind::Yield { .. } | TerminatorKind::Assert { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::GeneratorDrop | TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::FalseEdge { .. } @@ -103,13 +101,16 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { | StatementKind::StorageLive(..) | StatementKind::StorageDead(..) | StatementKind::Retag { .. } - | StatementKind::AscribeUserType(..) + | StatementKind::PlaceMention(..) | StatementKind::Coverage(..) + | StatementKind::Intrinsic(..) + | StatementKind::ConstEvalCounter | StatementKind::Nop => { // safe (at least as emitted during MIR construction) } - - StatementKind::CopyNonOverlapping(..) => unreachable!(), + // `AscribeUserType` just exists to help MIR borrowck. + // It has no semantics, and everything is already reported by `PlaceMention`. + StatementKind::AscribeUserType(..) => return, } self.super_statement(statement, location); } @@ -128,12 +129,10 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { } } &AggregateKind::Closure(def_id, _) | &AggregateKind::Generator(def_id, _, _) => { + let def_id = def_id.expect_local(); let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = - self.tcx.unsafety_check_result(def_id.expect_local()); - self.register_violations( - violations, - used_unsafe_blocks.iter().map(|(&h, &d)| (h, d)), - ); + self.tcx.unsafety_check_result(def_id); + self.register_violations(violations, used_unsafe_blocks.items().copied()); } }, _ => {} @@ -141,6 +140,28 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { self.super_rvalue(rvalue, location); } + fn visit_operand(&mut self, op: &Operand<'tcx>, location: Location) { + if let Operand::Constant(constant) = op { + let maybe_uneval = match constant.literal { + ConstantKind::Val(..) | ConstantKind::Ty(_) => None, + ConstantKind::Unevaluated(uv, _) => Some(uv), + }; + + if let Some(uv) = maybe_uneval { + if uv.promoted.is_none() { + let def_id = uv.def; + if self.tcx.def_kind(def_id) == DefKind::InlineConst { + let local_def_id = def_id.expect_local(); + let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = + self.tcx.unsafety_check_result(local_def_id); + self.register_violations(violations, used_unsafe_blocks.items().copied()); + } + } + } + } + self.super_operand(op, location); + } + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { // On types with `scalar_valid_range`, prevent // * `&mut x.field` @@ -162,7 +183,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { // If the projection root is an artificial local that we introduced when // desugaring `static`, give a more specific error message // (avoid the general "raw pointer" clause below, that would only be confusing). - if let Some(box LocalInfo::StaticRef { def_id, .. }) = decl.local_info { + if let LocalInfo::StaticRef { def_id, .. } = *decl.local_info() { if self.tcx.is_mutable_static(def_id) { self.require_unsafe( UnsafetyViolationKind::General, @@ -219,14 +240,11 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { // We have to check the actual type of the assignment, as that determines if the // old value is being dropped. let assigned_ty = place.ty(&self.body.local_decls, self.tcx).ty; - if assigned_ty.needs_drop( - self.tcx, - self.tcx.param_env(base_ty.ty_adt_def().unwrap().did()), - ) { + if assigned_ty.needs_drop(self.tcx, self.param_env) { // This would be unsafe, but should be outright impossible since we reject such unions. self.tcx.sess.delay_span_bug( self.source_info.span, - "union fields that need dropping should be impossible", + format!("union fields that need dropping should be impossible: {assigned_ty}") ); } } else { @@ -253,29 +271,15 @@ impl<'tcx> UnsafetyChecker<'_, 'tcx> { .lint_root; self.register_violations( [&UnsafetyViolation { source_info, lint_root, kind, details }], - [], + UnordItems::empty(), ); } fn register_violations<'a>( &mut self, violations: impl IntoIterator<Item = &'a UnsafetyViolation>, - new_used_unsafe_blocks: impl IntoIterator<Item = (HirId, UsedUnsafeBlockData)>, + new_used_unsafe_blocks: UnordItems<HirId, impl Iterator<Item = HirId>>, ) { - use UsedUnsafeBlockData::{AllAllowedInUnsafeFn, SomeDisallowedInUnsafeFn}; - - let update_entry = |this: &mut Self, hir_id, new_usage| { - match this.used_unsafe_blocks.entry(hir_id) { - hash_map::Entry::Occupied(mut entry) => { - if new_usage == SomeDisallowedInUnsafeFn { - *entry.get_mut() = SomeDisallowedInUnsafeFn; - } - } - hash_map::Entry::Vacant(entry) => { - entry.insert(new_usage); - } - }; - }; let safety = self.body.source_scopes[self.source_info.scope] .local_data .as_ref() @@ -302,22 +306,12 @@ impl<'tcx> UnsafetyChecker<'_, 'tcx> { } }), Safety::BuiltinUnsafe => {} - Safety::ExplicitUnsafe(hir_id) => violations.into_iter().for_each(|violation| { - update_entry( - self, - hir_id, - match self.tcx.lint_level_at_node(UNSAFE_OP_IN_UNSAFE_FN, violation.lint_root).0 - { - Level::Allow => AllAllowedInUnsafeFn(violation.lint_root), - _ => SomeDisallowedInUnsafeFn, - }, - ) + Safety::ExplicitUnsafe(hir_id) => violations.into_iter().for_each(|_violation| { + self.used_unsafe_blocks.insert(hir_id); }), }; - new_used_unsafe_blocks - .into_iter() - .for_each(|(hir_id, usage_data)| update_entry(self, hir_id, usage_data)); + self.used_unsafe_blocks.extend_unord(new_used_unsafe_blocks); } fn check_mut_borrowing_layout_constrained_field( &mut self, @@ -343,7 +337,7 @@ impl<'tcx> UnsafetyChecker<'_, 'tcx> { } else if !place .ty(self.body, self.tcx) .ty - .is_freeze(self.tcx.at(self.source_info.span), self.param_env) + .is_freeze(self.tcx, self.param_env) { UnsafetyViolationDetails::BorrowOfLayoutConstrainedField } else { @@ -382,22 +376,7 @@ impl<'tcx> UnsafetyChecker<'_, 'tcx> { } pub(crate) fn provide(providers: &mut Providers) { - *providers = Providers { - unsafety_check_result: |tcx, def_id| { - if let Some(def) = ty::WithOptConstParam::try_lookup(def_id, tcx) { - tcx.unsafety_check_result_for_const_arg(def) - } else { - unsafety_check_result(tcx, ty::WithOptConstParam::unknown(def_id)) - } - }, - unsafety_check_result_for_const_arg: |tcx, (did, param_did)| { - unsafety_check_result( - tcx, - ty::WithOptConstParam { did, const_param_did: Some(param_did) }, - ) - }, - ..*providers - }; + *providers = Providers { unsafety_check_result, ..*providers }; } /// Context information for [`UnusedUnsafeVisitor`] traversal, @@ -414,47 +393,45 @@ enum Context { struct UnusedUnsafeVisitor<'a, 'tcx> { tcx: TyCtxt<'tcx>, - used_unsafe_blocks: &'a FxHashMap<HirId, UsedUnsafeBlockData>, + used_unsafe_blocks: &'a UnordSet<HirId>, context: Context, unused_unsafes: &'a mut Vec<(HirId, UnusedUnsafe)>, } impl<'tcx> intravisit::Visitor<'tcx> for UnusedUnsafeVisitor<'_, 'tcx> { fn visit_block(&mut self, block: &'tcx hir::Block<'tcx>) { - use UsedUnsafeBlockData::{AllAllowedInUnsafeFn, SomeDisallowedInUnsafeFn}; - if let hir::BlockCheckMode::UnsafeBlock(hir::UnsafeSource::UserProvided) = block.rules { let used = match self.tcx.lint_level_at_node(UNUSED_UNSAFE, block.hir_id) { - (Level::Allow, _) => Some(SomeDisallowedInUnsafeFn), - _ => self.used_unsafe_blocks.get(&block.hir_id).copied(), + (Level::Allow, _) => true, + _ => self.used_unsafe_blocks.contains(&block.hir_id), }; let unused_unsafe = match (self.context, used) { - (_, None) => UnusedUnsafe::Unused, - (Context::Safe, Some(_)) - | (Context::UnsafeFn(_), Some(SomeDisallowedInUnsafeFn)) => { + (_, false) => UnusedUnsafe::Unused, + (Context::Safe, true) | (Context::UnsafeFn(_), true) => { let previous_context = self.context; self.context = Context::UnsafeBlock(block.hir_id); intravisit::walk_block(self, block); self.context = previous_context; return; } - (Context::UnsafeFn(hir_id), Some(AllAllowedInUnsafeFn(lint_root))) => { - UnusedUnsafe::InUnsafeFn(hir_id, lint_root) - } - (Context::UnsafeBlock(hir_id), Some(_)) => UnusedUnsafe::InUnsafeBlock(hir_id), + (Context::UnsafeBlock(hir_id), true) => UnusedUnsafe::InUnsafeBlock(hir_id), }; self.unused_unsafes.push((block.hir_id, unused_unsafe)); } intravisit::walk_block(self, block); } + fn visit_inline_const(&mut self, c: &'tcx hir::ConstBlock) { + self.visit_body(self.tcx.hir().body(c.body)) + } + fn visit_fn( &mut self, fk: intravisit::FnKind<'tcx>, _fd: &'tcx hir::FnDecl<'tcx>, b: hir::BodyId, _s: rustc_span::Span, - _id: HirId, + _id: LocalDefId, ) { if matches!(fk, intravisit::FnKind::Closure) { self.visit_body(self.tcx.hir().body(b)) @@ -465,17 +442,17 @@ impl<'tcx> intravisit::Visitor<'tcx> for UnusedUnsafeVisitor<'_, 'tcx> { fn check_unused_unsafe( tcx: TyCtxt<'_>, def_id: LocalDefId, - used_unsafe_blocks: &FxHashMap<HirId, UsedUnsafeBlockData>, + used_unsafe_blocks: &UnordSet<HirId>, ) -> Vec<(HirId, UnusedUnsafe)> { - let hir_id = tcx.hir().local_def_id_to_hir_id(def_id); - let body_id = tcx.hir().maybe_body_owned_by(hir_id); + let body_id = tcx.hir().maybe_body_owned_by(def_id); let Some(body_id) = body_id else { debug!("check_unused_unsafe({:?}) - no body found", def_id); return vec![]; }; - let body = tcx.hir().body(body_id); + let body = tcx.hir().body(body_id); + let hir_id = tcx.hir().local_def_id_to_hir_id(def_id); let context = match tcx.hir().fn_sig_by_hir_id(hir_id) { Some(sig) if sig.header.unsafety == hir::Unsafety::Unsafe => Context::UnsafeFn(hir_id), _ => Context::Safe, @@ -499,23 +476,28 @@ fn check_unused_unsafe( unused_unsafes } -fn unsafety_check_result<'tcx>( - tcx: TyCtxt<'tcx>, - def: ty::WithOptConstParam<LocalDefId>, -) -> &'tcx UnsafetyCheckResult { +fn unsafety_check_result(tcx: TyCtxt<'_>, def: LocalDefId) -> &UnsafetyCheckResult { debug!("unsafety_violations({:?})", def); // N.B., this borrow is valid because all the consumers of // `mir_built` force this. let body = &tcx.mir_built(def).borrow(); - let param_env = tcx.param_env(def.did); + if body.is_custom_mir() { + return tcx.arena.alloc(UnsafetyCheckResult { + violations: Vec::new(), + used_unsafe_blocks: Default::default(), + unused_unsafes: Some(Vec::new()), + }); + } - let mut checker = UnsafetyChecker::new(body, def.did, tcx, param_env); + let param_env = tcx.param_env(def); + + let mut checker = UnsafetyChecker::new(body, def, tcx, param_env); checker.visit_body(&body); - let unused_unsafes = (!tcx.is_closure(def.did.to_def_id())) - .then(|| check_unused_unsafe(tcx, def.did, &checker.used_unsafe_blocks)); + let unused_unsafes = (!tcx.is_typeck_child(def.to_def_id())) + .then(|| check_unused_unsafe(tcx, def, &checker.used_unsafe_blocks)); tcx.arena.alloc(UnsafetyCheckResult { violations: checker.violations, @@ -526,88 +508,62 @@ fn unsafety_check_result<'tcx>( fn report_unused_unsafe(tcx: TyCtxt<'_>, kind: UnusedUnsafe, id: HirId) { let span = tcx.sess.source_map().guess_head_span(tcx.hir().span(id)); - tcx.struct_span_lint_hir(UNUSED_UNSAFE, id, span, |lint| { - let msg = "unnecessary `unsafe` block"; - let mut db = lint.build(msg); - db.span_label(span, msg); - match kind { - UnusedUnsafe::Unused => {} - UnusedUnsafe::InUnsafeBlock(id) => { - db.span_label( - tcx.sess.source_map().guess_head_span(tcx.hir().span(id)), - "because it's nested under this `unsafe` block", - ); - } - UnusedUnsafe::InUnsafeFn(id, usage_lint_root) => { - db.span_label( - tcx.sess.source_map().guess_head_span(tcx.hir().span(id)), - "because it's nested under this `unsafe` fn", - ) - .note( - "this `unsafe` block does contain unsafe operations, \ - but those are already allowed in an `unsafe fn`", - ); - let (level, source) = - tcx.lint_level_at_node(UNSAFE_OP_IN_UNSAFE_FN, usage_lint_root); - assert_eq!(level, Level::Allow); - lint::explain_lint_level_source( - UNSAFE_OP_IN_UNSAFE_FN, - Level::Allow, - source, - &mut db, - ); - } - } - - db.emit(); - }); + let nested_parent = if let UnusedUnsafe::InUnsafeBlock(id) = kind { + Some(tcx.sess.source_map().guess_head_span(tcx.hir().span(id))) + } else { + None + }; + tcx.emit_spanned_lint(UNUSED_UNSAFE, id, span, errors::UnusedUnsafe { span, nested_parent }); } pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { debug!("check_unsafety({:?})", def_id); - // closures are handled by their parent fn. - if tcx.is_closure(def_id.to_def_id()) { + // closures and inline consts are handled by their parent fn. + if tcx.is_typeck_child(def_id.to_def_id()) { return; } let UnsafetyCheckResult { violations, unused_unsafes, .. } = tcx.unsafety_check_result(def_id); for &UnsafetyViolation { source_info, lint_root, kind, details } in violations.iter() { - let (description, note) = details.description_and_note(); - - // Report an error. - let unsafe_fn_msg = - if unsafe_op_in_unsafe_fn_allowed(tcx, lint_root) { " function or" } else { "" }; + let details = errors::RequiresUnsafeDetail { violation: details, span: source_info.span }; match kind { UnsafetyViolationKind::General => { - // once - struct_span_err!( - tcx.sess, - source_info.span, - E0133, - "{} is unsafe and requires unsafe{} block", - description, - unsafe_fn_msg, - ) - .span_label(source_info.span, description) - .note(note) - .emit(); + let op_in_unsafe_fn_allowed = unsafe_op_in_unsafe_fn_allowed(tcx, lint_root); + let note_non_inherited = tcx.hir().parent_iter(lint_root).find(|(id, node)| { + if let Node::Expr(block) = node + && let ExprKind::Block(block, _) = block.kind + && let BlockCheckMode::UnsafeBlock(_) = block.rules + { + true + } + else if let Some(sig) = tcx.hir().fn_sig_by_hir_id(*id) + && sig.header.is_unsafe() + { + true + } else { + false + } + }); + let enclosing = if let Some((id, _)) = note_non_inherited { + Some(tcx.sess.source_map().guess_head_span(tcx.hir().span(id))) + } else { + None + }; + tcx.sess.emit_err(errors::RequiresUnsafe { + span: source_info.span, + enclosing, + details, + op_in_unsafe_fn_allowed, + }); } - UnsafetyViolationKind::UnsafeFn => tcx.struct_span_lint_hir( + UnsafetyViolationKind::UnsafeFn => tcx.emit_spanned_lint( UNSAFE_OP_IN_UNSAFE_FN, lint_root, source_info.span, - |lint| { - lint.build(&format!( - "{} is unsafe and requires unsafe block (error E0133)", - description, - )) - .span_label(source_info.span, description) - .note(note) - .emit(); - }, + errors::UnsafeOpInUnsafeFn { details }, ), } } diff --git a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs index 611d29a4ee2..d435d3ee69b 100644 --- a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs +++ b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs @@ -1,39 +1,44 @@ -//! This module provides a pass to replacing the following statements with -//! [`Nop`]s +//! This module provides a pass that removes parts of MIR that are no longer relevant after +//! analysis phase and borrowck. In particular, it removes false edges, user type annotations and +//! replaces following statements with [`Nop`]s: //! //! - [`AscribeUserType`] //! - [`FakeRead`] //! - [`Assign`] statements with a [`Shallow`] borrow //! -//! The `CleanFakeReadsAndBorrows` "pass" is actually implemented as two -//! traversals (aka visits) of the input MIR. The first traversal, -//! `DeleteAndRecordFakeReads`, deletes the fake reads and finds the -//! temporaries read by [`ForMatchGuard`] reads, and `DeleteFakeBorrows` -//! deletes the initialization of those temporaries. -//! //! [`AscribeUserType`]: rustc_middle::mir::StatementKind::AscribeUserType -//! [`Shallow`]: rustc_middle::mir::BorrowKind::Shallow -//! [`FakeRead`]: rustc_middle::mir::StatementKind::FakeRead //! [`Assign`]: rustc_middle::mir::StatementKind::Assign -//! [`ForMatchGuard`]: rustc_middle::mir::FakeReadCause::ForMatchGuard +//! [`FakeRead`]: rustc_middle::mir::StatementKind::FakeRead //! [`Nop`]: rustc_middle::mir::StatementKind::Nop +//! [`Shallow`]: rustc_middle::mir::BorrowKind::Shallow use crate::MirPass; -use rustc_middle::mir::visit::MutVisitor; -use rustc_middle::mir::{Body, BorrowKind, Location, Rvalue}; -use rustc_middle::mir::{Statement, StatementKind}; +use rustc_middle::mir::{Body, BorrowKind, Rvalue, StatementKind, TerminatorKind}; use rustc_middle::ty::TyCtxt; -pub struct CleanupNonCodegenStatements; +pub struct CleanupPostBorrowck; -pub struct DeleteNonCodegenStatements<'tcx> { - tcx: TyCtxt<'tcx>, -} +impl<'tcx> MirPass<'tcx> for CleanupPostBorrowck { + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + for basic_block in body.basic_blocks.as_mut() { + for statement in basic_block.statements.iter_mut() { + match statement.kind { + StatementKind::AscribeUserType(..) + | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Shallow, _))) + | StatementKind::FakeRead(..) => statement.make_nop(), + _ => (), + } + } + let terminator = basic_block.terminator_mut(); + match terminator.kind { + TerminatorKind::FalseEdge { real_target, .. } + | TerminatorKind::FalseUnwind { real_target, .. } => { + terminator.kind = TerminatorKind::Goto { target: real_target }; + } + _ => {} + } + } -impl<'tcx> MirPass<'tcx> for CleanupNonCodegenStatements { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let mut delete = DeleteNonCodegenStatements { tcx }; - delete.visit_body(body); body.user_type_annotations.raw.clear(); for decl in &mut body.local_decls { @@ -41,19 +46,3 @@ impl<'tcx> MirPass<'tcx> for CleanupNonCodegenStatements { } } } - -impl<'tcx> MutVisitor<'tcx> for DeleteNonCodegenStatements<'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } - - fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { - match statement.kind { - StatementKind::AscribeUserType(..) - | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Shallow, _))) - | StatementKind::FakeRead(..) => statement.make_nop(), - _ => (), - } - self.super_statement(statement, location); - } -} diff --git a/compiler/rustc_mir_transform/src/const_debuginfo.rs b/compiler/rustc_mir_transform/src/const_debuginfo.rs index 6f0ae4f07ab..f662ce645b0 100644 --- a/compiler/rustc_mir_transform/src/const_debuginfo.rs +++ b/compiler/rustc_mir_transform/src/const_debuginfo.rs @@ -10,19 +10,19 @@ use rustc_middle::{ }; use crate::MirPass; -use rustc_index::{bit_set::BitSet, vec::IndexVec}; +use rustc_index::{bit_set::BitSet, IndexVec}; pub struct ConstDebugInfo; impl<'tcx> MirPass<'tcx> for ConstDebugInfo { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.opts.unstable_opts.unsound_mir_opts && sess.mir_opt_level() > 0 + sess.mir_opt_level() > 0 } fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("running ConstDebugInfo on {:?}", body.source); - for (local, constant) in find_optimization_oportunities(body) { + for (local, constant) in find_optimization_opportunities(body) { for debuginfo in &mut body.var_debug_info { if let VarDebugInfoContents::Place(p) = debuginfo.value { if p.local == local && p.projection.is_empty() { @@ -45,7 +45,7 @@ struct LocalUseVisitor { local_assignment_locations: IndexVec<Local, Option<Location>>, } -fn find_optimization_oportunities<'tcx>(body: &Body<'tcx>) -> Vec<(Local, Constant<'tcx>)> { +fn find_optimization_opportunities<'tcx>(body: &Body<'tcx>) -> Vec<(Local, Constant<'tcx>)> { let mut visitor = LocalUseVisitor { local_mutating_uses: IndexVec::from_elem(0, &body.local_decls), local_assignment_locations: IndexVec::from_elem(None, &body.local_decls), diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs index 5acf939f06b..024bea62098 100644 --- a/compiler/rustc_mir_transform/src/const_goto.rs +++ b/compiler/rustc_mir_transform/src/const_goto.rs @@ -28,7 +28,7 @@ pub struct ConstGoto; impl<'tcx> MirPass<'tcx> for ConstGoto { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 4 + sess.mir_opt_level() >= 2 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -57,18 +57,27 @@ impl<'tcx> MirPass<'tcx> for ConstGoto { } impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> { + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { + if data.is_cleanup { + // Because of the restrictions around control flow in cleanup blocks, we don't perform + // this optimization at all in such blocks. + return; + } + self.super_basic_block_data(block, data); + } + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { let _: Option<_> = try { let target = terminator.kind.as_goto()?; // We only apply this optimization if the last statement is a const assignment - let last_statement = self.body.basic_blocks()[location.block].statements.last()?; + let last_statement = self.body.basic_blocks[location.block].statements.last()?; if let (place, Rvalue::Use(Operand::Constant(_const))) = last_statement.kind.as_assign()? { // We found a constant being assigned to `place`. // Now check that the target of this Goto switches on this place. - let target_bb = &self.body.basic_blocks()[target]; + let target_bb = &self.body.basic_blocks[target]; // The `StorageDead(..)` statement does not affect the functionality of mir. // We can move this part of the statement up to the predecessor. @@ -82,8 +91,9 @@ impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> { } let target_bb_terminator = target_bb.terminator(); - let (discr, switch_ty, targets) = target_bb_terminator.kind.as_switch()?; + let (discr, targets) = target_bb_terminator.kind.as_switch()?; if discr.place() == Some(*place) { + let switch_ty = place.ty(self.body.local_decls(), self.tcx).ty; // We now know that the Switch matches on the const place, and it is statementless // Now find which value in the Switch matches the const value. let const_value = diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs index acd9e605353..1d43dbda0aa 100644 --- a/compiler/rustc_mir_transform/src/const_prop.rs +++ b/compiler/rustc_mir_transform/src/const_prop.rs @@ -1,36 +1,30 @@ //! Propagates constants for early reporting of statically known //! assertion failures -use std::cell::Cell; +use either::Right; -use rustc_ast::Mutability; +use rustc_const_eval::const_eval::CheckAlignment; +use rustc_const_eval::ReportErrorExt; use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; use rustc_index::bit_set::BitSet; -use rustc_index::vec::IndexVec; +use rustc_index::{IndexSlice, IndexVec}; use rustc_middle::mir::visit::{ MutVisitor, MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor, }; -use rustc_middle::mir::{ - BasicBlock, BinOp, Body, Constant, ConstantKind, Local, LocalDecl, LocalKind, Location, - Operand, Place, Rvalue, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp, - RETURN_PLACE, -}; +use rustc_middle::mir::*; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; -use rustc_middle::ty::subst::{InternalSubsts, Subst}; -use rustc_middle::ty::{ - self, ConstKind, EarlyBinder, Instance, ParamEnv, Ty, TyCtxt, TypeVisitable, -}; -use rustc_span::{def_id::DefId, Span}; -use rustc_target::abi::{self, HasDataLayout, Size, TargetDataLayout}; +use rustc_middle::ty::InternalSubsts; +use rustc_middle::ty::{self, ConstKind, Instance, ParamEnv, Ty, TyCtxt, TypeVisitableExt}; +use rustc_span::{def_id::DefId, Span, DUMMY_SP}; +use rustc_target::abi::{self, Align, HasDataLayout, Size, TargetDataLayout}; use rustc_target::spec::abi::Abi as CallAbi; -use rustc_trait_selection::traits; use crate::MirPass; use rustc_const_eval::interpret::{ - self, compile_time_machine, AllocId, ConstAllocation, ConstValue, CtfeValidationMode, Frame, - ImmTy, Immediate, InterpCx, InterpResult, LocalState, LocalValue, MemoryKind, OpTy, PlaceTy, - Pointer, Scalar, ScalarMaybeUninit, StackPopCleanup, StackPopUnwind, + self, compile_time_machine, AllocId, ConstAllocation, ConstValue, Frame, ImmTy, Immediate, + InterpCx, InterpResult, LocalValue, MemoryKind, OpTy, PlaceTy, Pointer, Scalar, + StackPopCleanup, }; /// The maximum number of bytes that we'll allocate space for a local or the return value. @@ -44,6 +38,7 @@ macro_rules! throw_machine_stop_str { ($($tt:tt)*) => {{ // We make a new local type for it. The type itself does not carry any information, // but its vtable (for the `MachineStopType` trait) does. + #[derive(Debug)] struct Zst; // Printing this type shows the desired string. impl std::fmt::Display for Zst { @@ -51,7 +46,17 @@ macro_rules! throw_machine_stop_str { write!(f, $($tt)*) } } - impl rustc_middle::mir::interpret::MachineStopType for Zst {} + + impl rustc_middle::mir::interpret::MachineStopType for Zst { + fn diagnostic_message(&self) -> rustc_errors::DiagnosticMessage { + self.to_string().into() + } + + fn add_args( + self: Box<Self>, + _: &mut dyn FnMut(std::borrow::Cow<'static, str>, rustc_errors::DiagnosticArgValue<'static>), + ) {} + } throw_machine_stop!(Zst) }}; } @@ -60,7 +65,7 @@ pub struct ConstProp; impl<'tcx> MirPass<'tcx> for ConstProp { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 1 + sess.mir_opt_level() >= 2 } #[instrument(skip(self, tcx), level = "debug")] @@ -82,7 +87,7 @@ impl<'tcx> MirPass<'tcx> for ConstProp { return; } - let is_generator = tcx.type_of(def_id.to_def_id()).is_generator(); + let is_generator = tcx.type_of(def_id.to_def_id()).subst_identity().is_generator(); // FIXME(welseywiser) const prop doesn't work on generators because of query cycles // computing their layout. if is_generator { @@ -90,50 +95,11 @@ impl<'tcx> MirPass<'tcx> for ConstProp { return; } - // Check if it's even possible to satisfy the 'where' clauses - // for this item. - // This branch will never be taken for any normal function. - // However, it's possible to `#!feature(trivial_bounds)]` to write - // a function with impossible to satisfy clauses, e.g.: - // `fn foo() where String: Copy {}` - // - // We don't usually need to worry about this kind of case, - // since we would get a compilation error if the user tried - // to call it. However, since we can do const propagation - // even without any calls to the function, we need to make - // sure that it even makes sense to try to evaluate the body. - // If there are unsatisfiable where clauses, then all bets are - // off, and we just give up. - // - // We manually filter the predicates, skipping anything that's not - // "global". We are in a potentially generic context - // (e.g. we are evaluating a function without substituting generic - // parameters, so this filtering serves two purposes: - // - // 1. We skip evaluating any predicates that we would - // never be able prove are unsatisfiable (e.g. `<T as Foo>` - // 2. We avoid trying to normalize predicates involving generic - // parameters (e.g. `<T as Foo>::MyItem`). This can confuse - // the normalization code (leading to cycle errors), since - // it's usually never invoked in this way. - let predicates = tcx - .predicates_of(def_id.to_def_id()) - .predicates - .iter() - .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None }); - if traits::impossible_predicates( - tcx, - traits::elaborate_predicates(tcx, predicates).map(|o| o.predicate).collect(), - ) { - trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id); - return; - } - trace!("ConstProp starting for {:?}", def_id); let dummy_body = &Body::new( body.source, - body.basic_blocks().clone(), + (*body.basic_blocks).to_owned(), body.source_scopes.clone(), body.local_decls.clone(), Default::default(), @@ -149,31 +115,31 @@ impl<'tcx> MirPass<'tcx> for ConstProp { // That would require a uniform one-def no-mutation analysis // and RPO (or recursing when needing the value of a local). let mut optimization_finder = ConstPropagator::new(body, dummy_body, tcx); - optimization_finder.visit_body(body); + + // Traverse the body in reverse post-order, to ensure that `FullConstProp` locals are + // assigned before being read. + let postorder = body.basic_blocks.postorder().to_vec(); + for bb in postorder.into_iter().rev() { + let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb]; + optimization_finder.visit_basic_block_data(bb, data); + } trace!("ConstProp done for {:?}", def_id); } } -struct ConstPropMachine<'mir, 'tcx> { +pub struct ConstPropMachine<'mir, 'tcx> { /// The virtual call stack. stack: Vec<Frame<'mir, 'tcx>>, - /// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end. - written_only_inside_own_block_locals: FxHashSet<Local>, - /// Locals that need to be cleared after every block terminates. - only_propagate_inside_block_locals: BitSet<Local>, - can_const_prop: IndexVec<Local, ConstPropMode>, + pub written_only_inside_own_block_locals: FxHashSet<Local>, + pub can_const_prop: IndexVec<Local, ConstPropMode>, } impl ConstPropMachine<'_, '_> { - fn new( - only_propagate_inside_block_locals: BitSet<Local>, - can_const_prop: IndexVec<Local, ConstPropMode>, - ) -> Self { + pub fn new(can_const_prop: IndexVec<Local, ConstPropMode>) -> Self { Self { stack: Vec::new(), written_only_inside_own_block_locals: Default::default(), - only_propagate_inside_block_locals, can_const_prop, } } @@ -185,6 +151,29 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> type MemoryKind = !; + #[inline(always)] + fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment { + // We do not check for alignment to avoid having to carry an `Align` + // in `ConstValue::ByRef`. + CheckAlignment::No + } + + #[inline(always)] + fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { + false // for now, we don't enforce validity + } + fn alignment_check_failed( + ecx: &InterpCx<'mir, 'tcx, Self>, + _has: Align, + _required: Align, + _check: CheckAlignment, + ) -> InterpResult<'tcx, ()> { + span_bug!( + ecx.cur_span(), + "`alignment_check_failed` called when no alignment check requested" + ) + } + fn load_mir( _ecx: &InterpCx<'mir, 'tcx, Self>, _instance: ty::InstanceDef<'tcx>, @@ -199,7 +188,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> _args: &[OpTy<'tcx>], _destination: &PlaceTy<'tcx>, _target: Option<BasicBlock>, - _unwind: StackPopUnwind, + _unwind: UnwindAction, ) -> InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { Ok(None) } @@ -210,7 +199,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> _args: &[OpTy<'tcx>], _destination: &PlaceTy<'tcx>, _target: Option<BasicBlock>, - _unwind: StackPopUnwind, + _unwind: UnwindAction, ) -> InterpResult<'tcx> { throw_machine_stop_str!("calling intrinsics isn't supported in ConstProp") } @@ -218,7 +207,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> fn assert_panic( _ecx: &mut InterpCx<'mir, 'tcx, Self>, _msg: &rustc_middle::mir::AssertMessage<'tcx>, - _unwind: Option<rustc_middle::mir::BasicBlock>, + _unwind: rustc_middle::mir::UnwindAction, ) -> InterpResult<'tcx> { bug!("panics terminators are not evaluated in ConstProp") } @@ -233,39 +222,22 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> throw_machine_stop_str!("pointer arithmetic or comparisons aren't supported in ConstProp") } - fn access_local<'a>( - frame: &'a Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>, - local: Local, - ) -> InterpResult<'tcx, &'a interpret::Operand<Self::PointerTag>> { - let l = &frame.locals[local]; - - if matches!( - l.value, - LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)) - ) { - // For us "uninit" means "we don't know its value, might be initiailized or not". - // So stop here. - throw_machine_stop_str!("tried to access alocal with unknown value ") - } - - l.access() - } - fn access_local_mut<'a>( ecx: &'a mut InterpCx<'mir, 'tcx, Self>, frame: usize, local: Local, - ) -> InterpResult<'tcx, &'a mut interpret::Operand<Self::PointerTag>> { - if ecx.machine.can_const_prop[local] == ConstPropMode::NoPropagation { - throw_machine_stop_str!("tried to write to a local that is marked as not propagatable") - } - if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) { - trace!( - "mutating local {:?} which is restricted to its block. \ - Will remove it from const-prop after block is finished.", - local - ); - ecx.machine.written_only_inside_own_block_locals.insert(local); + ) -> InterpResult<'tcx, &'a mut interpret::Operand<Self::Provenance>> { + assert_eq!(frame, 0); + match ecx.machine.can_const_prop[local] { + ConstPropMode::NoPropagation => { + throw_machine_stop_str!( + "tried to write to a local that is marked as not propagatable" + ) + } + ConstPropMode::OnlyInsideOwnBlock => { + ecx.machine.written_only_inside_own_block_locals.insert(local); + } + ConstPropMode::FullConstProp => {} } ecx.machine.stack[frame].locals[local].access_mut() } @@ -274,7 +246,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> _tcx: TyCtxt<'tcx>, _machine: &Self, _alloc_id: AllocId, - alloc: ConstAllocation<'tcx, Self::PointerTag, Self::AllocExtra>, + alloc: ConstAllocation<'tcx>, _static_def_id: Option<DefId>, is_write: bool, ) -> InterpResult<'tcx> { @@ -283,7 +255,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> } // If the static allocation is mutable, then we can't const prop it as its content // might be different at runtime. - if alloc.inner().mutability == Mutability::Mut { + if alloc.inner().mutability.is_mut() { throw_machine_stop_str!("can't access mutable globals in ConstProp"); } @@ -309,14 +281,14 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> #[inline(always)] fn stack<'a>( ecx: &'a InterpCx<'mir, 'tcx, Self>, - ) -> &'a [Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>] { + ) -> &'a [Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] { &ecx.machine.stack } #[inline(always)] fn stack_mut<'a>( ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - ) -> &'a mut Vec<Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>> { + ) -> &'a mut Vec<Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>> { &mut ecx.machine.stack } } @@ -326,10 +298,7 @@ struct ConstPropagator<'mir, 'tcx> { ecx: InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, - local_decls: &'mir IndexVec<Local, LocalDecl<'tcx>>, - // Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store - // the last known `SourceInfo` here and just keep revisiting it. - source_info: Option<SourceInfo>, + local_decls: &'mir IndexSlice<Local, LocalDecl<'tcx>>, } impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { @@ -373,27 +342,21 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let param_env = tcx.param_env_reveal_all_normalized(def_id); let can_const_prop = CanConstProp::check(tcx, param_env, body); - let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len()); - for (l, mode) in can_const_prop.iter_enumerated() { - if *mode == ConstPropMode::OnlyInsideOwnBlock { - only_propagate_inside_block_locals.insert(l); - } - } let mut ecx = InterpCx::new( tcx, tcx.def_span(def_id), param_env, - ConstPropMachine::new(only_propagate_inside_block_locals, can_const_prop), + ConstPropMachine::new(can_const_prop), ); let ret_layout = ecx - .layout_of(EarlyBinder(body.return_ty()).subst(tcx, substs)) + .layout_of(body.bound_return_ty().subst(tcx, substs)) .ok() // Don't bother allocating memory for large values. // I don't know how return types can seem to be unsized but this happens in the // `type/type-unsatisfiable.rs` test. .filter(|ret_layout| { - !ret_layout.is_unsized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) + ret_layout.is_sized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) }) .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); @@ -410,28 +373,28 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { ) .expect("failed to push initial stack frame"); - ConstPropagator { - ecx, - tcx, - param_env, - local_decls: &dummy_body.local_decls, - source_info: None, - } + ConstPropagator { ecx, tcx, param_env, local_decls: &dummy_body.local_decls } } fn get_const(&self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { let op = match self.ecx.eval_place_to_op(place, None) { - Ok(op) => op, + Ok(op) => { + if matches!(*op, interpret::Operand::Immediate(Immediate::Uninit)) { + // Make sure nobody accidentally uses this value. + return None; + } + op + } Err(e) => { - trace!("get_const failed: {}", e); + trace!("get_const failed: {:?}", e.into_kind().debug()); return None; } }; // Try to read the local as an immediate so that if it is representable as a scalar, we can // handle it as such, but otherwise, just return the value as is. - Some(match self.ecx.read_immediate_raw(&op, /*force*/ false) { - Ok(Ok(imm)) => imm.into(), + Some(match self.ecx.read_immediate_raw(&op) { + Ok(Right(imm)) => imm.into(), _ => op, }) } @@ -439,47 +402,26 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { /// Remove `local` from the pool of `Locals`. Allows writing to them, /// but not reading from them anymore. fn remove_const(ecx: &mut InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, local: Local) { - ecx.frame_mut().locals[local] = LocalState { - value: LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)), - layout: Cell::new(None), - }; - } - - fn use_ecx<F, T>(&mut self, f: F) -> Option<T> - where - F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, - { - match f(self) { - Ok(val) => Some(val), - Err(error) => { - trace!("InterpCx operation failed: {:?}", error); - // Some errors shouldn't come up because creating them causes - // an allocation, which we should avoid. When that happens, - // dedicated error variants should be introduced instead. - assert!( - !error.kind().formatted_string(), - "const-prop encountered formatting error: {}", - error - ); - None - } - } + ecx.frame_mut().locals[local].value = + LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)); + ecx.machine.written_only_inside_own_block_locals.remove(&local); } /// Returns the value, if any, of evaluating `c`. fn eval_constant(&mut self, c: &Constant<'tcx>) -> Option<OpTy<'tcx>> { // FIXME we need to revisit this for #67176 - if c.needs_subst() { + if c.has_param() { return None; } - self.ecx.mir_const_to_op(&c.literal, None).ok() + // No span, we don't want errors to be shown. + self.ecx.eval_mir_constant(&c.literal, None, None).ok() } /// Returns the value, if any, of evaluating `place`. fn eval_place(&mut self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { trace!("eval_place(place={:?})", place); - self.use_ecx(|this| this.ecx.eval_place_to_op(place, None)) + self.ecx.eval_place_to_op(place, None).ok() } /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` @@ -491,56 +433,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } } - fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>) -> Option<()> { - if self.use_ecx(|this| { - let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?; - let (_res, overflow, _ty) = this.ecx.overflowing_unary_op(op, &val)?; - Ok(overflow) - })? { - // `AssertKind` only has an `OverflowNeg` variant, so make sure that is - // appropriate to use. - assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); - return None; - } - - Some(()) - } - - fn check_binary_op( - &mut self, - op: BinOp, - left: &Operand<'tcx>, - right: &Operand<'tcx>, - ) -> Option<()> { - let r = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?)); - let l = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?)); - // Check for exceeding shifts *even if* we cannot evaluate the LHS. - if op == BinOp::Shr || op == BinOp::Shl { - let r = r?; - // We need the type of the LHS. We cannot use `place_layout` as that is the type - // of the result, which for checked binops is not the same! - let left_ty = left.ty(self.local_decls, self.tcx); - let left_size = self.ecx.layout_of(left_ty).ok()?.size; - let right_size = r.layout.size; - let r_bits = r.to_scalar().ok(); - let r_bits = r_bits.and_then(|r| r.to_bits(right_size).ok()); - if r_bits.map_or(false, |b| b >= left_size.bits() as u128) { - return None; - } - } - - if let (Some(l), Some(r)) = (&l, &r) { - // The remaining operators are handled through `overflowing_binary_op`. - if self.use_ecx(|this| { - let (_res, overflow, _ty) = this.ecx.overflowing_binary_op(op, l, r)?; - Ok(overflow) - })? { - return None; - } - } - Some(()) - } - fn propagate_operand(&mut self, operand: &mut Operand<'tcx>) { match *operand { Operand::Copy(l) | Operand::Move(l) => { @@ -552,14 +444,10 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // and use it to do const-prop here and everywhere else // where it makes sense. if let interpret::Operand::Immediate(interpret::Immediate::Scalar( - ScalarMaybeUninit::Scalar(scalar), + scalar, )) = *value { - *operand = self.operand_from_scalar( - scalar, - value.layout.ty, - self.source_info.unwrap().span, - ); + *operand = self.operand_from_scalar(scalar, value.layout.ty); } } } @@ -567,7 +455,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } } - fn const_prop(&mut self, rvalue: &Rvalue<'tcx>, place: Place<'tcx>) -> Option<()> { + fn check_rvalue(&mut self, rvalue: &Rvalue<'tcx>) -> Option<()> { // Perform any special handling for specific Rvalue types. // Generally, checks here fall into one of two categories: // 1. Additional checking to provide useful lints to the user @@ -576,28 +464,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // 2. Working around bugs in other parts of the compiler // - In this case, we'll return `None` from this function to stop evaluation. match rvalue { - // Additional checking: give lints to the user if an overflow would occur. - // We do this here and not in the `Assert` terminator as that terminator is - // only sometimes emitted (overflow checks can be disabled), but we want to always - // lint. - Rvalue::UnaryOp(op, arg) => { - trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); - self.check_unary_op(*op, arg)?; - } - Rvalue::BinaryOp(op, box (left, right)) => { - trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); - self.check_binary_op(*op, left, right)?; - } - Rvalue::CheckedBinaryOp(op, box (left, right)) => { - trace!( - "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", - op, - left, - right - ); - self.check_binary_op(*op, left, right)?; - } - // Do not try creating references (#67862) Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { trace!("skipping AddressOf | Ref for {:?}", place); @@ -617,7 +483,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { return None; } - // There's no other checking to do at this time. Rvalue::Aggregate(..) | Rvalue::Use(..) @@ -627,19 +492,26 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { | Rvalue::Cast(..) | Rvalue::ShallowInitBox(..) | Rvalue::Discriminant(..) - | Rvalue::NullaryOp(..) => {} + | Rvalue::NullaryOp(..) + | Rvalue::UnaryOp(..) + | Rvalue::BinaryOp(..) + | Rvalue::CheckedBinaryOp(..) => {} } // FIXME we need to revisit this for #67176 - if rvalue.needs_subst() { + if rvalue.has_param() { return None; } - - if self.tcx.sess.mir_opt_level() >= 4 { - self.eval_rvalue_with_identities(rvalue, place) - } else { - self.use_ecx(|this| this.ecx.eval_rvalue_into_place(rvalue, place)) + if !rvalue + .ty(&self.ecx.frame().body.local_decls, *self.ecx.tcx) + .is_sized(*self.ecx.tcx, self.param_env) + { + // the interpreter doesn't support unsized locals (only unsized arguments), + // but rustc does (in a kinda broken way), so we have to skip them here + return None; } + + Some(()) } // Attempt to use algebraic identities to eliminate constant expressions @@ -648,67 +520,75 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { rvalue: &Rvalue<'tcx>, place: Place<'tcx>, ) -> Option<()> { - self.use_ecx(|this| match rvalue { + match rvalue { Rvalue::BinaryOp(op, box (left, right)) | Rvalue::CheckedBinaryOp(op, box (left, right)) => { - let l = this.ecx.eval_operand(left, None); - let r = this.ecx.eval_operand(right, None); + let l = self.ecx.eval_operand(left, None).and_then(|x| self.ecx.read_immediate(&x)); + let r = + self.ecx.eval_operand(right, None).and_then(|x| self.ecx.read_immediate(&x)); let const_arg = match (l, r) { - (Ok(ref x), Err(_)) | (Err(_), Ok(ref x)) => this.ecx.read_immediate(x)?, - (Err(e), Err(_)) => return Err(e), - (Ok(_), Ok(_)) => return this.ecx.eval_rvalue_into_place(rvalue, place), + (Ok(x), Err(_)) | (Err(_), Ok(x)) => x, // exactly one side is known + (Err(_), Err(_)) => return None, // neither side is known + (Ok(_), Ok(_)) => return self.ecx.eval_rvalue_into_place(rvalue, place).ok(), // both sides are known }; if !matches!(const_arg.layout.abi, abi::Abi::Scalar(..)) { // We cannot handle Scalar Pair stuff. - return this.ecx.eval_rvalue_into_place(rvalue, place); + // No point in calling `eval_rvalue_into_place`, since only one side is known + return None; } - let arg_value = const_arg.to_scalar()?.to_bits(const_arg.layout.size)?; - let dest = this.ecx.eval_place(place)?; + let arg_value = const_arg.to_scalar().to_bits(const_arg.layout.size).ok()?; + let dest = self.ecx.eval_place(place).ok()?; match op { - BinOp::BitAnd if arg_value == 0 => this.ecx.write_immediate(*const_arg, &dest), + BinOp::BitAnd if arg_value == 0 => { + self.ecx.write_immediate(*const_arg, &dest).ok() + } BinOp::BitOr if arg_value == const_arg.layout.size.truncate(u128::MAX) || (const_arg.layout.ty.is_bool() && arg_value == 1) => { - this.ecx.write_immediate(*const_arg, &dest) + self.ecx.write_immediate(*const_arg, &dest).ok() } BinOp::Mul if const_arg.layout.ty.is_integral() && arg_value == 0 => { if let Rvalue::CheckedBinaryOp(_, _) = rvalue { let val = Immediate::ScalarPair( - const_arg.to_scalar()?.into(), - Scalar::from_bool(false).into(), + const_arg.to_scalar(), + Scalar::from_bool(false), ); - this.ecx.write_immediate(val, &dest) + self.ecx.write_immediate(val, &dest).ok() } else { - this.ecx.write_immediate(*const_arg, &dest) + self.ecx.write_immediate(*const_arg, &dest).ok() } } - _ => this.ecx.eval_rvalue_into_place(rvalue, place), + _ => None, } } - _ => this.ecx.eval_rvalue_into_place(rvalue, place), - }) + _ => self.ecx.eval_rvalue_into_place(rvalue, place).ok(), + } } /// Creates a new `Operand::Constant` from a `Scalar` value - fn operand_from_scalar(&self, scalar: Scalar, ty: Ty<'tcx>, span: Span) -> Operand<'tcx> { + fn operand_from_scalar(&self, scalar: Scalar, ty: Ty<'tcx>) -> Operand<'tcx> { Operand::Constant(Box::new(Constant { - span, + span: DUMMY_SP, user_ty: None, literal: ConstantKind::from_scalar(self.tcx, scalar, ty), })) } - fn replace_with_const( - &mut self, - rval: &mut Rvalue<'tcx>, - value: &OpTy<'tcx>, - source_info: SourceInfo, - ) { + fn replace_with_const(&mut self, place: Place<'tcx>, rval: &mut Rvalue<'tcx>) { + // This will return None if the above `const_prop` invocation only "wrote" a + // type whose creation requires no write. E.g. a generator whose initial state + // consists solely of uninitialized memory (so it doesn't capture any locals). + let Some(ref value) = self.get_const(place) else { return }; + if !self.should_const_prop(value) { + return; + } + trace!("replacing {:?}={:?} with {:?}", place, rval, value); + if let Rvalue::Use(Operand::Constant(c)) = rval { match c.literal { ConstantKind::Ty(c) if matches!(c.kind(), ConstKind::Unevaluated(..)) => {} @@ -720,34 +600,15 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } trace!("attempting to replace {:?} with {:?}", rval, value); - if let Err(e) = self.ecx.const_validate_operand( - value, - vec![], - // FIXME: is ref tracking too expensive? - // FIXME: what is the point of ref tracking if we do not even check the tracked refs? - &mut interpret::RefTracking::empty(), - CtfeValidationMode::Regular, - ) { - trace!("validation error, attempt failed: {:?}", e); - return; - } - // FIXME> figure out what to do when read_immediate_raw fails - let imm = self.use_ecx(|this| this.ecx.read_immediate_raw(value, /*force*/ false)); + let imm = self.ecx.read_immediate_raw(value).ok(); - if let Some(Ok(imm)) = imm { + if let Some(Right(imm)) = imm { match *imm { - interpret::Immediate::Scalar(ScalarMaybeUninit::Scalar(scalar)) => { - *rval = Rvalue::Use(self.operand_from_scalar( - scalar, - value.layout.ty, - source_info.span, - )); + interpret::Immediate::Scalar(scalar) => { + *rval = Rvalue::Use(self.operand_from_scalar(scalar, value.layout.ty)); } - Immediate::ScalarPair( - ScalarMaybeUninit::Scalar(_), - ScalarMaybeUninit::Scalar(_), - ) => { + Immediate::ScalarPair(..) => { // Found a value represented as a pair. For now only do const-prop if the type // of `rvalue` is also a tuple with two scalars. // FIXME: enable the general case stated above ^. @@ -756,31 +617,29 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { if let ty::Tuple(types) = ty.kind() { // Only do it if tuple is also a pair with two scalars if let [ty1, ty2] = types[..] { - let alloc = self.use_ecx(|this| { - let ty_is_scalar = |ty| { - this.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar()) - == Some(true) - }; - if ty_is_scalar(ty1) && ty_is_scalar(ty2) { - let alloc = this - .ecx - .intern_with_temp_alloc(value.layout, |ecx, dest| { - ecx.write_immediate(*imm, dest) - }) - .unwrap(); - Ok(Some(alloc)) - } else { - Ok(None) - } - }); - - if let Some(Some(alloc)) = alloc { + let ty_is_scalar = |ty| { + self.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar()) + == Some(true) + }; + let alloc = if ty_is_scalar(ty1) && ty_is_scalar(ty2) { + let alloc = self + .ecx + .intern_with_temp_alloc(value.layout, |ecx, dest| { + ecx.write_immediate(*imm, dest) + }) + .unwrap(); + Some(alloc) + } else { + None + }; + + if let Some(alloc) = alloc { // Assign entire constant in a single statement. // We can't use aggregates, as we run after the aggregate-lowering `MirPhase`. let const_val = ConstValue::ByRef { alloc, offset: Size::ZERO }; let literal = ConstantKind::Val(const_val, ty); *rval = Rvalue::Use(Operand::Constant(Box::new(Constant { - span: source_info.span, + span: DUMMY_SP, user_ty: None, literal, }))); @@ -802,43 +661,49 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } match **op { - interpret::Operand::Immediate(Immediate::Scalar(ScalarMaybeUninit::Scalar(s))) => { - s.try_to_int().is_ok() + interpret::Operand::Immediate(Immediate::Scalar(s)) => s.try_to_int().is_ok(), + interpret::Operand::Immediate(Immediate::ScalarPair(l, r)) => { + l.try_to_int().is_ok() && r.try_to_int().is_ok() } - interpret::Operand::Immediate(Immediate::ScalarPair( - ScalarMaybeUninit::Scalar(l), - ScalarMaybeUninit::Scalar(r), - )) => l.try_to_int().is_ok() && r.try_to_int().is_ok(), _ => false, } } + + fn ensure_not_propagated(&self, local: Local) { + if cfg!(debug_assertions) { + assert!( + self.get_const(local.into()).is_none() + || self + .layout_of(self.local_decls[local].ty) + .map_or(true, |layout| layout.is_zst()), + "failed to remove values for `{local:?}`, value={:?}", + self.get_const(local.into()), + ) + } + } } /// The mode that `ConstProp` is allowed to run in for a given `Local`. #[derive(Clone, Copy, Debug, PartialEq)] -enum ConstPropMode { +pub enum ConstPropMode { /// The `Local` can be propagated into and reads of this `Local` can also be propagated. FullConstProp, /// The `Local` can only be propagated into and from its own block. OnlyInsideOwnBlock, - /// The `Local` can be propagated into but reads cannot be propagated. - OnlyPropagateInto, /// The `Local` cannot be part of propagation at all. Any statement /// referencing it either for reading or writing will not get propagated. NoPropagation, } -struct CanConstProp { +pub struct CanConstProp { can_const_prop: IndexVec<Local, ConstPropMode>, // False at the beginning. Once set, no more assignments are allowed to that local. found_assignment: BitSet<Local>, - // Cache of locals' information - local_kinds: IndexVec<Local, LocalKind>, } impl CanConstProp { /// Returns true if `local` can be propagated - fn check<'tcx>( + pub fn check<'tcx>( tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, body: &Body<'tcx>, @@ -846,10 +711,6 @@ impl CanConstProp { let mut cpv = CanConstProp { can_const_prop: IndexVec::from_elem(ConstPropMode::FullConstProp, &body.local_decls), found_assignment: BitSet::new_empty(body.local_decls.len()), - local_kinds: IndexVec::from_fn_n( - |local| body.local_kind(local), - body.local_decls.len(), - ), }; for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { let ty = body.local_decls[local].ty; @@ -862,37 +723,32 @@ impl CanConstProp { continue; } } - // Cannot use args at all - // Cannot use locals because if x < y { y - x } else { x - y } would - // lint for x != y - // FIXME(oli-obk): lint variables until they are used in a condition - // FIXME(oli-obk): lint if return value is constant - if cpv.local_kinds[local] == LocalKind::Arg { - *val = ConstPropMode::OnlyPropagateInto; - trace!( - "local {:?} can't be const propagated because it's a function argument", - local - ); - } else if cpv.local_kinds[local] == LocalKind::Var { - *val = ConstPropMode::OnlyInsideOwnBlock; - trace!( - "local {:?} will only be propagated inside its block, because it's a user variable", - local - ); - } + } + // Consider that arguments are assigned on entry. + for arg in body.args_iter() { + cpv.found_assignment.insert(arg); } cpv.visit_body(&body); cpv.can_const_prop } } -impl Visitor<'_> for CanConstProp { +impl<'tcx> Visitor<'tcx> for CanConstProp { + fn visit_place(&mut self, place: &Place<'tcx>, mut context: PlaceContext, loc: Location) { + use rustc_middle::mir::visit::PlaceContext::*; + + // Dereferencing just read the addess of `place.local`. + if place.projection.first() == Some(&PlaceElem::Deref) { + context = NonMutatingUse(NonMutatingUseContext::Copy); + } + + self.visit_local(place.local, context, loc); + self.visit_projection(place.as_ref(), context, loc); + } + fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) { use rustc_middle::mir::visit::PlaceContext::*; match context { - // Projections are fine, because `&mut foo.x` will be caught by - // `MutatingUseContext::Borrow` elsewhere. - MutatingUse(MutatingUseContext::Projection) // These are just stores, where the storing is not propagatable, but there may be later // mutations of the same local via `Store` | MutatingUse(MutatingUseContext::Call) @@ -909,7 +765,6 @@ impl Visitor<'_> for CanConstProp { // states as applicable. ConstPropMode::OnlyInsideOwnBlock => {} ConstPropMode::NoPropagation => {} - ConstPropMode::OnlyPropagateInto => {} other @ ConstPropMode::FullConstProp => { trace!( "local {:?} can't be propagated because of multiple assignments. Previous state: {:?}", @@ -924,7 +779,7 @@ impl Visitor<'_> for CanConstProp { NonMutatingUse(NonMutatingUseContext::Copy) | NonMutatingUse(NonMutatingUseContext::Move) | NonMutatingUse(NonMutatingUseContext::Inspect) - | NonMutatingUse(NonMutatingUseContext::Projection) + | NonMutatingUse(NonMutatingUseContext::PlaceMention) | NonUse(_) => {} // These could be propagated with a smarter analysis or just some careful thinking about @@ -936,13 +791,14 @@ impl Visitor<'_> for CanConstProp { // mutation. | NonMutatingUse(NonMutatingUseContext::SharedBorrow) | NonMutatingUse(NonMutatingUseContext::ShallowBorrow) - | NonMutatingUse(NonMutatingUseContext::UniqueBorrow) | NonMutatingUse(NonMutatingUseContext::AddressOf) | MutatingUse(MutatingUseContext::Borrow) | MutatingUse(MutatingUseContext::AddressOf) => { - trace!("local {:?} can't be propagaged because it's used: {:?}", local, context); + trace!("local {:?} can't be propagated because it's used: {:?}", local, context); self.can_const_prop[local] = ConstPropMode::NoPropagation; } + MutatingUse(MutatingUseContext::Projection) + | NonMutatingUse(NonMutatingUseContext::Projection) => bug!("visit_place should not pass {context:?} for {local:?}"), } } } @@ -952,12 +808,6 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { self.tcx } - fn visit_body(&mut self, body: &mut Body<'tcx>) { - for (bb, data) in body.basic_blocks_mut().iter_enumerated_mut() { - self.visit_basic_block_data(bb, data); - } - } - fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { self.super_operand(operand, location); @@ -968,129 +818,118 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { } } - fn visit_constant(&mut self, constant: &mut Constant<'tcx>, location: Location) { - trace!("visit_constant: {:?}", constant); - self.super_constant(constant, location); - self.eval_constant(constant); + fn process_projection_elem( + &mut self, + elem: PlaceElem<'tcx>, + _: Location, + ) -> Option<PlaceElem<'tcx>> { + if let PlaceElem::Index(local) = elem + && let Some(value) = self.get_const(local.into()) + && self.should_const_prop(&value) + && let interpret::Operand::Immediate(interpret::Immediate::Scalar(scalar)) = *value + && let Ok(offset) = scalar.to_target_usize(&self.tcx) + && let Some(min_length) = offset.checked_add(1) + { + Some(PlaceElem::ConstantIndex { offset, min_length, from_end: false }) + } else { + None + } + } + + fn visit_assign( + &mut self, + place: &mut Place<'tcx>, + rvalue: &mut Rvalue<'tcx>, + location: Location, + ) { + self.super_assign(place, rvalue, location); + + let Some(()) = self.check_rvalue(rvalue) else { return }; + + match self.ecx.machine.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::OnlyInsideOwnBlock | ConstPropMode::FullConstProp => { + if let Some(()) = self.eval_rvalue_with_identities(rvalue, *place) { + self.replace_with_const(*place, rvalue); + } else { + // Const prop failed, so erase the destination, ensuring that whatever happens + // from here on, does not know about the previous value. + // This is important in case we have + // ```rust + // let mut x = 42; + // x = SOME_MUTABLE_STATIC; + // // x must now be uninit + // ``` + // FIXME: we overzealously erase the entire local, because that's easier to + // implement. + trace!( + "propagation into {:?} failed. + Nuking the entire site from orbit, it's the only way to be sure", + place, + ); + Self::remove_const(&mut self.ecx, place.local); + } + } + } } fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { trace!("visit_statement: {:?}", statement); - let source_info = statement.source_info; - self.source_info = Some(source_info); - if let StatementKind::Assign(box (place, ref mut rval)) = statement.kind { - let can_const_prop = self.ecx.machine.can_const_prop[place.local]; - if let Some(()) = self.const_prop(rval, place) { - // This will return None if the above `const_prop` invocation only "wrote" a - // type whose creation requires no write. E.g. a generator whose initial state - // consists solely of uninitialized memory (so it doesn't capture any locals). - if let Some(ref value) = self.get_const(place) && self.should_const_prop(value) { - trace!("replacing {:?} with {:?}", rval, value); - self.replace_with_const(rval, value, source_info); - if can_const_prop == ConstPropMode::FullConstProp - || can_const_prop == ConstPropMode::OnlyInsideOwnBlock - { - trace!("propagated into {:?}", place); - } - } - match can_const_prop { - ConstPropMode::OnlyInsideOwnBlock => { - trace!( - "found local restricted to its block. \ - Will remove it from const-prop after block is finished. Local: {:?}", - place.local - ); - } - ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { - trace!("can't propagate into {:?}", place); - if place.local != RETURN_PLACE { + + // We want to evaluate operands before any change to the assigned-to value, + // so we recurse first. + self.super_statement(statement, location); + + match statement.kind { + StatementKind::SetDiscriminant { ref place, .. } => { + match self.ecx.machine.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { + if self.ecx.statement(statement).is_ok() { + trace!("propped discriminant into {:?}", place); + } else { Self::remove_const(&mut self.ecx, place.local); } } - ConstPropMode::FullConstProp => {} } - } else { - // Const prop failed, so erase the destination, ensuring that whatever happens - // from here on, does not know about the previous value. - // This is important in case we have - // ```rust - // let mut x = 42; - // x = SOME_MUTABLE_STATIC; - // // x must now be uninit - // ``` - // FIXME: we overzealously erase the entire local, because that's easier to - // implement. - trace!( - "propagation into {:?} failed. - Nuking the entire site from orbit, it's the only way to be sure", - place, - ); - Self::remove_const(&mut self.ecx, place.local); } - } else { - match statement.kind { - StatementKind::SetDiscriminant { ref place, .. } => { - match self.ecx.machine.can_const_prop[place.local] { - ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { - if self.use_ecx(|this| this.ecx.statement(statement)).is_some() { - trace!("propped discriminant into {:?}", place); - } else { - Self::remove_const(&mut self.ecx, place.local); - } - } - ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { - Self::remove_const(&mut self.ecx, place.local); - } - } - } - StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { - let frame = self.ecx.frame_mut(); - frame.locals[local].value = - if let StatementKind::StorageLive(_) = statement.kind { - LocalValue::Live(interpret::Operand::Immediate( - interpret::Immediate::Uninit, - )) - } else { - LocalValue::Dead - }; - } - _ => {} + StatementKind::StorageLive(local) => { + Self::remove_const(&mut self.ecx, local); } + // We do not need to mark dead locals as such. For `FullConstProp` locals, + // this allows to propagate the single assigned value in this case: + // ``` + // let x = SOME_CONST; + // if a { + // f(copy x); + // StorageDead(x); + // } else { + // g(copy x); + // StorageDead(x); + // } + // ``` + // + // This may propagate a constant where the local would be uninit or dead. + // In both cases, this does not matter, as those reads would be UB anyway. + _ => {} } - - self.super_statement(statement, location); } fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { - let source_info = terminator.source_info; - self.source_info = Some(source_info); self.super_terminator(terminator, location); + match &mut terminator.kind { TerminatorKind::Assert { expected, ref mut cond, .. } => { - if let Some(ref value) = self.eval_operand(&cond) { + if let Some(ref value) = self.eval_operand(&cond) + && let Ok(value_const) = self.ecx.read_scalar(&value) + && self.should_const_prop(value) + { trace!("assertion on {:?} should be {:?}", value, expected); - let expected = ScalarMaybeUninit::from(Scalar::from_bool(*expected)); - let value_const = self.ecx.read_scalar(&value).unwrap(); - if expected != value_const { - // Poison all places this operand references so that further code - // doesn't use the invalid value - match cond { - Operand::Move(ref place) | Operand::Copy(ref place) => { - Self::remove_const(&mut self.ecx, place.local); - } - Operand::Constant(_) => {} - } - } else { - if self.should_const_prop(value) { - if let ScalarMaybeUninit::Scalar(scalar) = value_const { - *cond = self.operand_from_scalar( - scalar, - self.tcx.types.bool, - source_info.span, - ); - } - } - } + *cond = self.operand_from_scalar(value_const, self.tcx.types.bool); } } TerminatorKind::SwitchInt { ref mut discr, .. } => { @@ -1102,11 +941,10 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { // None of these have Operands to const-propagate. TerminatorKind::Goto { .. } | TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::Drop { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::Yield { .. } | TerminatorKind::GeneratorDrop | TerminatorKind::FalseEdge { .. } @@ -1118,26 +956,38 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { // gated on `mir_opt_level=3`. TerminatorKind::Call { .. } => {} } + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + self.super_basic_block_data(block, data); // We remove all Locals which are restricted in propagation to their containing blocks and // which were modified in the current block. // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. - let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); - for &local in locals.iter() { + let mut written_only_inside_own_block_locals = + std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); + + // This loop can get very hot for some bodies: it check each local in each bb. + // To avoid this quadratic behaviour, we only clear the locals that were modified inside + // the current block. + for local in written_only_inside_own_block_locals.drain() { + debug_assert_eq!( + self.ecx.machine.can_const_prop[local], + ConstPropMode::OnlyInsideOwnBlock + ); Self::remove_const(&mut self.ecx, local); } - locals.clear(); - // Put it back so we reuse the heap of the storage - self.ecx.machine.written_only_inside_own_block_locals = locals; + self.ecx.machine.written_only_inside_own_block_locals = + written_only_inside_own_block_locals; + if cfg!(debug_assertions) { - // Ensure we are correctly erasing locals with the non-debug-assert logic. - for local in self.ecx.machine.only_propagate_inside_block_locals.iter() { - assert!( - self.get_const(local.into()).is_none() - || self - .layout_of(self.local_decls[local].ty) - .map_or(true, |layout| layout.is_zst()) - ) + for (local, &mode) in self.ecx.machine.can_const_prop.iter_enumerated() { + match mode { + ConstPropMode::FullConstProp => {} + ConstPropMode::NoPropagation | ConstPropMode::OnlyInsideOwnBlock => { + self.ensure_not_propagated(local); + } + } } } } diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs index 49db140c474..759650fe4db 100644 --- a/compiler/rustc_mir_transform/src/const_prop_lint.rs +++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs @@ -1,63 +1,40 @@ //! Propagates constants for early reporting of statically known //! assertion failures -use std::cell::Cell; +use std::fmt::Debug; -use rustc_ast::Mutability; -use rustc_data_structures::fx::FxHashSet; +use either::Left; + +use rustc_const_eval::interpret::Immediate; +use rustc_const_eval::interpret::{ + self, InterpCx, InterpResult, LocalValue, MemoryKind, OpTy, Scalar, StackPopCleanup, +}; +use rustc_const_eval::ReportErrorExt; use rustc_hir::def::DefKind; use rustc_hir::HirId; use rustc_index::bit_set::BitSet; -use rustc_index::vec::IndexVec; -use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; -use rustc_middle::mir::{ - AssertKind, BasicBlock, BinOp, Body, Constant, ConstantKind, Local, LocalDecl, LocalKind, - Location, Operand, Place, Rvalue, SourceInfo, SourceScope, SourceScopeData, Statement, - StatementKind, Terminator, TerminatorKind, UnOp, RETURN_PLACE, -}; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; -use rustc_middle::ty::subst::{InternalSubsts, Subst}; +use rustc_middle::ty::InternalSubsts; use rustc_middle::ty::{ - self, ConstInt, ConstKind, EarlyBinder, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, - TypeVisitable, + self, ConstInt, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitableExt, }; -use rustc_session::lint; -use rustc_span::{def_id::DefId, Span}; +use rustc_span::Span; use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout}; -use rustc_target::spec::abi::Abi as CallAbi; use rustc_trait_selection::traits; +use crate::const_prop::CanConstProp; +use crate::const_prop::ConstPropMachine; +use crate::const_prop::ConstPropMode; +use crate::errors::AssertLint; use crate::MirLint; -use rustc_const_eval::const_eval::ConstEvalErr; -use rustc_const_eval::interpret::{ - self, compile_time_machine, AllocId, ConstAllocation, Frame, ImmTy, InterpCx, InterpResult, - LocalState, LocalValue, MemoryKind, OpTy, PlaceTy, Pointer, Scalar, ScalarMaybeUninit, - StackPopCleanup, StackPopUnwind, -}; /// The maximum number of bytes that we'll allocate space for a local or the return value. /// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just /// Severely regress performance. const MAX_ALLOC_LIMIT: u64 = 1024; -/// Macro for machine-specific `InterpError` without allocation. -/// (These will never be shown to the user, but they help diagnose ICEs.) -macro_rules! throw_machine_stop_str { - ($($tt:tt)*) => {{ - // We make a new local type for it. The type itself does not carry any information, - // but its vtable (for the `MachineStopType` trait) does. - struct Zst; - // Printing this type shows the desired string. - impl std::fmt::Display for Zst { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, $($tt)*) - } - } - impl rustc_middle::mir::interpret::MachineStopType for Zst {} - throw_machine_stop!(Zst) - }}; -} - pub struct ConstProp; impl<'tcx> MirLint<'tcx> for ConstProp { @@ -78,7 +55,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp { return; } - let is_generator = tcx.type_of(def_id.to_def_id()).is_generator(); + let is_generator = tcx.type_of(def_id.to_def_id()).subst_identity().is_generator(); // FIXME(welseywiser) const prop doesn't work on generators because of query cycles // computing their layout. if is_generator { @@ -117,10 +94,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp { .predicates .iter() .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None }); - if traits::impossible_predicates( - tcx, - traits::elaborate_predicates(tcx, predicates).map(|o| o.predicate).collect(), - ) { + if traits::impossible_predicates(tcx, traits::elaborate(tcx, predicates).collect()) { trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id); return; } @@ -129,7 +103,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp { let dummy_body = &Body::new( body.source, - body.basic_blocks().clone(), + (*body.basic_blocks).to_owned(), body.source_scopes.clone(), body.local_decls.clone(), Default::default(), @@ -151,182 +125,13 @@ impl<'tcx> MirLint<'tcx> for ConstProp { } } -struct ConstPropMachine<'mir, 'tcx> { - /// The virtual call stack. - stack: Vec<Frame<'mir, 'tcx>>, - /// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end. - written_only_inside_own_block_locals: FxHashSet<Local>, - /// Locals that need to be cleared after every block terminates. - only_propagate_inside_block_locals: BitSet<Local>, - can_const_prop: IndexVec<Local, ConstPropMode>, -} - -impl ConstPropMachine<'_, '_> { - fn new( - only_propagate_inside_block_locals: BitSet<Local>, - can_const_prop: IndexVec<Local, ConstPropMode>, - ) -> Self { - Self { - stack: Vec::new(), - written_only_inside_own_block_locals: Default::default(), - only_propagate_inside_block_locals, - can_const_prop, - } - } -} - -impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> { - compile_time_machine!(<'mir, 'tcx>); - const PANIC_ON_ALLOC_FAIL: bool = true; // all allocations are small (see `MAX_ALLOC_LIMIT`) - - type MemoryKind = !; - - fn load_mir( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _instance: ty::InstanceDef<'tcx>, - ) -> InterpResult<'tcx, &'tcx Body<'tcx>> { - throw_machine_stop_str!("calling functions isn't supported in ConstProp") - } - - fn find_mir_or_eval_fn( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _abi: CallAbi, - _args: &[OpTy<'tcx>], - _destination: &PlaceTy<'tcx>, - _target: Option<BasicBlock>, - _unwind: StackPopUnwind, - ) -> InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { - Ok(None) - } - - fn call_intrinsic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _args: &[OpTy<'tcx>], - _destination: &PlaceTy<'tcx>, - _target: Option<BasicBlock>, - _unwind: StackPopUnwind, - ) -> InterpResult<'tcx> { - throw_machine_stop_str!("calling intrinsics isn't supported in ConstProp") - } - - fn assert_panic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _msg: &rustc_middle::mir::AssertMessage<'tcx>, - _unwind: Option<rustc_middle::mir::BasicBlock>, - ) -> InterpResult<'tcx> { - bug!("panics terminators are not evaluated in ConstProp") - } - - fn binary_ptr_op( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _bin_op: BinOp, - _left: &ImmTy<'tcx>, - _right: &ImmTy<'tcx>, - ) -> InterpResult<'tcx, (Scalar, bool, Ty<'tcx>)> { - // We can't do this because aliasing of memory can differ between const eval and llvm - throw_machine_stop_str!("pointer arithmetic or comparisons aren't supported in ConstProp") - } - - fn access_local<'a>( - frame: &'a Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>, - local: Local, - ) -> InterpResult<'tcx, &'a interpret::Operand<Self::PointerTag>> { - let l = &frame.locals[local]; - - if matches!( - l.value, - LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)) - ) { - // For us "uninit" means "we don't know its value, might be initiailized or not". - // So stop here. - throw_machine_stop_str!("tried to access a local with unknown value") - } - - l.access() - } - - fn access_local_mut<'a>( - ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - frame: usize, - local: Local, - ) -> InterpResult<'tcx, &'a mut interpret::Operand<Self::PointerTag>> { - if ecx.machine.can_const_prop[local] == ConstPropMode::NoPropagation { - throw_machine_stop_str!("tried to write to a local that is marked as not propagatable") - } - if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) { - trace!( - "mutating local {:?} which is restricted to its block. \ - Will remove it from const-prop after block is finished.", - local - ); - ecx.machine.written_only_inside_own_block_locals.insert(local); - } - ecx.machine.stack[frame].locals[local].access_mut() - } - - fn before_access_global( - _tcx: TyCtxt<'tcx>, - _machine: &Self, - _alloc_id: AllocId, - alloc: ConstAllocation<'tcx, Self::PointerTag, Self::AllocExtra>, - _static_def_id: Option<DefId>, - is_write: bool, - ) -> InterpResult<'tcx> { - if is_write { - throw_machine_stop_str!("can't write to global"); - } - // If the static allocation is mutable, then we can't const prop it as its content - // might be different at runtime. - if alloc.inner().mutability == Mutability::Mut { - throw_machine_stop_str!("can't access mutable globals in ConstProp"); - } - - Ok(()) - } - - #[inline(always)] - fn expose_ptr( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _ptr: Pointer<AllocId>, - ) -> InterpResult<'tcx> { - throw_machine_stop_str!("exposing pointers isn't supported in ConstProp") - } - - #[inline(always)] - fn init_frame_extra( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - frame: Frame<'mir, 'tcx>, - ) -> InterpResult<'tcx, Frame<'mir, 'tcx>> { - Ok(frame) - } - - #[inline(always)] - fn stack<'a>( - ecx: &'a InterpCx<'mir, 'tcx, Self>, - ) -> &'a [Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>] { - &ecx.machine.stack - } - - #[inline(always)] - fn stack_mut<'a>( - ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - ) -> &'a mut Vec<Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>> { - &mut ecx.machine.stack - } -} - /// Finds optimization opportunities on the MIR. struct ConstPropagator<'mir, 'tcx> { ecx: InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, - source_scopes: &'mir IndexVec<SourceScope, SourceScopeData<'tcx>>, - local_decls: &'mir IndexVec<Local, LocalDecl<'tcx>>, - // Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store - // the last known `SourceInfo` here and just keep revisiting it. - source_info: Option<SourceInfo>, + worklist: Vec<BasicBlock>, + visited_blocks: BitSet<BasicBlock>, } impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { @@ -370,27 +175,21 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let param_env = tcx.param_env_reveal_all_normalized(def_id); let can_const_prop = CanConstProp::check(tcx, param_env, body); - let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len()); - for (l, mode) in can_const_prop.iter_enumerated() { - if *mode == ConstPropMode::OnlyInsideOwnBlock { - only_propagate_inside_block_locals.insert(l); - } - } let mut ecx = InterpCx::new( tcx, tcx.def_span(def_id), param_env, - ConstPropMachine::new(only_propagate_inside_block_locals, can_const_prop), + ConstPropMachine::new(can_const_prop), ); let ret_layout = ecx - .layout_of(EarlyBinder(body.return_ty()).subst(tcx, substs)) + .layout_of(body.bound_return_ty().subst(tcx, substs)) .ok() // Don't bother allocating memory for large values. // I don't know how return types can seem to be unsized but this happens in the // `type/type-unsatisfiable.rs` test. .filter(|ret_layout| { - !ret_layout.is_unsized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) + ret_layout.is_sized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) }) .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); @@ -411,25 +210,38 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { ecx, tcx, param_env, - source_scopes: &dummy_body.source_scopes, - local_decls: &dummy_body.local_decls, - source_info: None, + worklist: vec![START_BLOCK], + visited_blocks: BitSet::new_empty(body.basic_blocks.len()), } } + fn body(&self) -> &'mir Body<'tcx> { + self.ecx.frame().body + } + + fn local_decls(&self) -> &'mir LocalDecls<'tcx> { + &self.body().local_decls + } + fn get_const(&self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { let op = match self.ecx.eval_place_to_op(place, None) { - Ok(op) => op, + Ok(op) => { + if matches!(*op, interpret::Operand::Immediate(Immediate::Uninit)) { + // Make sure nobody accidentally uses this value. + return None; + } + op + } Err(e) => { - trace!("get_const failed: {}", e); + trace!("get_const failed: {:?}", e.into_kind().debug()); return None; } }; // Try to read the local as an immediate so that if it is representable as a scalar, we can // handle it as such, but otherwise, just return the value as is. - Some(match self.ecx.read_immediate_raw(&op, /*force*/ false) { - Ok(Ok(imm)) => imm.into(), + Some(match self.ecx.read_immediate_raw(&op) { + Ok(Left(imm)) => imm.into(), _ => op, }) } @@ -437,22 +249,21 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { /// Remove `local` from the pool of `Locals`. Allows writing to them, /// but not reading from them anymore. fn remove_const(ecx: &mut InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, local: Local) { - ecx.frame_mut().locals[local] = LocalState { - value: LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)), - layout: Cell::new(None), - }; + ecx.frame_mut().locals[local].value = + LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)); + ecx.machine.written_only_inside_own_block_locals.remove(&local); } fn lint_root(&self, source_info: SourceInfo) -> Option<HirId> { - source_info.scope.lint_root(self.source_scopes) + source_info.scope.lint_root(&self.body().source_scopes) } - fn use_ecx<F, T>(&mut self, source_info: SourceInfo, f: F) -> Option<T> + fn use_ecx<F, T>(&mut self, location: Location, f: F) -> Option<T> where F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, { // Overwrite the PC -- whatever the interpreter does to it does not make any sense anyway. - self.ecx.frame_mut().loc = Err(source_info.span); + self.ecx.frame_mut().loc = Left(location); match f(self) { Ok(val) => Some(val), Err(error) => { @@ -462,8 +273,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // dedicated error variants should be introduced instead. assert!( !error.kind().formatted_string(), - "const-prop encountered formatting error: {}", - error + "const-prop encountered formatting error: {error:?}", ); None } @@ -471,85 +281,46 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } /// Returns the value, if any, of evaluating `c`. - fn eval_constant(&mut self, c: &Constant<'tcx>, source_info: SourceInfo) -> Option<OpTy<'tcx>> { + fn eval_constant(&mut self, c: &Constant<'tcx>, location: Location) -> Option<OpTy<'tcx>> { // FIXME we need to revisit this for #67176 - if c.needs_subst() { + if c.has_param() { return None; } - match self.ecx.mir_const_to_op(&c.literal, None) { - Ok(op) => Some(op), - Err(error) => { - let tcx = self.ecx.tcx.at(c.span); - let err = ConstEvalErr::new(&self.ecx, error, Some(c.span)); - if let Some(lint_root) = self.lint_root(source_info) { - let lint_only = match c.literal { - ConstantKind::Ty(ct) => match ct.kind() { - // Promoteds must lint and not error as the user didn't ask for them - ConstKind::Unevaluated(ty::Unevaluated { - def: _, - substs: _, - promoted: Some(_), - }) => true, - // Out of backwards compatibility we cannot report hard errors in unused - // generic functions using associated constants of the generic parameters. - _ => c.literal.needs_subst(), - }, - ConstantKind::Val(_, ty) => ty.needs_subst(), - }; - if lint_only { - // Out of backwards compatibility we cannot report hard errors in unused - // generic functions using associated constants of the generic parameters. - err.report_as_lint(tcx, "erroneous constant used", lint_root, Some(c.span)); - } else { - err.report_as_error(tcx, "erroneous constant used"); - } - } else { - err.report_as_error(tcx, "erroneous constant used"); - } - None - } - } + // Normalization needed b/c const prop lint runs in + // `mir_drops_elaborated_and_const_checked`, which happens before + // optimized MIR. Only after optimizing the MIR can we guarantee + // that the `RevealAll` pass has happened and that the body's consts + // are normalized, so any call to resolve before that needs to be + // manually normalized. + let val = self.tcx.try_normalize_erasing_regions(self.param_env, c.literal).ok()?; + + self.use_ecx(location, |this| this.ecx.eval_mir_constant(&val, Some(c.span), None)) } /// Returns the value, if any, of evaluating `place`. - fn eval_place(&mut self, place: Place<'tcx>, source_info: SourceInfo) -> Option<OpTy<'tcx>> { + fn eval_place(&mut self, place: Place<'tcx>, location: Location) -> Option<OpTy<'tcx>> { trace!("eval_place(place={:?})", place); - self.use_ecx(source_info, |this| this.ecx.eval_place_to_op(place, None)) + self.use_ecx(location, |this| this.ecx.eval_place_to_op(place, None)) } /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` /// or `eval_place`, depending on the variant of `Operand` used. - fn eval_operand(&mut self, op: &Operand<'tcx>, source_info: SourceInfo) -> Option<OpTy<'tcx>> { + fn eval_operand(&mut self, op: &Operand<'tcx>, location: Location) -> Option<OpTy<'tcx>> { match *op { - Operand::Constant(ref c) => self.eval_constant(c, source_info), - Operand::Move(place) | Operand::Copy(place) => self.eval_place(place, source_info), + Operand::Constant(ref c) => self.eval_constant(c, location), + Operand::Move(place) | Operand::Copy(place) => self.eval_place(place, location), } } - fn report_assert_as_lint( - &self, - lint: &'static lint::Lint, - source_info: SourceInfo, - message: &'static str, - panic: AssertKind<impl std::fmt::Debug>, - ) { - if let Some(lint_root) = self.lint_root(source_info) { - self.tcx.struct_span_lint_hir(lint, lint_root, source_info.span, |lint| { - let mut err = lint.build(message); - err.span_label(source_info.span, format!("{:?}", panic)); - err.emit(); - }); + fn report_assert_as_lint(&self, source_info: &SourceInfo, lint: AssertLint<impl Debug>) { + if let Some(lint_root) = self.lint_root(*source_info) { + self.tcx.emit_spanned_lint(lint.lint(), lint_root, source_info.span, lint); } } - fn check_unary_op( - &mut self, - op: UnOp, - arg: &Operand<'tcx>, - source_info: SourceInfo, - ) -> Option<()> { - if let (val, true) = self.use_ecx(source_info, |this| { + fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>, location: Location) -> Option<()> { + if let (val, true) = self.use_ecx(location, |this| { let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?; let (_res, overflow, _ty) = this.ecx.overflowing_unary_op(op, &val)?; Ok((val, overflow)) @@ -557,11 +328,13 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // `AssertKind` only has an `OverflowNeg` variant, so make sure that is // appropriate to use. assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); + let source_info = self.body().source_info(location); self.report_assert_as_lint( - lint::builtin::ARITHMETIC_OVERFLOW, source_info, - "this arithmetic operation will overflow", - AssertKind::OverflowNeg(val.to_const_int()), + AssertLint::ArithmeticOverflow( + source_info.span, + AssertKind::OverflowNeg(val.to_const_int()), + ), ); return None; } @@ -574,59 +347,59 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { op: BinOp, left: &Operand<'tcx>, right: &Operand<'tcx>, - source_info: SourceInfo, + location: Location, ) -> Option<()> { - let r = self.use_ecx(source_info, |this| { + let r = self.use_ecx(location, |this| { this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?) }); - let l = self.use_ecx(source_info, |this| { - this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?) - }); + let l = self + .use_ecx(location, |this| this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?)); // Check for exceeding shifts *even if* we cannot evaluate the LHS. - if op == BinOp::Shr || op == BinOp::Shl { - let r = r?; + if matches!(op, BinOp::Shr | BinOp::Shl) { + let r = r.clone()?; // We need the type of the LHS. We cannot use `place_layout` as that is the type // of the result, which for checked binops is not the same! - let left_ty = left.ty(self.local_decls, self.tcx); + let left_ty = left.ty(self.local_decls(), self.tcx); let left_size = self.ecx.layout_of(left_ty).ok()?.size; let right_size = r.layout.size; - let r_bits = r.to_scalar().ok(); - let r_bits = r_bits.and_then(|r| r.to_bits(right_size).ok()); - if r_bits.map_or(false, |b| b >= left_size.bits() as u128) { - debug!("check_binary_op: reporting assert for {:?}", source_info); + let r_bits = r.to_scalar().to_bits(right_size).ok(); + if r_bits.is_some_and(|b| b >= left_size.bits() as u128) { + debug!("check_binary_op: reporting assert for {:?}", location); + let source_info = self.body().source_info(location); + let panic = AssertKind::Overflow( + op, + match l { + Some(l) => l.to_const_int(), + // Invent a dummy value, the diagnostic ignores it anyway + None => ConstInt::new( + ScalarInt::try_from_uint(1_u8, left_size).unwrap(), + left_ty.is_signed(), + left_ty.is_ptr_sized_integral(), + ), + }, + r.to_const_int(), + ); self.report_assert_as_lint( - lint::builtin::ARITHMETIC_OVERFLOW, source_info, - "this arithmetic operation will overflow", - AssertKind::Overflow( - op, - match l { - Some(l) => l.to_const_int(), - // Invent a dummy value, the diagnostic ignores it anyway - None => ConstInt::new( - ScalarInt::try_from_uint(1_u8, left_size).unwrap(), - left_ty.is_signed(), - left_ty.is_ptr_sized_integral(), - ), - }, - r.to_const_int(), - ), + AssertLint::ArithmeticOverflow(source_info.span, panic), ); return None; } } - if let (Some(l), Some(r)) = (&l, &r) { + if let (Some(l), Some(r)) = (l, r) { // The remaining operators are handled through `overflowing_binary_op`. - if self.use_ecx(source_info, |this| { - let (_res, overflow, _ty) = this.ecx.overflowing_binary_op(op, l, r)?; + if self.use_ecx(location, |this| { + let (_res, overflow, _ty) = this.ecx.overflowing_binary_op(op, &l, &r)?; Ok(overflow) })? { + let source_info = self.body().source_info(location); self.report_assert_as_lint( - lint::builtin::ARITHMETIC_OVERFLOW, source_info, - "this arithmetic operation will overflow", - AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), + AssertLint::ArithmeticOverflow( + source_info.span, + AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), + ), ); return None; } @@ -634,12 +407,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { Some(()) } - fn const_prop( - &mut self, - rvalue: &Rvalue<'tcx>, - source_info: SourceInfo, - place: Place<'tcx>, - ) -> Option<()> { + fn check_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) -> Option<()> { // Perform any special handling for specific Rvalue types. // Generally, checks here fall into one of two categories: // 1. Additional checking to provide useful lints to the user @@ -654,11 +422,11 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // lint. Rvalue::UnaryOp(op, arg) => { trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); - self.check_unary_op(*op, arg, source_info)?; + self.check_unary_op(*op, arg, location)?; } Rvalue::BinaryOp(op, box (left, right)) => { trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); - self.check_binary_op(*op, left, right, source_info)?; + self.check_binary_op(*op, left, right, location)?; } Rvalue::CheckedBinaryOp(op, box (left, right)) => { trace!( @@ -667,7 +435,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { left, right ); - self.check_binary_op(*op, left, right, source_info)?; + self.check_binary_op(*op, left, right, location)?; } // Do not try creating references (#67862) @@ -703,150 +471,107 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } // FIXME we need to revisit this for #67176 - if rvalue.needs_subst() { + if rvalue.has_param() { + return None; + } + if !rvalue.ty(self.local_decls(), self.tcx).is_sized(self.tcx, self.param_env) { + // the interpreter doesn't support unsized locals (only unsized arguments), + // but rustc does (in a kinda broken way), so we have to skip them here return None; } - self.use_ecx(source_info, |this| this.ecx.eval_rvalue_into_place(rvalue, place)) + Some(()) } -} - -/// The mode that `ConstProp` is allowed to run in for a given `Local`. -#[derive(Clone, Copy, Debug, PartialEq)] -enum ConstPropMode { - /// The `Local` can be propagated into and reads of this `Local` can also be propagated. - FullConstProp, - /// The `Local` can only be propagated into and from its own block. - OnlyInsideOwnBlock, - /// The `Local` can be propagated into but reads cannot be propagated. - OnlyPropagateInto, - /// The `Local` cannot be part of propagation at all. Any statement - /// referencing it either for reading or writing will not get propagated. - NoPropagation, -} -struct CanConstProp { - can_const_prop: IndexVec<Local, ConstPropMode>, - // False at the beginning. Once set, no more assignments are allowed to that local. - found_assignment: BitSet<Local>, - // Cache of locals' information - local_kinds: IndexVec<Local, LocalKind>, -} + fn check_assertion( + &mut self, + expected: bool, + msg: &AssertKind<Operand<'tcx>>, + cond: &Operand<'tcx>, + location: Location, + ) -> Option<!> { + let value = &self.eval_operand(&cond, location)?; + trace!("assertion on {:?} should be {:?}", value, expected); + + let expected = Scalar::from_bool(expected); + let value_const = self.use_ecx(location, |this| this.ecx.read_scalar(&value))?; + + if expected != value_const { + // Poison all places this operand references so that further code + // doesn't use the invalid value + if let Some(place) = cond.place() { + Self::remove_const(&mut self.ecx, place.local); + } -impl CanConstProp { - /// Returns true if `local` can be propagated - fn check<'tcx>( - tcx: TyCtxt<'tcx>, - param_env: ParamEnv<'tcx>, - body: &Body<'tcx>, - ) -> IndexVec<Local, ConstPropMode> { - let mut cpv = CanConstProp { - can_const_prop: IndexVec::from_elem(ConstPropMode::FullConstProp, &body.local_decls), - found_assignment: BitSet::new_empty(body.local_decls.len()), - local_kinds: IndexVec::from_fn_n( - |local| body.local_kind(local), - body.local_decls.len(), - ), - }; - for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { - let ty = body.local_decls[local].ty; - match tcx.layout_of(param_env.and(ty)) { - Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} - // Either the layout fails to compute, then we can't use this local anyway - // or the local is too large, then we don't want to. - _ => { - *val = ConstPropMode::NoPropagation; - continue; - } + enum DbgVal<T> { + Val(T), + Underscore, } - // Cannot use args at all - // Cannot use locals because if x < y { y - x } else { x - y } would - // lint for x != y - // FIXME(oli-obk): lint variables until they are used in a condition - // FIXME(oli-obk): lint if return value is constant - if cpv.local_kinds[local] == LocalKind::Arg { - *val = ConstPropMode::OnlyPropagateInto; - trace!( - "local {:?} can't be const propagated because it's a function argument", - local - ); - } else if cpv.local_kinds[local] == LocalKind::Var { - *val = ConstPropMode::OnlyInsideOwnBlock; - trace!( - "local {:?} will only be propagated inside its block, because it's a user variable", - local - ); + impl<T: std::fmt::Debug> std::fmt::Debug for DbgVal<T> { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Val(val) => val.fmt(fmt), + Self::Underscore => fmt.write_str("_"), + } + } } + let mut eval_to_int = |op| { + // This can be `None` if the lhs wasn't const propagated and we just + // triggered the assert on the value of the rhs. + self.eval_operand(op, location) + .and_then(|op| self.ecx.read_immediate(&op).ok()) + .map_or(DbgVal::Underscore, |op| DbgVal::Val(op.to_const_int())) + }; + let msg = match msg { + AssertKind::DivisionByZero(op) => AssertKind::DivisionByZero(eval_to_int(op)), + AssertKind::RemainderByZero(op) => AssertKind::RemainderByZero(eval_to_int(op)), + AssertKind::Overflow(bin_op @ (BinOp::Div | BinOp::Rem), op1, op2) => { + // Division overflow is *UB* in the MIR, and different than the + // other overflow checks. + AssertKind::Overflow(*bin_op, eval_to_int(op1), eval_to_int(op2)) + } + AssertKind::BoundsCheck { ref len, ref index } => { + let len = eval_to_int(len); + let index = eval_to_int(index); + AssertKind::BoundsCheck { len, index } + } + // Remaining overflow errors are already covered by checks on the binary operators. + AssertKind::Overflow(..) | AssertKind::OverflowNeg(_) => return None, + // Need proper const propagator for these. + _ => return None, + }; + let source_info = self.body().source_info(location); + self.report_assert_as_lint( + source_info, + AssertLint::UnconditionalPanic(source_info.span, msg), + ); } - cpv.visit_body(&body); - cpv.can_const_prop + + None } -} -impl Visitor<'_> for CanConstProp { - fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) { - use rustc_middle::mir::visit::PlaceContext::*; - match context { - // Projections are fine, because `&mut foo.x` will be caught by - // `MutatingUseContext::Borrow` elsewhere. - MutatingUse(MutatingUseContext::Projection) - // These are just stores, where the storing is not propagatable, but there may be later - // mutations of the same local via `Store` - | MutatingUse(MutatingUseContext::Call) - | MutatingUse(MutatingUseContext::AsmOutput) - | MutatingUse(MutatingUseContext::Deinit) - // Actual store that can possibly even propagate a value - | MutatingUse(MutatingUseContext::SetDiscriminant) - | MutatingUse(MutatingUseContext::Store) => { - if !self.found_assignment.insert(local) { - match &mut self.can_const_prop[local] { - // If the local can only get propagated in its own block, then we don't have - // to worry about multiple assignments, as we'll nuke the const state at the - // end of the block anyway, and inside the block we overwrite previous - // states as applicable. - ConstPropMode::OnlyInsideOwnBlock => {} - ConstPropMode::NoPropagation => {} - ConstPropMode::OnlyPropagateInto => {} - other @ ConstPropMode::FullConstProp => { - trace!( - "local {:?} can't be propagated because of multiple assignments. Previous state: {:?}", - local, other, - ); - *other = ConstPropMode::OnlyInsideOwnBlock; - } - } - } - } - // Reading constants is allowed an arbitrary number of times - NonMutatingUse(NonMutatingUseContext::Copy) - | NonMutatingUse(NonMutatingUseContext::Move) - | NonMutatingUse(NonMutatingUseContext::Inspect) - | NonMutatingUse(NonMutatingUseContext::Projection) - | NonUse(_) => {} - - // These could be propagated with a smarter analysis or just some careful thinking about - // whether they'd be fine right now. - MutatingUse(MutatingUseContext::Yield) - | MutatingUse(MutatingUseContext::Drop) - | MutatingUse(MutatingUseContext::Retag) - // These can't ever be propagated under any scheme, as we can't reason about indirect - // mutation. - | NonMutatingUse(NonMutatingUseContext::SharedBorrow) - | NonMutatingUse(NonMutatingUseContext::ShallowBorrow) - | NonMutatingUse(NonMutatingUseContext::UniqueBorrow) - | NonMutatingUse(NonMutatingUseContext::AddressOf) - | MutatingUse(MutatingUseContext::Borrow) - | MutatingUse(MutatingUseContext::AddressOf) => { - trace!("local {:?} can't be propagaged because it's used: {:?}", local, context); - self.can_const_prop[local] = ConstPropMode::NoPropagation; - } + fn ensure_not_propagated(&self, local: Local) { + if cfg!(debug_assertions) { + assert!( + self.get_const(local.into()).is_none() + || self + .layout_of(self.local_decls()[local].ty) + .map_or(true, |layout| layout.is_zst()), + "failed to remove values for `{local:?}`, value={:?}", + self.get_const(local.into()), + ) } } } impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { fn visit_body(&mut self, body: &Body<'tcx>) { - for (bb, data) in body.basic_blocks().iter_enumerated() { + while let Some(bb) = self.worklist.pop() { + if !self.visited_blocks.insert(bb) { + continue; + } + + let data = &body.basic_blocks[bb]; self.visit_basic_block_data(bb, data); } } @@ -858,198 +583,147 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { fn visit_constant(&mut self, constant: &Constant<'tcx>, location: Location) { trace!("visit_constant: {:?}", constant); self.super_constant(constant, location); - self.eval_constant(constant, self.source_info.unwrap()); + self.eval_constant(constant, location); + } + + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + self.super_assign(place, rvalue, location); + + let Some(()) = self.check_rvalue(rvalue, location) else { return }; + + match self.ecx.machine.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::OnlyInsideOwnBlock | ConstPropMode::FullConstProp => { + if self + .use_ecx(location, |this| this.ecx.eval_rvalue_into_place(rvalue, *place)) + .is_none() + { + // Const prop failed, so erase the destination, ensuring that whatever happens + // from here on, does not know about the previous value. + // This is important in case we have + // ```rust + // let mut x = 42; + // x = SOME_MUTABLE_STATIC; + // // x must now be uninit + // ``` + // FIXME: we overzealously erase the entire local, because that's easier to + // implement. + trace!( + "propagation into {:?} failed. + Nuking the entire site from orbit, it's the only way to be sure", + place, + ); + Self::remove_const(&mut self.ecx, place.local); + } + } + } } fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { trace!("visit_statement: {:?}", statement); - let source_info = statement.source_info; - self.source_info = Some(source_info); - if let StatementKind::Assign(box (place, ref rval)) = statement.kind { - let can_const_prop = self.ecx.machine.can_const_prop[place.local]; - if let Some(()) = self.const_prop(rval, source_info, place) { - match can_const_prop { - ConstPropMode::OnlyInsideOwnBlock => { - trace!( - "found local restricted to its block. \ - Will remove it from const-prop after block is finished. Local: {:?}", - place.local - ); - } - ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { - trace!("can't propagate into {:?}", place); - if place.local != RETURN_PLACE { + + // We want to evaluate operands before any change to the assigned-to value, + // so we recurse first. + self.super_statement(statement, location); + + match statement.kind { + StatementKind::SetDiscriminant { ref place, .. } => { + match self.ecx.machine.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { + if self.use_ecx(location, |this| this.ecx.statement(statement)).is_some() { + trace!("propped discriminant into {:?}", place); + } else { Self::remove_const(&mut self.ecx, place.local); } } - ConstPropMode::FullConstProp => {} } - } else { - // Const prop failed, so erase the destination, ensuring that whatever happens - // from here on, does not know about the previous value. - // This is important in case we have - // ```rust - // let mut x = 42; - // x = SOME_MUTABLE_STATIC; - // // x must now be uninit - // ``` - // FIXME: we overzealously erase the entire local, because that's easier to - // implement. - trace!( - "propagation into {:?} failed. - Nuking the entire site from orbit, it's the only way to be sure", - place, - ); - Self::remove_const(&mut self.ecx, place.local); } - } else { - match statement.kind { - StatementKind::SetDiscriminant { ref place, .. } => { - match self.ecx.machine.can_const_prop[place.local] { - ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { - if self - .use_ecx(source_info, |this| this.ecx.statement(statement)) - .is_some() - { - trace!("propped discriminant into {:?}", place); - } else { - Self::remove_const(&mut self.ecx, place.local); - } - } - ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { - Self::remove_const(&mut self.ecx, place.local); - } - } - } - StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { - let frame = self.ecx.frame_mut(); - frame.locals[local].value = - if let StatementKind::StorageLive(_) = statement.kind { - LocalValue::Live(interpret::Operand::Immediate( - interpret::Immediate::Uninit, - )) - } else { - LocalValue::Dead - }; - } - _ => {} + StatementKind::StorageLive(local) => { + let frame = self.ecx.frame_mut(); + frame.locals[local].value = + LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)); + } + StatementKind::StorageDead(local) => { + let frame = self.ecx.frame_mut(); + frame.locals[local].value = LocalValue::Dead; } + _ => {} } - - self.super_statement(statement, location); } fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - let source_info = terminator.source_info; - self.source_info = Some(source_info); self.super_terminator(terminator, location); match &terminator.kind { TerminatorKind::Assert { expected, ref msg, ref cond, .. } => { - if let Some(ref value) = self.eval_operand(&cond, source_info) { - trace!("assertion on {:?} should be {:?}", value, expected); - let expected = ScalarMaybeUninit::from(Scalar::from_bool(*expected)); - let value_const = self.ecx.read_scalar(&value).unwrap(); - if expected != value_const { - enum DbgVal<T> { - Val(T), - Underscore, - } - impl<T: std::fmt::Debug> std::fmt::Debug for DbgVal<T> { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Val(val) => val.fmt(fmt), - Self::Underscore => fmt.write_str("_"), - } - } - } - let mut eval_to_int = |op| { - // This can be `None` if the lhs wasn't const propagated and we just - // triggered the assert on the value of the rhs. - self.eval_operand(op, source_info).map_or(DbgVal::Underscore, |op| { - DbgVal::Val(self.ecx.read_immediate(&op).unwrap().to_const_int()) - }) - }; - let msg = match msg { - AssertKind::DivisionByZero(op) => { - Some(AssertKind::DivisionByZero(eval_to_int(op))) - } - AssertKind::RemainderByZero(op) => { - Some(AssertKind::RemainderByZero(eval_to_int(op))) - } - AssertKind::Overflow(bin_op @ (BinOp::Div | BinOp::Rem), op1, op2) => { - // Division overflow is *UB* in the MIR, and different than the - // other overflow checks. - Some(AssertKind::Overflow( - *bin_op, - eval_to_int(op1), - eval_to_int(op2), - )) - } - AssertKind::BoundsCheck { ref len, ref index } => { - let len = eval_to_int(len); - let index = eval_to_int(index); - Some(AssertKind::BoundsCheck { len, index }) - } - // Remaining overflow errors are already covered by checks on the binary operators. - AssertKind::Overflow(..) | AssertKind::OverflowNeg(_) => None, - // Need proper const propagator for these. - _ => None, - }; - // Poison all places this operand references so that further code - // doesn't use the invalid value - match cond { - Operand::Move(ref place) | Operand::Copy(ref place) => { - Self::remove_const(&mut self.ecx, place.local); - } - Operand::Constant(_) => {} - } - if let Some(msg) = msg { - self.report_assert_as_lint( - lint::builtin::UNCONDITIONAL_PANIC, - source_info, - "this operation will panic at runtime", - msg, - ); - } - } + self.check_assertion(*expected, msg, cond, location); + } + TerminatorKind::SwitchInt { ref discr, ref targets } => { + if let Some(ref value) = self.eval_operand(&discr, location) + && let Some(value_const) = self.use_ecx(location, |this| this.ecx.read_scalar(&value)) + && let Ok(constant) = value_const.try_to_int() + && let Ok(constant) = constant.to_bits(constant.size()) + { + // We managed to evaluate the discriminant, so we know we only need to visit + // one target. + let target = targets.target_for_value(constant); + self.worklist.push(target); + return; } + // We failed to evaluate the discriminant, fallback to visiting all successors. } // None of these have Operands to const-propagate. TerminatorKind::Goto { .. } | TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::Drop { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::Yield { .. } | TerminatorKind::GeneratorDrop | TerminatorKind::FalseEdge { .. } | TerminatorKind::FalseUnwind { .. } - | TerminatorKind::SwitchInt { .. } | TerminatorKind::Call { .. } | TerminatorKind::InlineAsm { .. } => {} } + self.worklist.extend(terminator.successors()); + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { + self.super_basic_block_data(block, data); + // We remove all Locals which are restricted in propagation to their containing blocks and // which were modified in the current block. // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. - let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); - for &local in locals.iter() { + let mut written_only_inside_own_block_locals = + std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); + + // This loop can get very hot for some bodies: it check each local in each bb. + // To avoid this quadratic behaviour, we only clear the locals that were modified inside + // the current block. + for local in written_only_inside_own_block_locals.drain() { + debug_assert_eq!( + self.ecx.machine.can_const_prop[local], + ConstPropMode::OnlyInsideOwnBlock + ); Self::remove_const(&mut self.ecx, local); } - locals.clear(); - // Put it back so we reuse the heap of the storage - self.ecx.machine.written_only_inside_own_block_locals = locals; + self.ecx.machine.written_only_inside_own_block_locals = + written_only_inside_own_block_locals; + if cfg!(debug_assertions) { - // Ensure we are correctly erasing locals with the non-debug-assert logic. - for local in self.ecx.machine.only_propagate_inside_block_locals.iter() { - assert!( - self.get_const(local.into()).is_none() - || self - .layout_of(self.local_decls[local].ty) - .map_or(true, |layout| layout.is_zst()) - ) + for (local, &mode) in self.ecx.machine.can_const_prop.iter_enumerated() { + match mode { + ConstPropMode::FullConstProp => {} + ConstPropMode::NoPropagation | ConstPropMode::OnlyInsideOwnBlock => { + self.ensure_not_propagated(local); + } + } } } } diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs new file mode 100644 index 00000000000..3df459dfa79 --- /dev/null +++ b/compiler/rustc_mir_transform/src/copy_prop.rs @@ -0,0 +1,182 @@ +use rustc_index::bit_set::BitSet; +use rustc_index::IndexSlice; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::impls::borrowed_locals; + +use crate::ssa::SsaLocals; +use crate::MirPass; + +/// Unify locals that copy each other. +/// +/// We consider patterns of the form +/// _a = rvalue +/// _b = move? _a +/// _c = move? _a +/// _d = move? _c +/// where each of the locals is only assigned once. +/// +/// We want to replace all those locals by `_a`, either copied or moved. +pub struct CopyProp; + +impl<'tcx> MirPass<'tcx> for CopyProp { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + #[instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + propagate_ssa(tcx, body); + } +} + +fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let borrowed_locals = borrowed_locals(body); + let ssa = SsaLocals::new(body); + + let fully_moved = fully_moved_locals(&ssa, body); + debug!(?fully_moved); + + let mut storage_to_remove = BitSet::new_empty(fully_moved.domain_size()); + for (local, &head) in ssa.copy_classes().iter_enumerated() { + if local != head { + storage_to_remove.insert(head); + } + } + + let any_replacement = ssa.copy_classes().iter_enumerated().any(|(l, &h)| l != h); + + Replacer { + tcx, + copy_classes: &ssa.copy_classes(), + fully_moved, + borrowed_locals, + storage_to_remove, + } + .visit_body_preserves_cfg(body); + + if any_replacement { + crate::simplify::remove_unused_definitions(body); + } +} + +/// `SsaLocals` computed equivalence classes between locals considering copy/move assignments. +/// +/// This function also returns whether all the `move?` in the pattern are `move` and not copies. +/// A local which is in the bitset can be replaced by `move _a`. Otherwise, it must be +/// replaced by `copy _a`, as we cannot move multiple times from `_a`. +/// +/// If an operand copies `_c`, it must happen before the assignment `_d = _c`, otherwise it is UB. +/// This means that replacing it by a copy of `_a` if ok, since this copy happens before `_c` is +/// moved, and therefore that `_d` is moved. +#[instrument(level = "trace", skip(ssa, body))] +fn fully_moved_locals(ssa: &SsaLocals, body: &Body<'_>) -> BitSet<Local> { + let mut fully_moved = BitSet::new_filled(body.local_decls.len()); + + for (_, rvalue, _) in ssa.assignments(body) { + let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) | Rvalue::CopyForDeref(place)) + = rvalue + else { continue }; + + let Some(rhs) = place.as_local() else { continue }; + if !ssa.is_ssa(rhs) { + continue; + } + + if let Rvalue::Use(Operand::Copy(_)) | Rvalue::CopyForDeref(_) = rvalue { + fully_moved.remove(rhs); + } + } + + ssa.meet_copy_equivalence(&mut fully_moved); + + fully_moved +} + +/// Utility to help performing substitution of `*pattern` by `target`. +struct Replacer<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + fully_moved: BitSet<Local>, + storage_to_remove: BitSet<Local>, + borrowed_locals: BitSet<Local>, + copy_classes: &'a IndexSlice<Local, Local>, +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, ctxt: PlaceContext, _: Location) { + let new_local = self.copy_classes[*local]; + match ctxt { + // Do not modify the local in storage statements. + PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {} + // The local should have been marked as non-SSA. + PlaceContext::MutatingUse(_) => assert_eq!(*local, new_local), + // We access the value. + _ => *local = new_local, + } + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, ctxt: PlaceContext, loc: Location) { + if let Some(new_projection) = self.process_projection(&place.projection, loc) { + place.projection = self.tcx().mk_place_elems(&new_projection); + } + + let observes_address = match ctxt { + PlaceContext::NonMutatingUse( + NonMutatingUseContext::SharedBorrow + | NonMutatingUseContext::ShallowBorrow + | NonMutatingUseContext::AddressOf, + ) => true, + // For debuginfo, merging locals is ok. + PlaceContext::NonUse(NonUseContext::VarDebugInfo) => { + self.borrowed_locals.contains(place.local) + } + _ => false, + }; + if observes_address && !place.is_indirect() { + // We observe the address of `place.local`. Do not replace it. + } else { + self.visit_local( + &mut place.local, + PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy), + loc, + ) + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) { + if let Operand::Move(place) = *operand + // A move out of a projection of a copy is equivalent to a copy of the original projection. + && !place.has_deref() + && !self.fully_moved.contains(place.local) + { + *operand = Operand::Copy(place); + } + self.super_operand(operand, loc); + } + + fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) { + // When removing storage statements, we need to remove both (#107511). + if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind + && self.storage_to_remove.contains(l) + { + stmt.make_nop(); + return + } + + self.super_statement(stmt, loc); + + // Do not leave tautological assignments around. + if let StatementKind::Assign(box (lhs, ref rhs)) = stmt.kind + && let Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)) | Rvalue::CopyForDeref(rhs) = *rhs + && lhs == rhs + { + stmt.make_nop(); + } + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index 45de0c28035..658e01d9310 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -520,7 +520,7 @@ impl<'a> BcbCounters<'a> { let mut found_loop_exit = false; for &branch in branches.iter() { if backedge_from_bcbs.iter().any(|&backedge_from_bcb| { - self.bcb_is_dominated_by(backedge_from_bcb, branch.target_bcb) + self.bcb_dominates(branch.target_bcb, backedge_from_bcb) }) { if let Some(reloop_branch) = some_reloop_branch { if reloop_branch.counter(&self.basic_coverage_blocks).is_none() { @@ -603,8 +603,8 @@ impl<'a> BcbCounters<'a> { } #[inline] - fn bcb_is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { - self.basic_coverage_blocks.is_dominated_by(node, dom) + fn bcb_dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { + self.basic_coverage_blocks.dominates(dom, node) } #[inline] diff --git a/compiler/rustc_mir_transform/src/coverage/debug.rs b/compiler/rustc_mir_transform/src/coverage/debug.rs index 0f8679b0bd6..6a3d42511ac 100644 --- a/compiler/rustc_mir_transform/src/coverage/debug.rs +++ b/compiler/rustc_mir_transform/src/coverage/debug.rs @@ -118,7 +118,7 @@ use rustc_middle::mir::spanview::{self, SpanViewable}; use rustc_data_structures::fx::FxHashMap; use rustc_middle::mir::coverage::*; -use rustc_middle::mir::{self, BasicBlock, TerminatorKind}; +use rustc_middle::mir::{self, BasicBlock}; use rustc_middle::ty::TyCtxt; use rustc_span::Span; @@ -292,10 +292,8 @@ impl DebugCounters { } pub fn some_block_label(&self, operand: ExpressionOperandId) -> Option<&String> { - self.some_counters.as_ref().map_or(None, |counters| { - counters - .get(&operand) - .map_or(None, |debug_counter| debug_counter.some_block_label.as_ref()) + self.some_counters.as_ref().and_then(|counters| { + counters.get(&operand).and_then(|debug_counter| debug_counter.some_block_label.as_ref()) }) } @@ -323,7 +321,10 @@ impl DebugCounters { String::new() }, self.format_operand(lhs), - if op == Op::Add { "+" } else { "-" }, + match op { + Op::Add => "+", + Op::Subtract => "-", + }, self.format_operand(rhs), ); } @@ -638,7 +639,7 @@ pub(super) fn dump_coverage_spanview<'tcx>( let def_id = mir_source.def_id(); let span_viewables = span_viewables(tcx, mir_body, basic_coverage_blocks, &coverage_spans); - let mut file = create_dump_file(tcx, "html", None, pass_name, &0, mir_source) + let mut file = create_dump_file(tcx, "html", false, pass_name, &0i32, mir_body) .expect("Unexpected error creating MIR spanview HTML file"); let crate_name = tcx.crate_name(def_id.krate); let item_name = tcx.def_path(def_id).to_filename_friendly_no_crate(); @@ -739,7 +740,7 @@ pub(super) fn dump_coverage_graphviz<'tcx>( .join("\n ") )); } - let mut file = create_dump_file(tcx, "dot", None, pass_name, &0, mir_source) + let mut file = create_dump_file(tcx, "dot", false, pass_name, &0i32, mir_body) .expect("Unexpected error creating BasicCoverageBlock graphviz DOT file"); graphviz_writer .write_graphviz(tcx, &mut file) @@ -795,7 +796,7 @@ fn bcb_to_string_sections<'tcx>( } let non_term_blocks = bcb_data.basic_blocks[0..len - 1] .iter() - .map(|&bb| format!("{:?}: {}", bb, term_type(&mir_body[bb].terminator().kind))) + .map(|&bb| format!("{:?}: {}", bb, mir_body[bb].terminator().kind.name())) .collect::<Vec<_>>(); if non_term_blocks.len() > 0 { sections.push(non_term_blocks.join("\n")); @@ -803,29 +804,7 @@ fn bcb_to_string_sections<'tcx>( sections.push(format!( "{:?}: {}", bcb_data.basic_blocks.last().unwrap(), - term_type(&bcb_data.terminator(mir_body).kind) + bcb_data.terminator(mir_body).kind.name(), )); sections } - -/// Returns a simple string representation of a `TerminatorKind` variant, independent of any -/// values it might hold. -pub(super) fn term_type(kind: &TerminatorKind<'_>) -> &'static str { - match kind { - TerminatorKind::Goto { .. } => "Goto", - TerminatorKind::SwitchInt { .. } => "SwitchInt", - TerminatorKind::Resume => "Resume", - TerminatorKind::Abort => "Abort", - TerminatorKind::Return => "Return", - TerminatorKind::Unreachable => "Unreachable", - TerminatorKind::Drop { .. } => "Drop", - TerminatorKind::DropAndReplace { .. } => "DropAndReplace", - TerminatorKind::Call { .. } => "Call", - TerminatorKind::Assert { .. } => "Assert", - TerminatorKind::Yield { .. } => "Yield", - TerminatorKind::GeneratorDrop => "GeneratorDrop", - TerminatorKind::FalseEdge { .. } => "FalseEdge", - TerminatorKind::FalseUnwind { .. } => "FalseUnwind", - TerminatorKind::InlineAsm { .. } => "InlineAsm", - } -} diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs index 759ea7cd328..ea1223fbca6 100644 --- a/compiler/rustc_mir_transform/src/coverage/graph.rs +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -5,10 +5,11 @@ use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::graph::dominators::{self, Dominators}; use rustc_data_structures::graph::{self, GraphSuccessors, WithNumNodes, WithStartNode}; use rustc_index::bit_set::BitSet; -use rustc_index::vec::IndexVec; +use rustc_index::{IndexSlice, IndexVec}; use rustc_middle::mir::coverage::*; use rustc_middle::mir::{self, BasicBlock, BasicBlockData, Terminator, TerminatorKind}; +use std::cmp::Ordering; use std::ops::{Index, IndexMut}; const ID_SEPARATOR: &str = ","; @@ -37,8 +38,7 @@ impl CoverageGraph { // `SwitchInt` to have multiple targets to the same destination `BasicBlock`, so // de-duplication is required. This is done without reordering the successors. - let bcbs_len = bcbs.len(); - let mut seen = IndexVec::from_elem_n(false, bcbs_len); + let mut seen = IndexVec::from_elem(false, &bcbs); let successors = IndexVec::from_fn_n( |bcb| { for b in seen.iter_mut() { @@ -60,7 +60,7 @@ impl CoverageGraph { bcbs.len(), ); - let mut predecessors = IndexVec::from_elem_n(Vec::new(), bcbs.len()); + let mut predecessors = IndexVec::from_elem(Vec::new(), &bcbs); for (bcb, bcb_successors) in successors.iter_enumerated() { for &successor in bcb_successors { predecessors[successor].push(bcb); @@ -112,7 +112,7 @@ impl CoverageGraph { if predecessors.len() > 1 { "predecessors.len() > 1".to_owned() } else { - format!("bb {} is not in precessors: {:?}", bb.index(), predecessors) + format!("bb {} is not in predecessors: {:?}", bb.index(), predecessors) } ); } @@ -123,7 +123,7 @@ impl CoverageGraph { match term.kind { TerminatorKind::Return { .. } - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Yield { .. } | TerminatorKind::SwitchInt { .. } => { // The `bb` has more than one _outgoing_ edge, or exits the function. Save the @@ -137,7 +137,7 @@ impl CoverageGraph { debug!(" because term.kind = {:?}", term.kind); // Note that this condition is based on `TerminatorKind`, even though it // theoretically boils down to `successors().len() != 1`; that is, either zero - // (e.g., `Return`, `Abort`) or multiple successors (e.g., `SwitchInt`), but + // (e.g., `Return`, `Terminate`) or multiple successors (e.g., `SwitchInt`), but // since the BCB CFG ignores things like unwind branches (which exist in the // `Terminator`s `successors()` list) checking the number of successors won't // work. @@ -156,7 +156,6 @@ impl CoverageGraph { | TerminatorKind::Resume | TerminatorKind::Unreachable | TerminatorKind::Drop { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::Call { .. } | TerminatorKind::GeneratorDrop | TerminatorKind::Assert { .. } @@ -177,10 +176,10 @@ impl CoverageGraph { fn add_basic_coverage_block( bcbs: &mut IndexVec<BasicCoverageBlock, BasicCoverageBlockData>, - bb_to_bcb: &mut IndexVec<BasicBlock, Option<BasicCoverageBlock>>, + bb_to_bcb: &mut IndexSlice<BasicBlock, Option<BasicCoverageBlock>>, basic_blocks: Vec<BasicBlock>, ) { - let bcb = BasicCoverageBlock::from_usize(bcbs.len()); + let bcb = bcbs.next_index(); for &bb in basic_blocks.iter() { bb_to_bcb[bb] = Some(bcb); } @@ -209,13 +208,17 @@ impl CoverageGraph { } #[inline(always)] - pub fn is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { - self.dominators.as_ref().unwrap().is_dominated_by(node, dom) + pub fn dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { + self.dominators.as_ref().unwrap().dominates(dom, node) } #[inline(always)] - pub fn dominators(&self) -> &Dominators<BasicCoverageBlock> { - self.dominators.as_ref().unwrap() + pub fn rank_partial_cmp( + &self, + a: BasicCoverageBlock, + b: BasicCoverageBlock, + ) -> Option<Ordering> { + self.dominators.as_ref().unwrap().rank_partial_cmp(a, b) } } @@ -282,9 +285,9 @@ impl graph::WithPredecessors for CoverageGraph { rustc_index::newtype_index! { /// A node in the control-flow graph of CoverageGraph. + #[debug_format = "bcb{}"] pub(super) struct BasicCoverageBlock { - DEBUG_FORMAT = "bcb{}", - const START_BCB = 0, + const START_BCB = 0; } } @@ -312,7 +315,7 @@ rustc_index::newtype_index! { /// to the BCB's primary counter or expression). /// /// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based -/// queries (`is_dominated_by()`, `predecessors`, `successors`, etc.) have branch (control flow) +/// queries (`dominates()`, `predecessors`, `successors`, etc.) have branch (control flow) /// significance. #[derive(Debug, Clone)] pub(super) struct BasicCoverageBlockData { @@ -538,29 +541,29 @@ impl TraverseCoverageGraphWithLoops { "TraverseCoverageGraphWithLoops::next - context_stack: {:?}", self.context_stack.iter().rev().collect::<Vec<_>>() ); - while let Some(next_bcb) = { - // Strip contexts with empty worklists from the top of the stack - while self.context_stack.last().map_or(false, |context| context.worklist.is_empty()) { + + while let Some(context) = self.context_stack.last_mut() { + if let Some(next_bcb) = context.worklist.pop() { + if !self.visited.insert(next_bcb) { + debug!("Already visited: {:?}", next_bcb); + continue; + } + debug!("Visiting {:?}", next_bcb); + if self.backedges[next_bcb].len() > 0 { + debug!("{:?} is a loop header! Start a new TraversalContext...", next_bcb); + self.context_stack.push(TraversalContext { + loop_backedges: Some((self.backedges[next_bcb].clone(), next_bcb)), + worklist: Vec::new(), + }); + } + self.extend_worklist(basic_coverage_blocks, next_bcb); + return Some(next_bcb); + } else { + // Strip contexts with empty worklists from the top of the stack self.context_stack.pop(); } - // Pop the next bcb off of the current context_stack. If none, all BCBs were visited. - self.context_stack.last_mut().map_or(None, |context| context.worklist.pop()) - } { - if !self.visited.insert(next_bcb) { - debug!("Already visited: {:?}", next_bcb); - continue; - } - debug!("Visiting {:?}", next_bcb); - if self.backedges[next_bcb].len() > 0 { - debug!("{:?} is a loop header! Start a new TraversalContext...", next_bcb); - self.context_stack.push(TraversalContext { - loop_backedges: Some((self.backedges[next_bcb].clone(), next_bcb)), - worklist: Vec::new(), - }); - } - self.extend_worklist(basic_coverage_blocks, next_bcb); - return Some(next_bcb); } + None } @@ -594,7 +597,7 @@ impl TraverseCoverageGraphWithLoops { // branching block would have given an `Expression` (or vice versa). let (some_successor_to_add, some_loop_header) = if let Some((_, loop_header)) = context.loop_backedges { - if basic_coverage_blocks.is_dominated_by(successor, loop_header) { + if basic_coverage_blocks.dominates(loop_header, successor) { (Some(successor), Some(loop_header)) } else { (None, None) @@ -652,29 +655,9 @@ pub(super) fn find_loop_backedges( let mut backedges = IndexVec::from_elem_n(Vec::<BasicCoverageBlock>::new(), num_bcbs); // Identify loops by their backedges. - // - // The computational complexity is bounded by: n(s) x d where `n` is the number of - // `BasicCoverageBlock` nodes (the simplified/reduced representation of the CFG derived from the - // MIR); `s` is the average number of successors per node (which is most likely less than 2, and - // independent of the size of the function, so it can be treated as a constant); - // and `d` is the average number of dominators per node. - // - // The average number of dominators depends on the size and complexity of the function, and - // nodes near the start of the function's control flow graph typically have less dominators - // than nodes near the end of the CFG. Without doing a detailed mathematical analysis, I - // think the resulting complexity has the characteristics of O(n log n). - // - // The overall complexity appears to be comparable to many other MIR transform algorithms, and I - // don't expect that this function is creating a performance hot spot, but if this becomes an - // issue, there may be ways to optimize the `is_dominated_by` algorithm (as indicated by an - // existing `FIXME` comment in that code), or possibly ways to optimize it's usage here, perhaps - // by keeping track of results for visited `BasicCoverageBlock`s if they can be used to short - // circuit downstream `is_dominated_by` checks. - // - // For now, that kind of optimization seems unnecessarily complicated. for (bcb, _) in basic_coverage_blocks.iter_enumerated() { for &successor in &basic_coverage_blocks.successors[bcb] { - if basic_coverage_blocks.is_dominated_by(bcb, successor) { + if basic_coverage_blocks.dominates(successor, bcb) { let loop_header = successor; let backedge_from_bcb = bcb; debug!( @@ -713,7 +696,7 @@ impl< ShortCircuitPreorder { body, - visited: BitSet::new_empty(body.basic_blocks().len()), + visited: BitSet::new_empty(body.basic_blocks.len()), worklist, filtered_successors, } @@ -747,7 +730,7 @@ impl< } fn size_hint(&self) -> (usize, Option<usize>) { - let size = self.body.basic_blocks().len() - self.visited.count(); + let size = self.body.basic_blocks.len() - self.visited.count(); (size, Some(size)) } } diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index 2619626a567..076e714d703 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -16,7 +16,7 @@ use crate::MirPass; use rustc_data_structures::graph::WithNumNodes; use rustc_data_structures::sync::Lrc; -use rustc_index::vec::IndexVec; +use rustc_index::IndexVec; use rustc_middle::hir; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; use rustc_middle::mir::coverage::*; @@ -80,7 +80,7 @@ impl<'tcx> MirPass<'tcx> for InstrumentCoverage { return; } - match mir_body.basic_blocks()[mir::START_BLOCK].terminator().kind { + match mir_body.basic_blocks[mir::START_BLOCK].terminator().kind { TerminatorKind::Unreachable => { trace!("InstrumentCoverage skipped for unreachable `START_BLOCK`"); return; @@ -514,7 +514,7 @@ fn make_code_region( // Extend an empty span by one character so the region will be counted. let CharPos(char_pos) = start_col; if span.hi() == body_span.hi() { - start_col = CharPos(char_pos - 1); + start_col = CharPos(char_pos.saturating_sub(1)); } else { end_col = CharPos(char_pos + 1); } @@ -533,15 +533,16 @@ fn make_code_region( } } -fn fn_sig_and_body<'tcx>( - tcx: TyCtxt<'tcx>, +fn fn_sig_and_body( + tcx: TyCtxt<'_>, def_id: DefId, -) -> (Option<&'tcx rustc_hir::FnSig<'tcx>>, &'tcx rustc_hir::Body<'tcx>) { +) -> (Option<&rustc_hir::FnSig<'_>>, &rustc_hir::Body<'_>) { // FIXME(#79625): Consider improving MIR to provide the information needed, to avoid going back // to HIR for it. let hir_node = tcx.hir().get_if_local(def_id).expect("expected DefId is local"); - let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body"); - (hir::map::fn_sig(hir_node), tcx.hir().body(fn_body_id)) + let (_, fn_body_id) = + hir::map::associated_body(hir_node).expect("HIR node is a function with body"); + (hir_node.fn_sig(), tcx.hir().body(fn_body_id)) } fn get_body_span<'tcx>( @@ -576,5 +577,10 @@ fn get_body_span<'tcx>( fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx rustc_hir::Body<'tcx>) -> u64 { // FIXME(cjgillot) Stop hashing HIR manually here. let owner = hir_body.id().hir_id.owner; - tcx.hir_owner_nodes(owner).unwrap().hash_including_bodies.to_smaller_hash() + tcx.hir_owner_nodes(owner) + .unwrap() + .opt_hash_including_bodies + .unwrap() + .to_smaller_hash() + .as_u64() } diff --git a/compiler/rustc_mir_transform/src/coverage/query.rs b/compiler/rustc_mir_transform/src/coverage/query.rs index 9d02f58ae65..74b4b4a07c5 100644 --- a/compiler/rustc_mir_transform/src/coverage/query.rs +++ b/compiler/rustc_mir_transform/src/coverage/query.rs @@ -2,7 +2,7 @@ use super::*; use rustc_middle::mir::coverage::*; use rustc_middle::mir::{self, Body, Coverage, CoverageInfo}; -use rustc_middle::ty::query::Providers; +use rustc_middle::query::Providers; use rustc_middle::ty::{self, TyCtxt}; use rustc_span::def_id::DefId; @@ -84,7 +84,7 @@ impl CoverageVisitor { } fn visit_body(&mut self, body: &Body<'_>) { - for bb_data in body.basic_blocks().iter() { + for bb_data in body.basic_blocks.iter() { for statement in bb_data.statements.iter() { if let StatementKind::Coverage(box ref coverage) = statement.kind { if is_inlined(body, statement) { @@ -136,9 +136,9 @@ fn coverageinfo<'tcx>(tcx: TyCtxt<'tcx>, instance_def: ty::InstanceDef<'tcx>) -> coverage_visitor.info } -fn covered_code_regions<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Vec<&'tcx CodeRegion> { +fn covered_code_regions(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<&CodeRegion> { let body = mir_body(tcx, def_id); - body.basic_blocks() + body.basic_blocks .iter() .flat_map(|data| { data.statements.iter().filter_map(|statement| match statement.kind { @@ -163,8 +163,7 @@ fn is_inlined(body: &Body<'_>, statement: &Statement<'_>) -> bool { /// This function ensures we obtain the correct MIR for the given item irrespective of /// whether that means const mir or runtime mir. For `const fn` this opts for runtime /// mir. -fn mir_body<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx mir::Body<'tcx> { - let id = ty::WithOptConstParam::unknown(def_id); - let def = ty::InstanceDef::Item(id); +fn mir_body(tcx: TyCtxt<'_>, def_id: DefId) -> &mir::Body<'_> { + let def = ty::InstanceDef::Item(def_id); tcx.instance_mir(def) } diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 423e78317aa..d27200419e2 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -1,4 +1,3 @@ -use super::debug::term_type; use super::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB}; use itertools::Itertools; @@ -40,7 +39,7 @@ impl CoverageStatement { "{}: @{}.{}: {:?}", source_range_no_file(tcx, span), bb.index(), - term_type(&term.kind), + term.kind.name(), term.kind ) } @@ -63,7 +62,7 @@ impl CoverageStatement { /// Note: A `CoverageStatement` merged into another CoverageSpan may come from a `BasicBlock` that /// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches /// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock` -/// `is_dominated_by()` the `BasicBlock`s in this `CoverageSpan`. +/// `dominates()` the `BasicBlock`s in this `CoverageSpan`. #[derive(Debug, Clone)] pub(super) struct CoverageSpan { pub span: Span, @@ -341,11 +340,11 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { if a.is_in_same_bcb(b) { Some(Ordering::Equal) } else { - // Sort equal spans by dominator relationship, in reverse order (so - // dominators always come after the dominated equal spans). When later - // comparing two spans in order, the first will either dominate the second, - // or they will have no dominator relationship. - self.basic_coverage_blocks.dominators().rank_partial_cmp(b.bcb, a.bcb) + // Sort equal spans by dominator relationship (so dominators always come + // before the dominated equal spans). When later comparing two spans in + // order, the first will either dominate the second, or they will have no + // dominator relationship. + self.basic_coverage_blocks.rank_partial_cmp(a.bcb, b.bcb) } } else { // Sort hi() in reverse order so shorter spans are attempted after longer spans. @@ -407,7 +406,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { if self.prev().is_macro_expansion() && self.curr().is_macro_expansion() { // Macros that expand to include branching (such as // `assert_eq!()`, `assert_ne!()`, `info!()`, `debug!()`, or - // `trace!()) typically generate callee spans with identical + // `trace!()`) typically generate callee spans with identical // ranges (typically the full span of the macro) for all // `BasicBlocks`. This makes it impossible to distinguish // the condition (`if val1 != val2`) from the optional @@ -694,7 +693,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { /// `prev.span.hi()` will be greater than (further right of) `prev_original_span.hi()`. /// If prev.span() was split off to the right of a closure, prev.span().lo() will be /// greater than prev_original_span.lo(). The actual span of `prev_original_span` is - /// not as important as knowing that `prev()` **used to have the same span** as `curr(), + /// not as important as knowing that `prev()` **used to have the same span** as `curr()`, /// which means their sort order is still meaningful for determining the dominator /// relationship. /// @@ -705,12 +704,12 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { fn hold_pending_dups_unless_dominated(&mut self) { // Equal coverage spans are ordered by dominators before dominated (if any), so it should be // impossible for `curr` to dominate any previous `CoverageSpan`. - debug_assert!(!self.span_bcb_is_dominated_by(self.prev(), self.curr())); + debug_assert!(!self.span_bcb_dominates(self.curr(), self.prev())); let initial_pending_count = self.pending_dups.len(); if initial_pending_count > 0 { let mut pending_dups = self.pending_dups.split_off(0); - pending_dups.retain(|dup| !self.span_bcb_is_dominated_by(self.curr(), dup)); + pending_dups.retain(|dup| !self.span_bcb_dominates(dup, self.curr())); self.pending_dups.append(&mut pending_dups); if self.pending_dups.len() < initial_pending_count { debug!( @@ -721,7 +720,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } } - if self.span_bcb_is_dominated_by(self.curr(), self.prev()) { + if self.span_bcb_dominates(self.prev(), self.curr()) { debug!( " different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}", self.prev() @@ -787,8 +786,8 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } } - fn span_bcb_is_dominated_by(&self, covspan: &CoverageSpan, dom_covspan: &CoverageSpan) -> bool { - self.basic_coverage_blocks.is_dominated_by(covspan.bcb, dom_covspan.bcb) + fn span_bcb_dominates(&self, dom_covspan: &CoverageSpan, covspan: &CoverageSpan) -> bool { + self.basic_coverage_blocks.dominates(dom_covspan.bcb, covspan.bcb) } } @@ -802,6 +801,8 @@ pub(super) fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> | StatementKind::StorageDead(_) // Coverage should not be encountered, but don't inject coverage coverage | StatementKind::Coverage(_) + // Ignore `ConstEvalCounter`s + | StatementKind::ConstEvalCounter // Ignore `Nop`s | StatementKind::Nop => None, @@ -825,11 +826,12 @@ pub(super) fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> // Retain spans from all other statements StatementKind::FakeRead(box (_, _)) // Not including `ForGuardBinding` - | StatementKind::CopyNonOverlapping(..) + | StatementKind::Intrinsic(..) | StatementKind::Assign(_) | StatementKind::SetDiscriminant { .. } | StatementKind::Deinit(..) | StatementKind::Retag(_, _) + | StatementKind::PlaceMention(..) | StatementKind::AscribeUserType(_, _) => { Some(statement.source_info.span) } @@ -848,7 +850,6 @@ pub(super) fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Sp TerminatorKind::Unreachable // Unreachable blocks are not connected to the MIR CFG | TerminatorKind::Assert { .. } | TerminatorKind::Drop { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::SwitchInt { .. } // For `FalseEdge`, only the `real` branch is taken, so it is similar to a `Goto`. | TerminatorKind::FalseEdge { .. } @@ -867,7 +868,7 @@ pub(super) fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Sp // Retain spans from all other terminators TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Yield { .. } | TerminatorKind::GeneratorDrop diff --git a/compiler/rustc_mir_transform/src/coverage/test_macros/Cargo.toml b/compiler/rustc_mir_transform/src/coverage/test_macros/Cargo.toml index f5e8b65656a..f753caa9124 100644 --- a/compiler/rustc_mir_transform/src/coverage/test_macros/Cargo.toml +++ b/compiler/rustc_mir_transform/src/coverage/test_macros/Cargo.toml @@ -5,4 +5,3 @@ edition = "2021" [lib] proc-macro = true -doctest = false diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs index 6380f03528a..90b58933df7 100644 --- a/compiler/rustc_mir_transform/src/coverage/tests.rs +++ b/compiler/rustc_mir_transform/src/coverage/tests.rs @@ -25,7 +25,6 @@ //! to: `rustc_span::create_default_session_globals_then(|| { test_here(); })`. use super::counters; -use super::debug; use super::graph; use super::spans; @@ -34,10 +33,10 @@ use coverage_test_macros::let_bcb; use itertools::Itertools; use rustc_data_structures::graph::WithNumNodes; use rustc_data_structures::graph::WithSuccessors; -use rustc_index::vec::{Idx, IndexVec}; +use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::coverage::CoverageKind; use rustc_middle::mir::*; -use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::ty; use rustc_span::{self, BytePos, Pos, Span, DUMMY_SP}; // All `TEMP_BLOCK` targets should be replaced before calling `to_body() -> mir::Body`. @@ -47,7 +46,6 @@ struct MockBlocks<'tcx> { blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>, dummy_place: Place<'tcx>, next_local: usize, - bool_ty: Ty<'tcx>, } impl<'tcx> MockBlocks<'tcx> { @@ -56,7 +54,6 @@ impl<'tcx> MockBlocks<'tcx> { blocks: IndexVec::new(), dummy_place: Place { local: RETURN_PLACE, projection: ty::List::empty() }, next_local: 0, - bool_ty: TyCtxt::BOOL_TY_FOR_UNIT_TESTING, } } @@ -67,7 +64,7 @@ impl<'tcx> MockBlocks<'tcx> { } fn push(&mut self, kind: TerminatorKind<'tcx>) -> BasicBlock { - let next_lo = if let Some(last) = self.blocks.last() { + let next_lo = if let Some(last) = self.blocks.last_index() { self.blocks[last].terminator().source_info.span.hi() } else { BytePos(1) @@ -88,7 +85,6 @@ impl<'tcx> MockBlocks<'tcx> { TerminatorKind::Assert { ref mut target, .. } | TerminatorKind::Call { target: Some(ref mut target), .. } | TerminatorKind::Drop { ref mut target, .. } - | TerminatorKind::DropAndReplace { ref mut target, .. } | TerminatorKind::FalseEdge { real_target: ref mut target, .. } | TerminatorKind::FalseUnwind { real_target: ref mut target, .. } | TerminatorKind::Goto { ref mut target } @@ -143,7 +139,7 @@ impl<'tcx> MockBlocks<'tcx> { args: vec![], destination: self.dummy_place.clone(), target: Some(TEMP_BLOCK), - cleanup: None, + unwind: UnwindAction::Continue, from_hir_call: false, fn_span: DUMMY_SP, }, @@ -157,7 +153,6 @@ impl<'tcx> MockBlocks<'tcx> { fn switchint(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { let switchint_kind = TerminatorKind::SwitchInt { discr: Operand::Move(Place::from(self.new_temp())), - switch_ty: self.bool_ty, // just a dummy value targets: SwitchTargets::static_if(0, TEMP_BLOCK, TEMP_BLOCK), }; self.add_block_from(some_from_block, switchint_kind) @@ -172,11 +167,11 @@ impl<'tcx> MockBlocks<'tcx> { } } -fn debug_basic_blocks<'tcx>(mir_body: &Body<'tcx>) -> String { +fn debug_basic_blocks(mir_body: &Body<'_>) -> String { format!( "{:?}", mir_body - .basic_blocks() + .basic_blocks .iter_enumerated() .map(|(bb, data)| { let term = &data.terminator(); @@ -187,18 +182,17 @@ fn debug_basic_blocks<'tcx>(mir_body: &Body<'tcx>) -> String { TerminatorKind::Assert { target, .. } | TerminatorKind::Call { target: Some(target), .. } | TerminatorKind::Drop { target, .. } - | TerminatorKind::DropAndReplace { target, .. } | TerminatorKind::FalseEdge { real_target: target, .. } | TerminatorKind::FalseUnwind { real_target: target, .. } | TerminatorKind::Goto { target } | TerminatorKind::InlineAsm { destination: Some(target), .. } | TerminatorKind::Yield { resume: target, .. } => { - format!("{}{:?}:{} -> {:?}", sp, bb, debug::term_type(kind), target) + format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), target) } TerminatorKind::SwitchInt { targets, .. } => { - format!("{}{:?}:{} -> {:?}", sp, bb, debug::term_type(kind), targets) + format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), targets) } - _ => format!("{}{:?}:{}", sp, bb, debug::term_type(kind)), + _ => format!("{}{:?}:{}", sp, bb, kind.name()), } }) .collect::<Vec<_>>() @@ -213,14 +207,14 @@ fn print_mir_graphviz(name: &str, mir_body: &Body<'_>) { "digraph {} {{\n{}\n}}", name, mir_body - .basic_blocks() + .basic_blocks .iter_enumerated() .map(|(bb, data)| { format!( " {:?} [label=\"{:?}: {}\"];\n{}", bb, bb, - debug::term_type(&data.terminator().kind), + data.terminator().kind.name(), mir_body .basic_blocks .successors(bb) @@ -249,7 +243,7 @@ fn print_coverage_graphviz( " {:?} [label=\"{:?}: {}\"];\n{}", bcb, bcb, - debug::term_type(&bcb_data.terminator(mir_body).kind), + bcb_data.terminator(mir_body).kind.name(), basic_coverage_blocks .successors(bcb) .map(|successor| { format!(" {:?} -> {:?};", bcb, successor) }) @@ -653,7 +647,7 @@ fn test_traverse_coverage_with_loops() { fn synthesize_body_span_from_terminators(mir_body: &Body<'_>) -> Span { let mut some_span: Option<Span> = None; - for (_, data) in mir_body.basic_blocks().iter_enumerated() { + for (_, data) in mir_body.basic_blocks.iter_enumerated() { let term_span = data.terminator().source_info.span; if let Some(span) = some_span.as_mut() { *span = span.to(term_span); diff --git a/compiler/rustc_mir_transform/src/ctfe_limit.rs b/compiler/rustc_mir_transform/src/ctfe_limit.rs new file mode 100644 index 00000000000..bf5722b3d00 --- /dev/null +++ b/compiler/rustc_mir_transform/src/ctfe_limit.rs @@ -0,0 +1,58 @@ +//! A pass that inserts the `ConstEvalCounter` instruction into any blocks that have a back edge +//! (thus indicating there is a loop in the CFG), or whose terminator is a function call. +use crate::MirPass; + +use rustc_data_structures::graph::dominators::Dominators; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, Body, Statement, StatementKind, TerminatorKind, +}; +use rustc_middle::ty::TyCtxt; + +pub struct CtfeLimit; + +impl<'tcx> MirPass<'tcx> for CtfeLimit { + #[instrument(skip(self, _tcx, body))] + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let doms = body.basic_blocks.dominators(); + let indices: Vec<BasicBlock> = body + .basic_blocks + .iter_enumerated() + .filter_map(|(node, node_data)| { + if matches!(node_data.terminator().kind, TerminatorKind::Call { .. }) + // Back edges in a CFG indicate loops + || has_back_edge(&doms, node, &node_data) + { + Some(node) + } else { + None + } + }) + .collect(); + for index in indices { + insert_counter( + body.basic_blocks_mut() + .get_mut(index) + .expect("basic_blocks index {index} should exist"), + ); + } + } +} + +fn has_back_edge( + doms: &Dominators<BasicBlock>, + node: BasicBlock, + node_data: &BasicBlockData<'_>, +) -> bool { + if !doms.is_reachable(node) { + return false; + } + // Check if any of the dominators of the node are also the node's successor. + node_data.terminator().successors().any(|succ| doms.dominates(succ, node)) +} + +fn insert_counter(basic_block_data: &mut BasicBlockData<'_>) { + basic_block_data.statements.push(Statement { + source_info: basic_block_data.terminator().source_info, + kind: StatementKind::ConstEvalCounter, + }); +} diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs new file mode 100644 index 00000000000..78fb196358f --- /dev/null +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -0,0 +1,627 @@ +//! A constant propagation optimization pass based on dataflow analysis. +//! +//! Currently, this pass only propagates scalar values. + +use rustc_const_eval::const_eval::CheckAlignment; +use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar}; +use rustc_data_structures::fx::FxHashMap; +use rustc_hir::def::DefKind; +use rustc_middle::mir::visit::{MutVisitor, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_mir_dataflow::value_analysis::{ + Map, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace, +}; +use rustc_mir_dataflow::{ + lattice::FlatSet, Analysis, Results, ResultsVisitor, SwitchIntEdgeEffects, +}; +use rustc_span::DUMMY_SP; +use rustc_target::abi::{Align, FieldIdx, VariantIdx}; + +use crate::MirPass; + +// These constants are somewhat random guesses and have not been optimized. +// If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive). +const BLOCK_LIMIT: usize = 100; +const PLACE_LIMIT: usize = 100; + +pub struct DataflowConstProp; + +impl<'tcx> MirPass<'tcx> for DataflowConstProp { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 3 + } + + #[instrument(skip_all level = "debug")] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT { + debug!("aborted dataflow const prop due too many basic blocks"); + return; + } + + // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators. + // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function + // applications, where `h` is the height of the lattice. Because the height of our lattice + // is linear w.r.t. the number of tracked places, this is `O(tracked_places * n)`. However, + // because every transfer function application could traverse the whole map, this becomes + // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of + // map nodes is strongly correlated to the number of tracked places, this becomes more or + // less `O(n)` if we place a constant limit on the number of tracked places. + let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None }; + + // Decide which places to track during the analysis. + let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit); + + // Perform the actual dataflow analysis. + let analysis = ConstAnalysis::new(tcx, body, map); + let mut results = debug_span!("analyze") + .in_scope(|| analysis.wrap().into_engine(tcx, body).iterate_to_fixpoint()); + + // Collect results and patch the body afterwards. + let mut visitor = CollectAndPatch::new(tcx); + debug_span!("collect").in_scope(|| results.visit_reachable_with(body, &mut visitor)); + debug_span!("patch").in_scope(|| visitor.visit_body(body)); + } +} + +struct ConstAnalysis<'a, 'tcx> { + map: Map, + tcx: TyCtxt<'tcx>, + local_decls: &'a LocalDecls<'tcx>, + ecx: InterpCx<'tcx, 'tcx, DummyMachine>, + param_env: ty::ParamEnv<'tcx>, +} + +impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { + type Value = FlatSet<ScalarTy<'tcx>>; + + const NAME: &'static str = "ConstAnalysis"; + + fn map(&self) -> &Map { + &self.map + } + + fn handle_set_discriminant( + &self, + place: Place<'tcx>, + variant_index: VariantIdx, + state: &mut State<Self::Value>, + ) { + state.flood_discr(place.as_ref(), &self.map); + if self.map.find_discr(place.as_ref()).is_some() { + let enum_ty = place.ty(self.local_decls, self.tcx).ty; + if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) { + state.assign_discr( + place.as_ref(), + ValueOrPlace::Value(FlatSet::Elem(discr)), + &self.map, + ); + } + } + } + + fn handle_assign( + &self, + target: Place<'tcx>, + rvalue: &Rvalue<'tcx>, + state: &mut State<Self::Value>, + ) { + match rvalue { + Rvalue::Aggregate(kind, operands) => { + // If we assign `target = Enum::Variant#0(operand)`, + // we must make sure that all `target as Variant#i` are `Top`. + state.flood(target.as_ref(), self.map()); + + let Some(target_idx) = self.map().find(target.as_ref()) else { return }; + + let (variant_target, variant_index) = match **kind { + AggregateKind::Tuple | AggregateKind::Closure(..) => (Some(target_idx), None), + AggregateKind::Adt(def_id, variant_index, ..) => { + match self.tcx.def_kind(def_id) { + DefKind::Struct => (Some(target_idx), None), + DefKind::Enum => ( + self.map.apply(target_idx, TrackElem::Variant(variant_index)), + Some(variant_index), + ), + _ => return, + } + } + _ => return, + }; + if let Some(variant_target_idx) = variant_target { + for (field_index, operand) in operands.iter().enumerate() { + if let Some(field) = self.map().apply( + variant_target_idx, + TrackElem::Field(FieldIdx::from_usize(field_index)), + ) { + let result = self.handle_operand(operand, state); + state.insert_idx(field, result, self.map()); + } + } + } + if let Some(variant_index) = variant_index + && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant) + { + // We are assigning the discriminant as part of an aggregate. + // This discriminant can only alias a variant field's value if the operand + // had an invalid value for that type. + // Using invalid values is UB, so we are allowed to perform the assignment + // without extra flooding. + let enum_ty = target.ty(self.local_decls, self.tcx).ty; + if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) { + state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map); + } + } + } + Rvalue::CheckedBinaryOp(op, box (left, right)) => { + // Flood everything now, so we can use `insert_value_idx` directly later. + state.flood(target.as_ref(), self.map()); + + let Some(target) = self.map().find(target.as_ref()) else { return }; + + let value_target = self.map().apply(target, TrackElem::Field(0_u32.into())); + let overflow_target = self.map().apply(target, TrackElem::Field(1_u32.into())); + + if value_target.is_some() || overflow_target.is_some() { + let (val, overflow) = self.binary_op(state, *op, left, right); + + if let Some(value_target) = value_target { + // We have flooded `target` earlier. + state.insert_value_idx(value_target, val, self.map()); + } + if let Some(overflow_target) = overflow_target { + let overflow = match overflow { + FlatSet::Top => FlatSet::Top, + FlatSet::Elem(overflow) => { + self.wrap_scalar(Scalar::from_bool(overflow), self.tcx.types.bool) + } + FlatSet::Bottom => FlatSet::Bottom, + }; + // We have flooded `target` earlier. + state.insert_value_idx(overflow_target, overflow, self.map()); + } + } + } + _ => self.super_assign(target, rvalue, state), + } + } + + fn handle_rvalue( + &self, + rvalue: &Rvalue<'tcx>, + state: &mut State<Self::Value>, + ) -> ValueOrPlace<Self::Value> { + match rvalue { + Rvalue::Cast( + kind @ (CastKind::IntToInt + | CastKind::FloatToInt + | CastKind::FloatToFloat + | CastKind::IntToFloat), + operand, + ty, + ) => match self.eval_operand(operand, state) { + FlatSet::Elem(op) => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + self.ecx.int_to_int_or_float(&op, *ty) + } + CastKind::FloatToInt | CastKind::FloatToFloat => { + self.ecx.float_to_float_or_int(&op, *ty) + } + _ => unreachable!(), + } + .map(|result| ValueOrPlace::Value(self.wrap_immediate(result, *ty))) + .unwrap_or(ValueOrPlace::TOP), + _ => ValueOrPlace::TOP, + }, + Rvalue::BinaryOp(op, box (left, right)) => { + // Overflows must be ignored here. + let (val, _overflow) = self.binary_op(state, *op, left, right); + ValueOrPlace::Value(val) + } + Rvalue::UnaryOp(op, operand) => match self.eval_operand(operand, state) { + FlatSet::Elem(value) => self + .ecx + .unary_op(*op, &value) + .map(|val| ValueOrPlace::Value(self.wrap_immty(val))) + .unwrap_or(ValueOrPlace::Value(FlatSet::Top)), + FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom), + FlatSet::Top => ValueOrPlace::Value(FlatSet::Top), + }, + Rvalue::Discriminant(place) => { + ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map())) + } + _ => self.super_rvalue(rvalue, state), + } + } + + fn handle_constant( + &self, + constant: &Constant<'tcx>, + _state: &mut State<Self::Value>, + ) -> Self::Value { + constant + .literal + .eval(self.tcx, self.param_env) + .try_to_scalar() + .map(|value| FlatSet::Elem(ScalarTy(value, constant.ty()))) + .unwrap_or(FlatSet::Top) + } + + fn handle_switch_int( + &self, + discr: &Operand<'tcx>, + apply_edge_effects: &mut impl SwitchIntEdgeEffects<State<Self::Value>>, + ) { + // FIXME: The dataflow framework only provides the state if we call `apply()`, which makes + // this more inefficient than it has to be. + let mut discr_value = None; + let mut handled = false; + apply_edge_effects.apply(|state, target| { + let discr_value = match discr_value { + Some(value) => value, + None => { + let value = match self.handle_operand(discr, state) { + ValueOrPlace::Value(value) => value, + ValueOrPlace::Place(place) => state.get_idx(place, self.map()), + }; + let result = match value { + FlatSet::Top => FlatSet::Top, + FlatSet::Elem(ScalarTy(scalar, _)) => { + let int = scalar.assert_int(); + FlatSet::Elem(int.assert_bits(int.size())) + } + FlatSet::Bottom => FlatSet::Bottom, + }; + discr_value = Some(result); + result + } + }; + + let FlatSet::Elem(choice) = discr_value else { + // Do nothing if we don't know which branch will be taken. + return + }; + + if target.value.map(|n| n == choice).unwrap_or(!handled) { + // Branch is taken. Has no effect on state. + handled = true; + } else { + // Branch is not taken. + state.mark_unreachable(); + } + }) + } +} + +#[derive(Clone, PartialEq, Eq)] +struct ScalarTy<'tcx>(Scalar, Ty<'tcx>); + +impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // This is used for dataflow visualization, so we return something more concise. + std::fmt::Display::fmt(&ConstantKind::Val(ConstValue::Scalar(self.0), self.1), f) + } +} + +impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { + pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self { + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + Self { + map, + tcx, + local_decls: &body.local_decls, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), + param_env: param_env, + } + } + + fn binary_op( + &self, + state: &mut State<FlatSet<ScalarTy<'tcx>>>, + op: BinOp, + left: &Operand<'tcx>, + right: &Operand<'tcx>, + ) -> (FlatSet<ScalarTy<'tcx>>, FlatSet<bool>) { + let left = self.eval_operand(left, state); + let right = self.eval_operand(right, state); + match (left, right) { + (FlatSet::Elem(left), FlatSet::Elem(right)) => { + match self.ecx.overflowing_binary_op(op, &left, &right) { + Ok((val, overflow, ty)) => (self.wrap_scalar(val, ty), FlatSet::Elem(overflow)), + _ => (FlatSet::Top, FlatSet::Top), + } + } + (FlatSet::Bottom, _) | (_, FlatSet::Bottom) => (FlatSet::Bottom, FlatSet::Bottom), + (_, _) => { + // Could attempt some algebraic simplifications here. + (FlatSet::Top, FlatSet::Top) + } + } + } + + fn eval_operand( + &self, + op: &Operand<'tcx>, + state: &mut State<FlatSet<ScalarTy<'tcx>>>, + ) -> FlatSet<ImmTy<'tcx>> { + let value = match self.handle_operand(op, state) { + ValueOrPlace::Value(value) => value, + ValueOrPlace::Place(place) => state.get_idx(place, &self.map), + }; + match value { + FlatSet::Top => FlatSet::Top, + FlatSet::Elem(ScalarTy(scalar, ty)) => self + .tcx + .layout_of(self.param_env.and(ty)) + .map(|layout| FlatSet::Elem(ImmTy::from_scalar(scalar, layout))) + .unwrap_or(FlatSet::Top), + FlatSet::Bottom => FlatSet::Bottom, + } + } + + fn eval_discriminant( + &self, + enum_ty: Ty<'tcx>, + variant_index: VariantIdx, + ) -> Option<ScalarTy<'tcx>> { + if !enum_ty.is_enum() { + return None; + } + let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?; + let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?; + let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?; + Some(ScalarTy(discr_value, discr.ty)) + } + + fn wrap_scalar(&self, scalar: Scalar, ty: Ty<'tcx>) -> FlatSet<ScalarTy<'tcx>> { + FlatSet::Elem(ScalarTy(scalar, ty)) + } + + fn wrap_immediate(&self, imm: Immediate, ty: Ty<'tcx>) -> FlatSet<ScalarTy<'tcx>> { + match imm { + Immediate::Scalar(scalar) => self.wrap_scalar(scalar, ty), + _ => FlatSet::Top, + } + } + + fn wrap_immty(&self, val: ImmTy<'tcx>) -> FlatSet<ScalarTy<'tcx>> { + self.wrap_immediate(*val, val.layout.ty) + } +} + +struct CollectAndPatch<'tcx> { + tcx: TyCtxt<'tcx>, + + /// For a given MIR location, this stores the values of the operands used by that location. In + /// particular, this is before the effect, such that the operands of `_1 = _1 + _2` are + /// properly captured. (This may become UB soon, but it is currently emitted even by safe code.) + before_effect: FxHashMap<(Location, Place<'tcx>), ScalarTy<'tcx>>, + + /// Stores the assigned values for assignments where the Rvalue is constant. + assignments: FxHashMap<Location, ScalarTy<'tcx>>, +} + +impl<'tcx> CollectAndPatch<'tcx> { + fn new(tcx: TyCtxt<'tcx>) -> Self { + Self { tcx, before_effect: FxHashMap::default(), assignments: FxHashMap::default() } + } + + fn make_operand(&self, scalar: ScalarTy<'tcx>) -> Operand<'tcx> { + Operand::Constant(Box::new(Constant { + span: DUMMY_SP, + user_ty: None, + literal: ConstantKind::Val(ConstValue::Scalar(scalar.0), scalar.1), + })) + } +} + +impl<'mir, 'tcx> + ResultsVisitor<'mir, 'tcx, Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>> + for CollectAndPatch<'tcx> +{ + type FlowState = State<FlatSet<ScalarTy<'tcx>>>; + + fn visit_statement_before_primary_effect( + &mut self, + results: &Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, + state: &Self::FlowState, + statement: &'mir Statement<'tcx>, + location: Location, + ) { + match &statement.kind { + StatementKind::Assign(box (_, rvalue)) => { + OperandCollector { state, visitor: self, map: &results.analysis.0.map } + .visit_rvalue(rvalue, location); + } + _ => (), + } + } + + fn visit_statement_after_primary_effect( + &mut self, + results: &Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, + state: &Self::FlowState, + statement: &'mir Statement<'tcx>, + location: Location, + ) { + match statement.kind { + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(_)))) => { + // Don't overwrite the assignment if it already uses a constant (to keep the span). + } + StatementKind::Assign(box (place, _)) => { + match state.get(place.as_ref(), &results.analysis.0.map) { + FlatSet::Top => (), + FlatSet::Elem(value) => { + self.assignments.insert(location, value); + } + FlatSet::Bottom => { + // This assignment is either unreachable, or an uninitialized value is assigned. + } + } + } + _ => (), + } + } + + fn visit_terminator_before_primary_effect( + &mut self, + results: &Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, + state: &Self::FlowState, + terminator: &'mir Terminator<'tcx>, + location: Location, + ) { + OperandCollector { state, visitor: self, map: &results.analysis.0.map } + .visit_terminator(terminator, location); + } +} + +impl<'tcx> MutVisitor<'tcx> for CollectAndPatch<'tcx> { + fn tcx<'a>(&'a self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + if let Some(value) = self.assignments.get(&location) { + match &mut statement.kind { + StatementKind::Assign(box (_, rvalue)) => { + *rvalue = Rvalue::Use(self.make_operand(value.clone())); + } + _ => bug!("found assignment info for non-assign statement"), + } + } else { + self.super_statement(statement, location); + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + if let Some(value) = self.before_effect.get(&(location, *place)) { + *operand = self.make_operand(value.clone()); + } + } + _ => (), + } + } +} + +struct OperandCollector<'tcx, 'map, 'a> { + state: &'a State<FlatSet<ScalarTy<'tcx>>>, + visitor: &'a mut CollectAndPatch<'tcx>, + map: &'map Map, +} + +impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> { + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + match self.state.get(place.as_ref(), self.map) { + FlatSet::Top => (), + FlatSet::Elem(value) => { + self.visitor.before_effect.insert((location, *place), value); + } + FlatSet::Bottom => (), + } + } + _ => (), + } + } +} + +struct DummyMachine; + +impl<'mir, 'tcx> rustc_const_eval::interpret::Machine<'mir, 'tcx> for DummyMachine { + rustc_const_eval::interpret::compile_time_machine!(<'mir, 'tcx>); + type MemoryKind = !; + const PANIC_ON_ALLOC_FAIL: bool = true; + + fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment { + unimplemented!() + } + + fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { + unimplemented!() + } + fn alignment_check_failed( + _ecx: &InterpCx<'mir, 'tcx, Self>, + _has: Align, + _required: Align, + _check: CheckAlignment, + ) -> interpret::InterpResult<'tcx, ()> { + unimplemented!() + } + + fn find_mir_or_eval_fn( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _instance: ty::Instance<'tcx>, + _abi: rustc_target::spec::abi::Abi, + _args: &[rustc_const_eval::interpret::OpTy<'tcx, Self::Provenance>], + _destination: &rustc_const_eval::interpret::PlaceTy<'tcx, Self::Provenance>, + _target: Option<BasicBlock>, + _unwind: UnwindAction, + ) -> interpret::InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { + unimplemented!() + } + + fn call_intrinsic( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _instance: ty::Instance<'tcx>, + _args: &[rustc_const_eval::interpret::OpTy<'tcx, Self::Provenance>], + _destination: &rustc_const_eval::interpret::PlaceTy<'tcx, Self::Provenance>, + _target: Option<BasicBlock>, + _unwind: UnwindAction, + ) -> interpret::InterpResult<'tcx> { + unimplemented!() + } + + fn assert_panic( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _msg: &rustc_middle::mir::AssertMessage<'tcx>, + _unwind: UnwindAction, + ) -> interpret::InterpResult<'tcx> { + unimplemented!() + } + + fn binary_ptr_op( + _ecx: &InterpCx<'mir, 'tcx, Self>, + _bin_op: BinOp, + _left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + _right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + ) -> interpret::InterpResult<'tcx, (interpret::Scalar<Self::Provenance>, bool, Ty<'tcx>)> { + throw_unsup!(Unsupported("".into())) + } + + fn expose_ptr( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _ptr: interpret::Pointer<Self::Provenance>, + ) -> interpret::InterpResult<'tcx> { + unimplemented!() + } + + fn init_frame_extra( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _frame: rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance>, + ) -> interpret::InterpResult< + 'tcx, + rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>, + > { + unimplemented!() + } + + fn stack<'a>( + _ecx: &'a InterpCx<'mir, 'tcx, Self>, + ) -> &'a [rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] + { + unimplemented!() + } + + fn stack_mut<'a>( + _ecx: &'a mut InterpCx<'mir, 'tcx, Self>, + ) -> &'a mut Vec< + rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>, + > { + unimplemented!() + } +} diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs index 9163672f570..7bc5183a00a 100644 --- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -52,7 +52,9 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS | StatementKind::StorageLive(_) | StatementKind::StorageDead(_) | StatementKind::Coverage(_) - | StatementKind::CopyNonOverlapping(_) + | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(_) | StatementKind::Nop => (), StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => { @@ -70,6 +72,8 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS for Location { block, statement_index } in patch { bbs[block].statements[statement_index].make_nop(); } + + crate::simplify::simplify_locals(body, tcx) } pub struct DeadStoreElimination; diff --git a/compiler/rustc_mir_transform/src/deaggregator.rs b/compiler/rustc_mir_transform/src/deaggregator.rs deleted file mode 100644 index b93fe5879f4..00000000000 --- a/compiler/rustc_mir_transform/src/deaggregator.rs +++ /dev/null @@ -1,49 +0,0 @@ -use crate::util::expand_aggregate; -use crate::MirPass; -use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; - -pub struct Deaggregator; - -impl<'tcx> MirPass<'tcx> for Deaggregator { - fn phase_change(&self) -> Option<MirPhase> { - Some(MirPhase::Deaggregated) - } - - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); - for bb in basic_blocks { - bb.expand_statements(|stmt| { - // FIXME(eddyb) don't match twice on `stmt.kind` (post-NLL). - match stmt.kind { - // FIXME(#48193) Deaggregate arrays when it's cheaper to do so. - StatementKind::Assign(box ( - _, - Rvalue::Aggregate(box AggregateKind::Array(_), _), - )) => { - return None; - } - StatementKind::Assign(box (_, Rvalue::Aggregate(_, _))) => {} - _ => return None, - } - - let stmt = stmt.replace_nop(); - let source_info = stmt.source_info; - let StatementKind::Assign(box (lhs, Rvalue::Aggregate(kind, operands))) = stmt.kind else { - bug!(); - }; - - Some(expand_aggregate( - lhs, - operands.into_iter().map(|op| { - let ty = op.ty(&body.local_decls, tcx); - (op, ty) - }), - *kind, - source_info, - tcx, - )) - }); - } - } -} diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs new file mode 100644 index 00000000000..a133c9d4782 --- /dev/null +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -0,0 +1,220 @@ +//! Deduces supplementary parameter attributes from MIR. +//! +//! Deduced parameter attributes are those that can only be soundly determined by examining the +//! body of the function instead of just the signature. These can be useful for optimization +//! purposes on a best-effort basis. We compute them here and store them into the crate metadata so +//! dependent crates can use them. + +use rustc_hir::def_id::LocalDefId; +use rustc_index::bit_set::BitSet; +use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::{Body, Location, Operand, Place, Terminator, TerminatorKind, RETURN_PLACE}; +use rustc_middle::ty::{self, DeducedParamAttrs, Ty, TyCtxt}; +use rustc_session::config::OptLevel; + +/// A visitor that determines which arguments have been mutated. We can't use the mutability field +/// on LocalDecl for this because it has no meaning post-optimization. +struct DeduceReadOnly { + /// Each bit is indexed by argument number, starting at zero (so 0 corresponds to local decl + /// 1). The bit is true if the argument may have been mutated or false if we know it hasn't + /// been up to the point we're at. + mutable_args: BitSet<usize>, +} + +impl DeduceReadOnly { + /// Returns a new DeduceReadOnly instance. + fn new(arg_count: usize) -> Self { + Self { mutable_args: BitSet::new_empty(arg_count) } + } +} + +impl<'tcx> Visitor<'tcx> for DeduceReadOnly { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { + // We're only interested in arguments. + if place.local == RETURN_PLACE || place.local.index() > self.mutable_args.domain_size() { + return; + } + + let mark_as_mutable = match context { + PlaceContext::MutatingUse(..) => { + // This is a mutation, so mark it as such. + true + } + PlaceContext::NonMutatingUse(NonMutatingUseContext::AddressOf) => { + // Whether mutating though a `&raw const` is allowed is still undecided, so we + // disable any sketchy `readonly` optimizations for now. + // But we only need to do this if the pointer would point into the argument. + !place.is_indirect() + } + PlaceContext::NonMutatingUse(..) | PlaceContext::NonUse(..) => { + // Not mutating, so it's fine. + false + } + }; + + if mark_as_mutable { + self.mutable_args.insert(place.local.index() - 1); + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + // OK, this is subtle. Suppose that we're trying to deduce whether `x` in `f` is read-only + // and we have the following: + // + // fn f(x: BigStruct) { g(x) } + // fn g(mut y: BigStruct) { y.foo = 1 } + // + // If, at the generated MIR level, `f` turned into something like: + // + // fn f(_1: BigStruct) -> () { + // let mut _0: (); + // bb0: { + // _0 = g(move _1) -> bb1; + // } + // ... + // } + // + // then it would be incorrect to mark `x` (i.e. `_1`) as `readonly`, because `g`'s write to + // its copy of the indirect parameter would actually be a write directly to the pointer that + // `f` passes. Note that function arguments are the only situation in which this problem can + // arise: every other use of `move` in MIR doesn't actually write to the value it moves + // from. + // + // Anyway, right now this situation doesn't actually arise in practice. Instead, the MIR for + // that function looks like this: + // + // fn f(_1: BigStruct) -> () { + // let mut _0: (); + // let mut _2: BigStruct; + // bb0: { + // _2 = move _1; + // _0 = g(move _2) -> bb1; + // } + // ... + // } + // + // Because of that extra move that MIR construction inserts, `x` (i.e. `_1`) can *in + // practice* safely be marked `readonly`. + // + // To handle the possibility that other optimizations (for example, destination propagation) + // might someday generate MIR like the first example above, we panic upon seeing an argument + // to *our* function that is directly moved into *another* function as an argument. Having + // eliminated that problematic case, we can safely treat moves as copies in this analysis. + // + // In the future, if MIR optimizations cause arguments of a caller to be directly moved into + // the argument of a callee, we can just add that argument to `mutated_args` instead of + // panicking. + // + // Note that, because the problematic MIR is never actually generated, we can't add a test + // case for this. + + if let TerminatorKind::Call { ref args, .. } = terminator.kind { + for arg in args { + if let Operand::Move(place) = *arg { + let local = place.local; + if place.is_indirect() + || local == RETURN_PLACE + || local.index() > self.mutable_args.domain_size() + { + continue; + } + + self.mutable_args.insert(local.index() - 1); + } + } + }; + + self.super_terminator(terminator, location); + } +} + +/// Returns true if values of a given type will never be passed indirectly, regardless of ABI. +fn type_will_always_be_passed_directly(ty: Ty<'_>) -> bool { + matches!( + ty.kind(), + ty::Bool + | ty::Char + | ty::Float(..) + | ty::Int(..) + | ty::RawPtr(..) + | ty::Ref(..) + | ty::Slice(..) + | ty::Uint(..) + ) +} + +/// Returns the deduced parameter attributes for a function. +/// +/// Deduced parameter attributes are those that can only be soundly determined by examining the +/// body of the function instead of just the signature. These can be useful for optimization +/// purposes on a best-effort basis. We compute them here and store them into the crate metadata so +/// dependent crates can use them. +pub fn deduced_param_attrs<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: LocalDefId, +) -> &'tcx [DeducedParamAttrs] { + // This computation is unfortunately rather expensive, so don't do it unless we're optimizing. + // Also skip it in incremental mode. + if tcx.sess.opts.optimize == OptLevel::No || tcx.sess.opts.incremental.is_some() { + return &[]; + } + + // If the Freeze language item isn't present, then don't bother. + if tcx.lang_items().freeze_trait().is_none() { + return &[]; + } + + // Codegen won't use this information for anything if all the function parameters are passed + // directly. Detect that and bail, for compilation speed. + let fn_ty = tcx.type_of(def_id).subst_identity(); + if matches!(fn_ty.kind(), ty::FnDef(..)) { + if fn_ty + .fn_sig(tcx) + .inputs() + .skip_binder() + .iter() + .cloned() + .all(type_will_always_be_passed_directly) + { + return &[]; + } + } + + // Don't deduce any attributes for functions that have no MIR. + if !tcx.is_mir_available(def_id) { + return &[]; + } + + // Grab the optimized MIR. Analyze it to determine which arguments have been mutated. + let body: &Body<'tcx> = tcx.optimized_mir(def_id); + let mut deduce_read_only = DeduceReadOnly::new(body.arg_count); + deduce_read_only.visit_body(body); + + // Set the `readonly` attribute for every argument that we concluded is immutable and that + // contains no UnsafeCells. + // + // FIXME: This is overly conservative around generic parameters: `is_freeze()` will always + // return false for them. For a description of alternatives that could do a better job here, + // see [1]. + // + // [1]: https://github.com/rust-lang/rust/pull/103172#discussion_r999139997 + let param_env = tcx.param_env_reveal_all_normalized(def_id); + let mut deduced_param_attrs = tcx.arena.alloc_from_iter( + body.local_decls.iter().skip(1).take(body.arg_count).enumerate().map( + |(arg_index, local_decl)| DeducedParamAttrs { + read_only: !deduce_read_only.mutable_args.contains(arg_index) + && local_decl.ty.is_freeze(tcx, param_env), + }, + ), + ); + + // Trailing parameters past the size of the `deduced_param_attrs` array are assumed to have the + // default set of attributes, so we don't have to store them explicitly. Pop them off to save a + // few bytes in metadata. + while deduced_param_attrs.last() == Some(&DeducedParamAttrs::default()) { + let last_index = deduced_param_attrs.len() - 1; + deduced_param_attrs = &mut deduced_param_attrs[0..last_index]; + } + + deduced_param_attrs +} diff --git a/compiler/rustc_mir_transform/src/deduplicate_blocks.rs b/compiler/rustc_mir_transform/src/deduplicate_blocks.rs index d1977ed49fe..909116a77f5 100644 --- a/compiler/rustc_mir_transform/src/deduplicate_blocks.rs +++ b/compiler/rustc_mir_transform/src/deduplicate_blocks.rs @@ -58,7 +58,7 @@ fn find_duplicates(body: &Body<'_>) -> FxHashMap<BasicBlock, BasicBlock> { let mut duplicates = FxHashMap::default(); let bbs_to_go_through = - body.basic_blocks().iter_enumerated().filter(|(_, bbd)| !bbd.is_cleanup).count(); + body.basic_blocks.iter_enumerated().filter(|(_, bbd)| !bbd.is_cleanup).count(); let mut same_hashes = FxHashMap::with_capacity_and_hasher(bbs_to_go_through, Default::default()); @@ -71,8 +71,7 @@ fn find_duplicates(body: &Body<'_>) -> FxHashMap<BasicBlock, BasicBlock> { // When we see bb1, we see that it is a duplicate of bb3, and therefore insert it in the duplicates list // with replacement bb3. // When the duplicates are removed, we will end up with only bb3. - for (bb, bbd) in body.basic_blocks().iter_enumerated().rev().filter(|(_, bbd)| !bbd.is_cleanup) - { + for (bb, bbd) in body.basic_blocks.iter_enumerated().rev().filter(|(_, bbd)| !bbd.is_cleanup) { // Basic blocks can get really big, so to avoid checking for duplicates in basic blocks // that are unlikely to have duplicates, we stop early. The early bail number has been // found experimentally by eprintln while compiling the crates in the rustc-perf suite. diff --git a/compiler/rustc_mir_transform/src/deref_separator.rs b/compiler/rustc_mir_transform/src/deref_separator.rs index a00bb16f7ac..a39026751a7 100644 --- a/compiler/rustc_mir_transform/src/deref_separator.rs +++ b/compiler/rustc_mir_transform/src/deref_separator.rs @@ -1,5 +1,5 @@ use crate::MirPass; -use rustc_index::vec::IndexVec; +use rustc_index::IndexVec; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::NonUseContext::VarDebugInfo; use rustc_middle::mir::visit::{MutVisitor, PlaceContext}; @@ -28,8 +28,6 @@ impl<'tcx> MutVisitor<'tcx> for DerefChecker<'tcx> { let mut last_len = 0; let mut last_deref_idx = 0; - let mut prev_temp: Option<Local> = None; - for (idx, elem) in place.projection[0..].iter().enumerate() { if *elem == ProjectionElem::Deref { last_deref_idx = idx; @@ -39,14 +37,12 @@ impl<'tcx> MutVisitor<'tcx> for DerefChecker<'tcx> { for (idx, (p_ref, p_elem)) in place.iter_projections().enumerate() { if !p_ref.projection.is_empty() && p_elem == ProjectionElem::Deref { let ty = p_ref.ty(&self.local_decls, self.tcx).ty; - let temp = self.patcher.new_local_with_info( + let temp = self.patcher.new_internal_with_info( ty, self.local_decls[p_ref.local].source_info.span, - Some(Box::new(LocalInfo::DerefTemp)), + LocalInfo::DerefTemp, ); - self.patcher.add_statement(loc, StatementKind::StorageLive(temp)); - // We are adding current p_ref's projections to our // temp value, excluding projections we already covered. let deref_place = Place::from(place_local) @@ -66,22 +62,8 @@ impl<'tcx> MutVisitor<'tcx> for DerefChecker<'tcx> { Place::from(temp).project_deeper(&place.projection[idx..], self.tcx); *place = temp_place; } - - // We are destroying the previous temp since it's no longer used. - if let Some(prev_temp) = prev_temp { - self.patcher.add_statement(loc, StatementKind::StorageDead(prev_temp)); - } - - prev_temp = Some(temp); } } - - // Since we won't be able to reach final temp, we destroy it outside the loop. - if let Some(prev_temp) = prev_temp { - let last_loc = - Location { block: loc.block, statement_index: loc.statement_index + 1 }; - self.patcher.add_statement(last_loc, StatementKind::StorageDead(prev_temp)); - } } } } @@ -90,7 +72,7 @@ pub fn deref_finder<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let patch = MirPatch::new(body); let mut checker = DerefChecker { tcx, patcher: patch, local_decls: body.local_decls.clone() }; - for (bb, data) in body.basic_blocks_mut().iter_enumerated_mut() { + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { checker.visit_basic_block_data(bb, data); } @@ -100,6 +82,5 @@ pub fn deref_finder<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { impl<'tcx> MirPass<'tcx> for Derefer { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { deref_finder(tcx, body); - body.phase = MirPhase::Derefered; } } diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index 33572068f5c..78758e2db28 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -20,7 +20,8 @@ //! values or the return place `_0`. On a very high level, independent of the actual implementation //! details, it does the following: //! -//! 1) Identify `dest = src;` statements that can be soundly eliminated. +//! 1) Identify `dest = src;` statements with values for `dest` and `src` whose storage can soundly +//! be merged. //! 2) Replace all mentions of `src` with `dest` ("unifying" them and propagating the destination //! backwards). //! 3) Delete the `dest = src;` statement (by making it a `nop`). @@ -29,44 +30,80 @@ //! //! ## Soundness //! -//! Given an `Assign` statement `dest = src;`, where `dest` is a `Place` and `src` is an `Rvalue`, -//! there are a few requirements that must hold for the optimization to be sound: +//! We have a pair of places `p` and `q`, whose memory we would like to merge. In order for this to +//! be sound, we need to check a number of conditions: //! -//! * `dest` must not contain any *indirection* through a pointer. It must access part of the base -//! local. Otherwise it might point to arbitrary memory that is hard to track. +//! * `p` and `q` must both be *constant* - it does not make much sense to talk about merging them +//! if they do not consistently refer to the same place in memory. This is satisfied if they do +//! not contain any indirection through a pointer or any indexing projections. //! -//! It must also not contain any indexing projections, since those take an arbitrary `Local` as -//! the index, and that local might only be initialized shortly before `dest` is used. +//! * We need to make sure that the goal of "merging the memory" is actually structurally possible +//! in MIR. For example, even if all the other conditions are satisfied, there is no way to +//! "merge" `_5.foo` and `_6.bar`. For now, we ensure this by requiring that both `p` and `q` are +//! locals with no further projections. Future iterations of this pass should improve on this. //! -//! * `src` must be a bare `Local` without any indirections or field projections (FIXME: Is this a -//! fundamental restriction or just current impl state?). It can be copied or moved by the -//! assignment. +//! * Finally, we want `p` and `q` to use the same memory - however, we still need to make sure that +//! each of them has enough "ownership" of that memory to continue "doing its job." More +//! precisely, what we will check is that whenever the program performs a write to `p`, then it +//! does not currently care about what the value in `q` is (and vice versa). We formalize the +//! notion of "does not care what the value in `q` is" by checking the *liveness* of `q`. //! -//! * The `dest` and `src` locals must never be [*live*][liveness] at the same time. If they are, it -//! means that they both hold a (potentially different) value that is needed by a future use of -//! the locals. Unifying them would overwrite one of the values. +//! Because of the difficulty of computing liveness of places that have their address taken, we do +//! not even attempt to do it. Any places that are in a local that has its address taken is +//! excluded from the optimization. //! -//! Note that computing liveness of locals that have had their address taken is more difficult: -//! Short of doing full escape analysis on the address/pointer/reference, the pass would need to -//! assume that any operation that can potentially involve opaque user code (such as function -//! calls, destructors, and inline assembly) may access any local that had its address taken -//! before that point. +//! The first two conditions are simple structural requirements on the `Assign` statements that can +//! be trivially checked. The third requirement however is more difficult and costly to check. //! -//! Here, the first two conditions are simple structural requirements on the `Assign` statements -//! that can be trivially checked. The liveness requirement however is more difficult and costly to -//! check. +//! ## Future Improvements +//! +//! There are a number of ways in which this pass could be improved in the future: +//! +//! * Merging storage liveness ranges instead of removing storage statements completely. This may +//! improve stack usage. +//! +//! * Allow merging locals into places with projections, eg `_5` into `_6.foo`. +//! +//! * Liveness analysis with more precision than whole locals at a time. The smaller benefit of this +//! is that it would allow us to dest prop at "sub-local" levels in some cases. The bigger benefit +//! of this is that such liveness analysis can report more accurate results about whole locals at +//! a time. For example, consider: +//! +//! ```ignore (syntax-highlighting-only) +//! _1 = u; +//! // unrelated code +//! _1.f1 = v; +//! _2 = _1.f1; +//! ``` +//! +//! Because the current analysis only thinks in terms of locals, it does not have enough +//! information to report that `_1` is dead in the "unrelated code" section. +//! +//! * Liveness analysis enabled by alias analysis. This would allow us to not just bail on locals +//! that ever have their address taken. Of course that requires actually having alias analysis +//! (and a model to build it on), so this might be a bit of a ways off. +//! +//! * Various perf improvements. There are a bunch of comments in here marked `PERF` with ideas for +//! how to do things more efficiently. However, the complexity of the pass as a whole should be +//! kept in mind. //! //! ## Previous Work //! -//! A [previous attempt] at implementing an optimization like this turned out to be a significant -//! regression in compiler performance. Fixing the regressions introduced a lot of undesirable -//! complexity to the implementation. +//! A [previous attempt][attempt 1] at implementing an optimization like this turned out to be a +//! significant regression in compiler performance. Fixing the regressions introduced a lot of +//! undesirable complexity to the implementation. +//! +//! A [subsequent approach][attempt 2] tried to avoid the costly computation by limiting itself to +//! acyclic CFGs, but still turned out to be far too costly to run due to suboptimal performance +//! within individual basic blocks, requiring a walk across the entire block for every assignment +//! found within the block. For the `tuple-stress` benchmark, which has 458745 statements in a +//! single block, this proved to be far too costly. //! -//! A [subsequent approach] tried to avoid the costly computation by limiting itself to acyclic -//! CFGs, but still turned out to be far too costly to run due to suboptimal performance within -//! individual basic blocks, requiring a walk across the entire block for every assignment found -//! within the block. For the `tuple-stress` benchmark, which has 458745 statements in a single -//! block, this proved to be far too costly. +//! [Another approach after that][attempt 3] was much closer to correct, but had some soundness +//! issues - it was failing to consider stores outside live ranges, and failed to uphold some of the +//! requirements that MIR has for non-overlapping places within statements. However, it also had +//! performance issues caused by `O(l² * s)` runtime, where `l` is the number of locals and `s` is +//! the number of statements and terminators. //! //! Since the first attempt at this, the compiler has improved dramatically, and new analysis //! frameworks have been added that should make this approach viable without requiring a limited @@ -74,8 +111,7 @@ //! - rustc now has a powerful dataflow analysis framework that can handle forwards and backwards //! analyses efficiently. //! - Layout optimizations for generators have been added to improve code generation for -//! async/await, which are very similar in spirit to what this optimization does. Both walk the -//! MIR and record conflicting uses of locals in a `BitMatrix`. +//! async/await, which are very similar in spirit to what this optimization does. //! //! Also, rustc now has a simple NRVO pass (see `nrvo.rs`), which handles a subset of the cases that //! this destination propagation pass handles, proving that similar optimizations can be performed @@ -87,256 +123,213 @@ //! it replaces the eliminated assign statements with `nop`s and leaves unused locals behind. //! //! [liveness]: https://en.wikipedia.org/wiki/Live_variable_analysis -//! [previous attempt]: https://github.com/rust-lang/rust/pull/47954 -//! [subsequent approach]: https://github.com/rust-lang/rust/pull/71003 +//! [attempt 1]: https://github.com/rust-lang/rust/pull/47954 +//! [attempt 2]: https://github.com/rust-lang/rust/pull/71003 +//! [attempt 3]: https://github.com/rust-lang/rust/pull/72632 + +use std::collections::hash_map::{Entry, OccupiedEntry}; +use crate::simplify::remove_dead_blocks; use crate::MirPass; -use itertools::Itertools; -use rustc_data_structures::unify::{InPlaceUnificationTable, UnifyKey}; -use rustc_index::{ - bit_set::{BitMatrix, BitSet}, - vec::IndexVec, -}; +use rustc_data_structures::fx::FxHashMap; +use rustc_index::bit_set::BitSet; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::{dump_mir, PassWhere}; use rustc_middle::mir::{ - traversal, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, PlaceElem, - Rvalue, Statement, StatementKind, Terminator, TerminatorKind, + traversal, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, Rvalue, + Statement, StatementKind, TerminatorKind, }; use rustc_middle::ty::TyCtxt; -use rustc_mir_dataflow::impls::{borrowed_locals, MaybeInitializedLocals, MaybeLiveLocals}; -use rustc_mir_dataflow::Analysis; - -// Empirical measurements have resulted in some observations: -// - Running on a body with a single block and 500 locals takes barely any time -// - Running on a body with ~400 blocks and ~300 relevant locals takes "too long" -// ...so we just limit both to somewhat reasonable-ish looking values. -const MAX_LOCALS: usize = 500; -const MAX_BLOCKS: usize = 250; +use rustc_mir_dataflow::impls::MaybeLiveLocals; +use rustc_mir_dataflow::{Analysis, ResultsCursor}; pub struct DestinationPropagation; impl<'tcx> MirPass<'tcx> for DestinationPropagation { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // FIXME(#79191, #82678): This is unsound. - // - // Only run at mir-opt-level=3 or higher for now (we don't fix up debuginfo and remove - // storage statements at the moment). - sess.opts.unstable_opts.unsound_mir_opts && sess.mir_opt_level() >= 3 + // For now, only run at MIR opt level 3. Two things need to be changed before this can be + // turned on by default: + // 1. Because of the overeager removal of storage statements, this can cause stack space + // regressions. This opt is not the place to fix this though, it's a more general + // problem in MIR. + // 2. Despite being an overall perf improvement, this still causes a 30% regression in + // keccak. We can temporarily fix this by bounding function size, but in the long term + // we should fix this by being smarter about invalidating analysis results. + sess.mir_opt_level() >= 3 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let def_id = body.source.def_id(); + let mut allocations = Allocations::default(); + trace!(func = ?tcx.def_path_str(def_id)); - let candidates = find_candidates(body); - if candidates.is_empty() { - debug!("{:?}: no dest prop candidates, done", def_id); - return; - } - - // Collect all locals we care about. We only compute conflicts for these to save time. - let mut relevant_locals = BitSet::new_empty(body.local_decls.len()); - for CandidateAssignment { dest, src, loc: _ } in &candidates { - relevant_locals.insert(dest.local); - relevant_locals.insert(*src); - } + let borrowed = rustc_mir_dataflow::impls::borrowed_locals(body); - // This pass unfortunately has `O(l² * s)` performance, where `l` is the number of locals - // and `s` is the number of statements and terminators in the function. - // To prevent blowing up compile times too much, we bail out when there are too many locals. - let relevant = relevant_locals.count(); - debug!( - "{:?}: {} locals ({} relevant), {} blocks", - def_id, - body.local_decls.len(), - relevant, - body.basic_blocks().len() - ); - if relevant > MAX_LOCALS { - warn!( - "too many candidate locals in {:?} ({}, max is {}), not optimizing", - def_id, relevant, MAX_LOCALS + // In order to avoid having to collect data for every single pair of locals in the body, we + // do not allow doing more than one merge for places that are derived from the same local at + // once. To avoid missed opportunities, we instead iterate to a fixed point - we'll refer to + // each of these iterations as a "round." + // + // Reaching a fixed point could in theory take up to `min(l, s)` rounds - however, we do not + // expect to see MIR like that. To verify this, a test was run against `[rust-lang/regex]` - + // the average MIR body saw 1.32 full iterations of this loop. The most that was hit were 30 + // for a single function. Only 80/2801 (2.9%) of functions saw at least 5. + // + // [rust-lang/regex]: + // https://github.com/rust-lang/regex/tree/b5372864e2df6a2f5e543a556a62197f50ca3650 + let mut round_count = 0; + loop { + // PERF: Can we do something smarter than recalculating the candidates and liveness + // results? + let mut candidates = find_candidates( + body, + &borrowed, + &mut allocations.candidates, + &mut allocations.candidates_reverse, ); - return; - } - if body.basic_blocks().len() > MAX_BLOCKS { - warn!( - "too many blocks in {:?} ({}, max is {}), not optimizing", - def_id, - body.basic_blocks().len(), - MAX_BLOCKS + trace!(?candidates); + let mut live = MaybeLiveLocals + .into_engine(tcx, body) + .iterate_to_fixpoint() + .into_results_cursor(body); + dest_prop_mir_dump(tcx, body, &mut live, round_count); + + FilterInformation::filter_liveness( + &mut candidates, + &mut live, + &mut allocations.write_info, + body, ); - return; - } - let mut conflicts = Conflicts::build(tcx, body, &relevant_locals); + // Because we do not update liveness information, it is unsound to use a local for more + // than one merge operation within a single round of optimizations. We store here which + // ones we have already used. + let mut merged_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len()); - let mut replacements = Replacements::new(body.local_decls.len()); - for candidate @ CandidateAssignment { dest, src, loc } in candidates { - // Merge locals that don't conflict. - if !conflicts.can_unify(dest.local, src) { - debug!("at assignment {:?}, conflict {:?} vs. {:?}", loc, dest.local, src); - continue; - } + // This is the set of merges we will apply this round. It is a subset of the candidates. + let mut merges = FxHashMap::default(); - if replacements.for_src(candidate.src).is_some() { - debug!("src {:?} already has replacement", candidate.src); - continue; + for (src, candidates) in candidates.c.iter() { + if merged_locals.contains(*src) { + continue; + } + let Some(dest) = + candidates.iter().find(|dest| !merged_locals.contains(**dest)) else { + continue; + }; + if !tcx.consider_optimizing(|| { + format!("{} round {}", tcx.def_path_str(def_id), round_count) + }) { + break; + } + merges.insert(*src, *dest); + merged_locals.insert(*src); + merged_locals.insert(*dest); } + trace!(merging = ?merges); - if !tcx.consider_optimizing(|| { - format!("DestinationPropagation {:?} {:?}", def_id, candidate) - }) { + if merges.is_empty() { break; } + round_count += 1; - replacements.push(candidate); - conflicts.unify(candidate.src, candidate.dest.local); + apply_merges(body, tcx, &merges, &merged_locals); } - replacements.flatten(tcx); - - debug!("replacements {:?}", replacements.map); - - Replacer { tcx, replacements, place_elem_cache: Vec::new() }.visit_body(body); - - // FIXME fix debug info - } -} - -#[derive(Debug, Eq, PartialEq, Copy, Clone)] -struct UnifyLocal(Local); + if round_count != 0 { + // Merging can introduce overlap between moved arguments and/or call destination in an + // unreachable code, which validator considers to be ill-formed. + remove_dead_blocks(tcx, body); + } -impl From<Local> for UnifyLocal { - fn from(l: Local) -> Self { - Self(l) + trace!(round_count); } } -impl UnifyKey for UnifyLocal { - type Value = (); - #[inline] - fn index(&self) -> u32 { - self.0.as_u32() - } - #[inline] - fn from_index(u: u32) -> Self { - Self(Local::from_u32(u)) - } - fn tag() -> &'static str { - "UnifyLocal" - } +/// Container for the various allocations that we need. +/// +/// We store these here and hand out `&mut` access to them, instead of dropping and recreating them +/// frequently. Everything with a `&'alloc` lifetime points into here. +#[derive(Default)] +struct Allocations { + candidates: FxHashMap<Local, Vec<Local>>, + candidates_reverse: FxHashMap<Local, Vec<Local>>, + write_info: WriteInfo, + // PERF: Do this for `MaybeLiveLocals` allocations too. } -struct Replacements<'tcx> { - /// Maps locals to their replacement. - map: IndexVec<Local, Option<Place<'tcx>>>, - - /// Whose locals' live ranges to kill. - kill: BitSet<Local>, +#[derive(Debug)] +struct Candidates<'alloc> { + /// The set of candidates we are considering in this optimization. + /// + /// We will always merge the key into at most one of its values. + /// + /// Whether a place ends up in the key or the value does not correspond to whether it appears as + /// the lhs or rhs of any assignment. As a matter of fact, the places in here might never appear + /// in an assignment at all. This happens because if we see an assignment like this: + /// + /// ```ignore (syntax-highlighting-only) + /// _1.0 = _2.0 + /// ``` + /// + /// We will still report that we would like to merge `_1` and `_2` in an attempt to allow us to + /// remove that assignment. + c: &'alloc mut FxHashMap<Local, Vec<Local>>, + /// A reverse index of the `c` set; if the `c` set contains `a => Place { local: b, proj }`, + /// then this contains `b => a`. + // PERF: Possibly these should be `SmallVec`s? + reverse: &'alloc mut FxHashMap<Local, Vec<Local>>, } -impl<'tcx> Replacements<'tcx> { - fn new(locals: usize) -> Self { - Self { map: IndexVec::from_elem_n(None, locals), kill: BitSet::new_empty(locals) } - } - - fn push(&mut self, candidate: CandidateAssignment<'tcx>) { - trace!("Replacements::push({:?})", candidate); - let entry = &mut self.map[candidate.src]; - assert!(entry.is_none()); - - *entry = Some(candidate.dest); - self.kill.insert(candidate.src); - self.kill.insert(candidate.dest.local); - } +////////////////////////////////////////////////////////// +// Merging +// +// Applies the actual optimization - /// Applies the stored replacements to all replacements, until no replacements would result in - /// locals that need further replacements when applied. - fn flatten(&mut self, tcx: TyCtxt<'tcx>) { - // Note: This assumes that there are no cycles in the replacements, which is enforced via - // `self.unified_locals`. Otherwise this can cause an infinite loop. - - for local in self.map.indices() { - if let Some(replacement) = self.map[local] { - // Substitute the base local of `replacement` until fixpoint. - let mut base = replacement.local; - let mut reversed_projection_slices = Vec::with_capacity(1); - while let Some(replacement_for_replacement) = self.map[base] { - base = replacement_for_replacement.local; - reversed_projection_slices.push(replacement_for_replacement.projection); - } - - let projection: Vec<_> = reversed_projection_slices - .iter() - .rev() - .flat_map(|projs| projs.iter()) - .chain(replacement.projection.iter()) - .collect(); - let projection = tcx.intern_place_elems(&projection); - - // Replace with the final `Place`. - self.map[local] = Some(Place { local: base, projection }); - } - } - } - - fn for_src(&self, src: Local) -> Option<Place<'tcx>> { - self.map[src] - } +fn apply_merges<'tcx>( + body: &mut Body<'tcx>, + tcx: TyCtxt<'tcx>, + merges: &FxHashMap<Local, Local>, + merged_locals: &BitSet<Local>, +) { + let mut merger = Merger { tcx, merges, merged_locals }; + merger.visit_body_preserves_cfg(body); } -struct Replacer<'tcx> { +struct Merger<'a, 'tcx> { tcx: TyCtxt<'tcx>, - replacements: Replacements<'tcx>, - place_elem_cache: Vec<PlaceElem<'tcx>>, + merges: &'a FxHashMap<Local, Local>, + merged_locals: &'a BitSet<Local>, } -impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { +impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> { fn tcx(&self) -> TyCtxt<'tcx> { self.tcx } - fn visit_local(&mut self, local: &mut Local, context: PlaceContext, location: Location) { - if context.is_use() && self.replacements.for_src(*local).is_some() { - bug!( - "use of local {:?} should have been replaced by visit_place; context={:?}, loc={:?}", - local, - context, - location, - ); + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) { + if let Some(dest) = self.merges.get(local) { + *local = *dest; } } - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - if let Some(replacement) = self.replacements.for_src(place.local) { - // Rebase `place`s projections onto `replacement`'s. - self.place_elem_cache.clear(); - self.place_elem_cache.extend(replacement.projection.iter().chain(place.projection)); - let projection = self.tcx.intern_place_elems(&self.place_elem_cache); - let new_place = Place { local: replacement.local, projection }; - - debug!("Replacer: {:?} -> {:?}", place, new_place); - *place = new_place; - } - - self.super_place(place, context, location); - } - fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { - self.super_statement(statement, location); - match &statement.kind { - // FIXME: Don't delete storage statements, merge the live ranges instead + // FIXME: Don't delete storage statements, but "merge" the storage ranges instead. StatementKind::StorageDead(local) | StatementKind::StorageLive(local) - if self.replacements.kill.contains(*local) => + if self.merged_locals.contains(*local) => { - statement.make_nop() + statement.make_nop(); + return; } - + _ => (), + }; + self.super_statement(statement, location); + match &statement.kind { StatementKind::Assign(box (dest, rvalue)) => { match rvalue { - Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { + Rvalue::CopyForDeref(place) + | Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { // These might've been turned into self-assignments by the replacement // (this includes the original statement we wanted to eliminate). if dest == place { @@ -353,524 +346,436 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { } } -struct Conflicts<'a> { - relevant_locals: &'a BitSet<Local>, - - /// The conflict matrix. It is always symmetric and the adjacency matrix of the corresponding - /// conflict graph. - matrix: BitMatrix<Local, Local>, - - /// Preallocated `BitSet` used by `unify`. - unify_cache: BitSet<Local>, - - /// Tracks locals that have been merged together to prevent cycles and propagate conflicts. - unified_locals: InPlaceUnificationTable<UnifyLocal>, +////////////////////////////////////////////////////////// +// Liveness filtering +// +// This section enforces bullet point 2 + +struct FilterInformation<'a, 'body, 'alloc, 'tcx> { + body: &'body Body<'tcx>, + live: &'a mut ResultsCursor<'body, 'tcx, MaybeLiveLocals>, + candidates: &'a mut Candidates<'alloc>, + write_info: &'alloc mut WriteInfo, + at: Location, } -impl<'a> Conflicts<'a> { - fn build<'tcx>( - tcx: TyCtxt<'tcx>, - body: &'_ Body<'tcx>, - relevant_locals: &'a BitSet<Local>, - ) -> Self { - // We don't have to look out for locals that have their address taken, since - // `find_candidates` already takes care of that. - - let conflicts = BitMatrix::from_row_n( - &BitSet::new_empty(body.local_decls.len()), - body.local_decls.len(), - ); - - let mut init = MaybeInitializedLocals - .into_engine(tcx, body) - .iterate_to_fixpoint() - .into_results_cursor(body); - let mut live = - MaybeLiveLocals.into_engine(tcx, body).iterate_to_fixpoint().into_results_cursor(body); - - let mut reachable = None; - dump_mir(tcx, None, "DestinationPropagation-dataflow", &"", body, |pass_where, w| { - let reachable = reachable.get_or_insert_with(|| traversal::reachable_as_bitset(body)); - - match pass_where { - PassWhere::BeforeLocation(loc) if reachable.contains(loc.block) => { - init.seek_before_primary_effect(loc); - live.seek_after_primary_effect(loc); - - writeln!(w, " // init: {:?}", init.get())?; - writeln!(w, " // live: {:?}", live.get())?; - } - PassWhere::AfterTerminator(bb) if reachable.contains(bb) => { - let loc = body.terminator_loc(bb); - init.seek_after_primary_effect(loc); - live.seek_before_primary_effect(loc); - - writeln!(w, " // init: {:?}", init.get())?; - writeln!(w, " // live: {:?}", live.get())?; - } - - PassWhere::BeforeBlock(bb) if reachable.contains(bb) => { - init.seek_to_block_start(bb); - live.seek_to_block_start(bb); - - writeln!(w, " // init: {:?}", init.get())?; - writeln!(w, " // live: {:?}", live.get())?; - } - - PassWhere::BeforeCFG | PassWhere::AfterCFG | PassWhere::AfterLocation(_) => {} - - PassWhere::BeforeLocation(_) | PassWhere::AfterTerminator(_) => { - writeln!(w, " // init: <unreachable>")?; - writeln!(w, " // live: <unreachable>")?; - } - - PassWhere::BeforeBlock(_) => { - writeln!(w, " // init: <unreachable>")?; - writeln!(w, " // live: <unreachable>")?; - } +// We first implement some utility functions which we will expose removing candidates according to +// different needs. Throughout the liveness filtering, the `candidates` are only ever accessed +// through these methods, and not directly. +impl<'alloc> Candidates<'alloc> { + /// Just `Vec::retain`, but the condition is inverted and we add debugging output + fn vec_filter_candidates( + src: Local, + v: &mut Vec<Local>, + mut f: impl FnMut(Local) -> CandidateFilter, + at: Location, + ) { + v.retain(|dest| { + let remove = f(*dest); + if remove == CandidateFilter::Remove { + trace!("eliminating {:?} => {:?} due to conflict at {:?}", src, dest, at); } - - Ok(()) + remove == CandidateFilter::Keep }); + } - let mut this = Self { - relevant_locals, - matrix: conflicts, - unify_cache: BitSet::new_empty(body.local_decls.len()), - unified_locals: { - let mut table = InPlaceUnificationTable::new(); - // Pre-fill table with all locals (this creates N nodes / "connected" components, - // "graph"-ically speaking). - for local in 0..body.local_decls.len() { - assert_eq!(table.new_key(()), UnifyLocal(Local::from_usize(local))); - } - table - }, - }; - - let mut live_and_init_locals = Vec::new(); - - // Visit only reachable basic blocks. The exact order is not important. - for (block, data) in traversal::preorder(body) { - // We need to observe the dataflow state *before* all possible locations (statement or - // terminator) in each basic block, and then observe the state *after* the terminator - // effect is applied. As long as neither `init` nor `borrowed` has a "before" effect, - // we will observe all possible dataflow states. - - // Since liveness is a backwards analysis, we need to walk the results backwards. To do - // that, we first collect in the `MaybeInitializedLocals` results in a forwards - // traversal. - - live_and_init_locals.resize_with(data.statements.len() + 1, || { - BitSet::new_empty(body.local_decls.len()) - }); - - // First, go forwards for `MaybeInitializedLocals` and apply intra-statement/terminator - // conflicts. - for (i, statement) in data.statements.iter().enumerate() { - this.record_statement_conflicts(statement); - - let loc = Location { block, statement_index: i }; - init.seek_before_primary_effect(loc); + /// `vec_filter_candidates` but for an `Entry` + fn entry_filter_candidates( + mut entry: OccupiedEntry<'_, Local, Vec<Local>>, + p: Local, + f: impl FnMut(Local) -> CandidateFilter, + at: Location, + ) { + let candidates = entry.get_mut(); + Self::vec_filter_candidates(p, candidates, f, at); + if candidates.len() == 0 { + entry.remove(); + } + } - live_and_init_locals[i].clone_from(init.get()); + /// For all candidates `(p, q)` or `(q, p)` removes the candidate if `f(q)` says to do so + fn filter_candidates_by( + &mut self, + p: Local, + mut f: impl FnMut(Local) -> CandidateFilter, + at: Location, + ) { + // Cover the cases where `p` appears as a `src` + if let Entry::Occupied(entry) = self.c.entry(p) { + Self::entry_filter_candidates(entry, p, &mut f, at); + } + // And the cases where `p` appears as a `dest` + let Some(srcs) = self.reverse.get_mut(&p) else { + return; + }; + // We use `retain` here to remove the elements from the reverse set if we've removed the + // matching candidate in the forward set. + srcs.retain(|src| { + if f(*src) == CandidateFilter::Keep { + return true; } + let Entry::Occupied(entry) = self.c.entry(*src) else { + return false; + }; + Self::entry_filter_candidates( + entry, + *src, + |dest| { + if dest == p { CandidateFilter::Remove } else { CandidateFilter::Keep } + }, + at, + ); + false + }); + } +} - this.record_terminator_conflicts(data.terminator()); - let term_loc = Location { block, statement_index: data.statements.len() }; - init.seek_before_primary_effect(term_loc); - live_and_init_locals[term_loc.statement_index].clone_from(init.get()); - - // Now, go backwards and union with the liveness results. - for statement_index in (0..=data.statements.len()).rev() { - let loc = Location { block, statement_index }; - live.seek_after_primary_effect(loc); - - live_and_init_locals[statement_index].intersect(live.get()); +#[derive(Copy, Clone, PartialEq, Eq)] +enum CandidateFilter { + Keep, + Remove, +} - trace!("record conflicts at {:?}", loc); +impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> { + /// Filters the set of candidates to remove those that conflict. + /// + /// The steps we take are exactly those that are outlined at the top of the file. For each + /// statement/terminator, we collect the set of locals that are written to in that + /// statement/terminator, and then we remove all pairs of candidates that contain one such local + /// and another one that is live. + /// + /// We need to be careful about the ordering of operations within each statement/terminator + /// here. Many statements might write and read from more than one place, and we need to consider + /// them all. The strategy for doing this is as follows: We first gather all the places that are + /// written to within the statement/terminator via `WriteInfo`. Then, we use the liveness + /// analysis from *before* the statement/terminator (in the control flow sense) to eliminate + /// candidates - this is because we want to conservatively treat a pair of locals that is both + /// read and written in the statement/terminator to be conflicting, and the liveness analysis + /// before the statement/terminator will correctly report locals that are read in the + /// statement/terminator to be live. We are additionally conservative by treating all written to + /// locals as also being read from. + fn filter_liveness<'b>( + candidates: &mut Candidates<'alloc>, + live: &mut ResultsCursor<'b, 'tcx, MaybeLiveLocals>, + write_info_alloc: &'alloc mut WriteInfo, + body: &'b Body<'tcx>, + ) { + let mut this = FilterInformation { + body, + live, + candidates, + // We don't actually store anything at this scope, we just keep things here to be able + // to reuse the allocation. + write_info: write_info_alloc, + // Doesn't matter what we put here, will be overwritten before being used + at: Location::START, + }; + this.internal_filter_liveness(); + } - this.record_dataflow_conflicts(&mut live_and_init_locals[statement_index]); + fn internal_filter_liveness(&mut self) { + for (block, data) in traversal::preorder(self.body) { + self.at = Location { block, statement_index: data.statements.len() }; + self.live.seek_after_primary_effect(self.at); + self.write_info.for_terminator(&data.terminator().kind); + self.apply_conflicts(); + + for (i, statement) in data.statements.iter().enumerate().rev() { + self.at = Location { block, statement_index: i }; + self.live.seek_after_primary_effect(self.at); + self.write_info.for_statement(&statement.kind, self.body); + self.apply_conflicts(); } - - init.seek_to_block_end(block); - live.seek_to_block_end(block); - let mut conflicts = init.get().clone(); - conflicts.intersect(live.get()); - trace!("record conflicts at end of {:?}", block); - - this.record_dataflow_conflicts(&mut conflicts); } - - this } - fn record_dataflow_conflicts(&mut self, new_conflicts: &mut BitSet<Local>) { - // Remove all locals that are not candidates. - new_conflicts.intersect(self.relevant_locals); - - for local in new_conflicts.iter() { - self.matrix.union_row_with(&new_conflicts, local); + fn apply_conflicts(&mut self) { + let writes = &self.write_info.writes; + for p in writes { + let other_skip = self.write_info.skip_pair.and_then(|(a, b)| { + if a == *p { + Some(b) + } else if b == *p { + Some(a) + } else { + None + } + }); + self.candidates.filter_candidates_by( + *p, + |q| { + if Some(q) == other_skip { + return CandidateFilter::Keep; + } + // It is possible that a local may be live for less than the + // duration of a statement This happens in the case of function + // calls or inline asm. Because of this, we also mark locals as + // conflicting when both of them are written to in the same + // statement. + if self.live.contains(q) || writes.contains(&q) { + CandidateFilter::Remove + } else { + CandidateFilter::Keep + } + }, + self.at, + ); } } +} + +/// Describes where a statement/terminator writes to +#[derive(Default, Debug)] +struct WriteInfo { + writes: Vec<Local>, + /// If this pair of locals is a candidate pair, completely skip processing it during this + /// statement. All other candidates are unaffected. + skip_pair: Option<(Local, Local)>, +} - fn record_local_conflict(&mut self, a: Local, b: Local, why: &str) { - trace!("conflict {:?} <-> {:?} due to {}", a, b, why); - self.matrix.insert(a, b); - self.matrix.insert(b, a); +impl WriteInfo { + fn for_statement<'tcx>(&mut self, statement: &StatementKind<'tcx>, body: &Body<'tcx>) { + self.reset(); + match statement { + StatementKind::Assign(box (lhs, rhs)) => { + self.add_place(*lhs); + match rhs { + Rvalue::Use(op) => { + self.add_operand(op); + self.consider_skipping_for_assign_use(*lhs, op, body); + } + Rvalue::Repeat(op, _) => { + self.add_operand(op); + } + Rvalue::Cast(_, op, _) + | Rvalue::UnaryOp(_, op) + | Rvalue::ShallowInitBox(op, _) => { + self.add_operand(op); + } + Rvalue::BinaryOp(_, ops) | Rvalue::CheckedBinaryOp(_, ops) => { + for op in [&ops.0, &ops.1] { + self.add_operand(op); + } + } + Rvalue::Aggregate(_, ops) => { + for op in ops { + self.add_operand(op); + } + } + Rvalue::ThreadLocalRef(_) + | Rvalue::NullaryOp(_, _) + | Rvalue::Ref(_, _, _) + | Rvalue::AddressOf(_, _) + | Rvalue::Len(_) + | Rvalue::Discriminant(_) + | Rvalue::CopyForDeref(_) => (), + } + } + // Retags are technically also reads, but reporting them as a write suffices + StatementKind::SetDiscriminant { place, .. } + | StatementKind::Deinit(place) + | StatementKind::Retag(_, place) => { + self.add_place(**place); + } + StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter + | StatementKind::Nop + | StatementKind::Coverage(_) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::PlaceMention(_) => (), + StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => { + bug!("{:?} not found in this MIR phase", statement) + } + } } - /// Records locals that must not overlap during the evaluation of `stmt`. These locals conflict - /// and must not be merged. - fn record_statement_conflicts(&mut self, stmt: &Statement<'_>) { - match &stmt.kind { - // While the left and right sides of an assignment must not overlap, we do not mark - // conflicts here as that would make this optimization useless. When we optimize, we - // eliminate the resulting self-assignments automatically. - StatementKind::Assign(_) => {} - - StatementKind::SetDiscriminant { .. } - | StatementKind::Deinit(..) - | StatementKind::StorageLive(..) - | StatementKind::StorageDead(..) - | StatementKind::Retag(..) - | StatementKind::FakeRead(..) - | StatementKind::AscribeUserType(..) - | StatementKind::Coverage(..) - | StatementKind::CopyNonOverlapping(..) - | StatementKind::Nop => {} + fn consider_skipping_for_assign_use<'tcx>( + &mut self, + lhs: Place<'tcx>, + rhs: &Operand<'tcx>, + body: &Body<'tcx>, + ) { + let Some(rhs) = rhs.place() else { + return + }; + if let Some(pair) = places_to_candidate_pair(lhs, rhs, body) { + self.skip_pair = Some(pair); } } - fn record_terminator_conflicts(&mut self, term: &Terminator<'_>) { - match &term.kind { - TerminatorKind::DropAndReplace { - place: dropped_place, - value, - target: _, - unwind: _, - } => { - if let Some(place) = value.place() - && !place.is_indirect() - && !dropped_place.is_indirect() - { - self.record_local_conflict( - place.local, - dropped_place.local, - "DropAndReplace operand overlap", - ); - } + fn for_terminator<'tcx>(&mut self, terminator: &TerminatorKind<'tcx>) { + self.reset(); + match terminator { + TerminatorKind::SwitchInt { discr: op, .. } + | TerminatorKind::Assert { cond: op, .. } => { + self.add_operand(op); } - TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => { - if let Some(place) = value.place() { - if !place.is_indirect() && !resume_arg.is_indirect() { - self.record_local_conflict( - place.local, - resume_arg.local, - "Yield operand overlap", - ); - } + TerminatorKind::Call { destination, func, args, .. } => { + self.add_place(*destination); + self.add_operand(func); + for arg in args { + self.add_operand(arg); } } - TerminatorKind::Call { - func, - args, - destination, - target: _, - cleanup: _, - from_hir_call: _, - fn_span: _, - } => { - // No arguments may overlap with the destination. - for arg in args.iter().chain(Some(func)) { - if let Some(place) = arg.place() { - if !place.is_indirect() && !destination.is_indirect() { - self.record_local_conflict( - destination.local, - place.local, - "call dest/arg overlap", - ); + TerminatorKind::InlineAsm { operands, .. } => { + for asm_operand in operands { + match asm_operand { + InlineAsmOperand::In { value, .. } => { + self.add_operand(value); } - } - } - } - TerminatorKind::InlineAsm { - template: _, - operands, - options: _, - line_spans: _, - destination: _, - cleanup: _, - } => { - // The intended semantics here aren't documented, we just assume that nothing that - // could be written to by the assembly may overlap with any other operands. - for op in operands { - match op { - InlineAsmOperand::Out { reg: _, late: _, place: Some(dest_place) } - | InlineAsmOperand::InOut { - reg: _, - late: _, - in_value: _, - out_place: Some(dest_place), - } => { - // For output place `place`, add all places accessed by the inline asm. - for op in operands { - match op { - InlineAsmOperand::In { reg: _, value } => { - if let Some(p) = value.place() - && !p.is_indirect() - && !dest_place.is_indirect() - { - self.record_local_conflict( - p.local, - dest_place.local, - "asm! operand overlap", - ); - } - } - InlineAsmOperand::Out { - reg: _, - late: _, - place: Some(place), - } => { - if !place.is_indirect() && !dest_place.is_indirect() { - self.record_local_conflict( - place.local, - dest_place.local, - "asm! operand overlap", - ); - } - } - InlineAsmOperand::InOut { - reg: _, - late: _, - in_value, - out_place, - } => { - if let Some(place) = in_value.place() - && !place.is_indirect() - && !dest_place.is_indirect() - { - self.record_local_conflict( - place.local, - dest_place.local, - "asm! operand overlap", - ); - } - - if let Some(place) = out_place - && !place.is_indirect() - && !dest_place.is_indirect() - { - self.record_local_conflict( - place.local, - dest_place.local, - "asm! operand overlap", - ); - } - } - InlineAsmOperand::Out { reg: _, late: _, place: None } - | InlineAsmOperand::Const { value: _ } - | InlineAsmOperand::SymFn { value: _ } - | InlineAsmOperand::SymStatic { def_id: _ } => {} - } + InlineAsmOperand::Out { place, .. } => { + if let Some(place) = place { + self.add_place(*place); } } - InlineAsmOperand::InOut { - reg: _, - late: _, - in_value: _, - out_place: None, + // Note that the `late` field in `InOut` is about whether the registers used + // for these things overlap, and is of absolutely no interest to us. + InlineAsmOperand::InOut { in_value, out_place, .. } => { + if let Some(place) = out_place { + self.add_place(*place); + } + self.add_operand(in_value); } - | InlineAsmOperand::In { reg: _, value: _ } - | InlineAsmOperand::Out { reg: _, late: _, place: None } - | InlineAsmOperand::Const { value: _ } - | InlineAsmOperand::SymFn { value: _ } - | InlineAsmOperand::SymStatic { def_id: _ } => {} + InlineAsmOperand::Const { .. } + | InlineAsmOperand::SymFn { .. } + | InlineAsmOperand::SymStatic { .. } => (), } } } - TerminatorKind::Goto { .. } - | TerminatorKind::SwitchInt { .. } | TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::Drop { .. } - | TerminatorKind::Assert { .. } + | TerminatorKind::Unreachable { .. } => (), + TerminatorKind::Drop { .. } => { + // `Drop`s create a `&mut` and so are not considered + } + TerminatorKind::Yield { .. } | TerminatorKind::GeneratorDrop | TerminatorKind::FalseEdge { .. } - | TerminatorKind::FalseUnwind { .. } => {} + | TerminatorKind::FalseUnwind { .. } => { + bug!("{:?} not found in this MIR phase", terminator) + } } } - /// Checks whether `a` and `b` may be merged. Returns `false` if there's a conflict. - fn can_unify(&mut self, a: Local, b: Local) -> bool { - // After some locals have been unified, their conflicts are only tracked in the root key, - // so look that up. - let a = self.unified_locals.find(a).0; - let b = self.unified_locals.find(b).0; - - if a == b { - // Already merged (part of the same connected component). - return false; - } - - if self.matrix.contains(a, b) { - // Conflict (derived via dataflow, intra-statement conflicts, or inherited from another - // local during unification). - return false; - } - - true + fn add_place(&mut self, place: Place<'_>) { + self.writes.push(place.local); } - /// Merges the conflicts of `a` and `b`, so that each one inherits all conflicts of the other. - /// - /// `can_unify` must have returned `true` for the same locals, or this may panic or lead to - /// miscompiles. - /// - /// This is called when the pass makes the decision to unify `a` and `b` (or parts of `a` and - /// `b`) and is needed to ensure that future unification decisions take potentially newly - /// introduced conflicts into account. - /// - /// For an example, assume we have locals `_0`, `_1`, `_2`, and `_3`. There are these conflicts: - /// - /// * `_0` <-> `_1` - /// * `_1` <-> `_2` - /// * `_3` <-> `_0` - /// - /// We then decide to merge `_2` with `_3` since they don't conflict. Then we decide to merge - /// `_2` with `_0`, which also doesn't have a conflict in the above list. However `_2` is now - /// `_3`, which does conflict with `_0`. - fn unify(&mut self, a: Local, b: Local) { - trace!("unify({:?}, {:?})", a, b); - - // Get the root local of the connected components. The root local stores the conflicts of - // all locals in the connected component (and *is stored* as the conflicting local of other - // locals). - let a = self.unified_locals.find(a).0; - let b = self.unified_locals.find(b).0; - assert_ne!(a, b); - - trace!("roots: a={:?}, b={:?}", a, b); - trace!("{:?} conflicts: {:?}", a, self.matrix.iter(a).format(", ")); - trace!("{:?} conflicts: {:?}", b, self.matrix.iter(b).format(", ")); - - self.unified_locals.union(a, b); - - let root = self.unified_locals.find(a).0; - assert!(root == a || root == b); - - // Make all locals that conflict with `a` also conflict with `b`, and vice versa. - self.unify_cache.clear(); - for conflicts_with_a in self.matrix.iter(a) { - self.unify_cache.insert(conflicts_with_a); - } - for conflicts_with_b in self.matrix.iter(b) { - self.unify_cache.insert(conflicts_with_b); - } - for conflicts_with_a_or_b in self.unify_cache.iter() { - // Set both `a` and `b` for this local's row. - self.matrix.insert(conflicts_with_a_or_b, a); - self.matrix.insert(conflicts_with_a_or_b, b); + fn add_operand<'tcx>(&mut self, op: &Operand<'tcx>) { + match op { + // FIXME(JakobDegen): In a previous version, the `Move` case was incorrectly treated as + // being a read only. This was unsound, however we cannot add a regression test because + // it is not possible to set this off with current MIR. Once we have that ability, a + // regression test should be added. + Operand::Move(p) => self.add_place(*p), + Operand::Copy(_) | Operand::Constant(_) => (), } + } - // Write the locals `a` conflicts with to `b`'s row. - self.matrix.union_rows(a, b); - // Write the locals `b` conflicts with to `a`'s row. - self.matrix.union_rows(b, a); + fn reset(&mut self) { + self.writes.clear(); + self.skip_pair = None; } } -/// A `dest = {move} src;` statement at `loc`. +///////////////////////////////////////////////////// +// Candidate accumulation + +/// If the pair of places is being considered for merging, returns the candidate which would be +/// merged in order to accomplish this. /// -/// We want to consider merging `dest` and `src` due to this assignment. -#[derive(Debug, Copy, Clone)] -struct CandidateAssignment<'tcx> { - /// Does not contain indirection or indexing (so the only local it contains is the place base). - dest: Place<'tcx>, - src: Local, - loc: Location, +/// The contract here is in one direction - there is a guarantee that merging the locals that are +/// outputted by this function would result in an assignment between the inputs becoming a +/// self-assignment. However, there is no guarantee that the returned pair is actually suitable for +/// merging - candidate collection must still check this independently. +/// +/// This output is unique for each unordered pair of input places. +fn places_to_candidate_pair<'tcx>( + a: Place<'tcx>, + b: Place<'tcx>, + body: &Body<'tcx>, +) -> Option<(Local, Local)> { + let (mut a, mut b) = if a.projection.len() == 0 && b.projection.len() == 0 { + (a.local, b.local) + } else { + return None; + }; + + // By sorting, we make sure we're input order independent + if a > b { + std::mem::swap(&mut a, &mut b); + } + + // We could now return `(a, b)`, but then we miss some candidates in the case where `a` can't be + // used as a `src`. + if is_local_required(a, body) { + std::mem::swap(&mut a, &mut b); + } + // We could check `is_local_required` again here, but there's no need - after all, we make no + // promise that the candidate pair is actually valid + Some((a, b)) } -/// Scans the MIR for assignments between locals that we might want to consider merging. +/// Collects the candidates for merging /// -/// This will filter out assignments that do not match the right form (as described in the top-level -/// comment) and also throw out assignments that involve a local that has its address taken or is -/// otherwise ineligible (eg. locals used as array indices are ignored because we cannot propagate -/// arbitrary places into array indices). -fn find_candidates<'tcx>(body: &Body<'tcx>) -> Vec<CandidateAssignment<'tcx>> { - let mut visitor = FindAssignments { - body, - candidates: Vec::new(), - ever_borrowed_locals: borrowed_locals(body), - locals_used_as_array_index: locals_used_as_array_index(body), - }; +/// This is responsible for enforcing the first and third bullet point. +fn find_candidates<'alloc, 'tcx>( + body: &Body<'tcx>, + borrowed: &BitSet<Local>, + candidates: &'alloc mut FxHashMap<Local, Vec<Local>>, + candidates_reverse: &'alloc mut FxHashMap<Local, Vec<Local>>, +) -> Candidates<'alloc> { + candidates.clear(); + candidates_reverse.clear(); + let mut visitor = FindAssignments { body, candidates, borrowed }; visitor.visit_body(body); - visitor.candidates + // Deduplicate candidates + for (_, cands) in candidates.iter_mut() { + cands.sort(); + cands.dedup(); + } + // Generate the reverse map + for (src, cands) in candidates.iter() { + for dest in cands.iter().copied() { + candidates_reverse.entry(dest).or_default().push(*src); + } + } + Candidates { c: candidates, reverse: candidates_reverse } } -struct FindAssignments<'a, 'tcx> { +struct FindAssignments<'a, 'alloc, 'tcx> { body: &'a Body<'tcx>, - candidates: Vec<CandidateAssignment<'tcx>>, - ever_borrowed_locals: BitSet<Local>, - locals_used_as_array_index: BitSet<Local>, + candidates: &'alloc mut FxHashMap<Local, Vec<Local>>, + borrowed: &'a BitSet<Local>, } -impl<'tcx> Visitor<'tcx> for FindAssignments<'_, 'tcx> { - fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { +impl<'tcx> Visitor<'tcx> for FindAssignments<'_, '_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { if let StatementKind::Assign(box ( - dest, - Rvalue::Use(Operand::Copy(src) | Operand::Move(src)), + lhs, + Rvalue::CopyForDeref(rhs) | Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), )) = &statement.kind { - // `dest` must not have pointer indirection. - if dest.is_indirect() { + let Some((src, dest)) = places_to_candidate_pair(*lhs, *rhs, self.body) else { return; - } - - // `src` must be a plain local. - if !src.projection.is_empty() { - return; - } + }; - // Since we want to replace `src` with `dest`, `src` must not be required. - if is_local_required(src.local, self.body) { + // As described at the top of the file, we do not go near things that have their address + // taken. + if self.borrowed.contains(src) || self.borrowed.contains(dest) { return; } - // Can't optimize if either local ever has their address taken. This optimization does - // liveness analysis only based on assignments, and a local can be live even if its - // never assigned to again, because a reference to it might be live. - // FIXME: This can be smarter and take `StorageDead` into account (which invalidates - // borrows). - if self.ever_borrowed_locals.contains(dest.local) - || self.ever_borrowed_locals.contains(src.local) - { + // Also, we need to make sure that MIR actually allows the `src` to be removed + if is_local_required(src, self.body) { return; } - assert_ne!(dest.local, src.local, "self-assignments are UB"); - - // We can't replace locals occurring in `PlaceElem::Index` for now. - if self.locals_used_as_array_index.contains(src.local) { - return; - } - - for elem in dest.projection { - if let PlaceElem::Index(_) = elem { - // `dest` contains an indexing projection. - return; - } - } - - self.candidates.push(CandidateAssignment { - dest: *dest, - src: src.local, - loc: location, - }); + // We may insert duplicates here, but that's fine + self.candidates.entry(src).or_default().push(dest); } } } @@ -882,36 +787,50 @@ impl<'tcx> Visitor<'tcx> for FindAssignments<'_, 'tcx> { fn is_local_required(local: Local, body: &Body<'_>) -> bool { match body.local_kind(local) { LocalKind::Arg | LocalKind::ReturnPointer => true, - LocalKind::Var | LocalKind::Temp => false, + LocalKind::Temp => false, } } -/// `PlaceElem::Index` only stores a `Local`, so we can't replace that with a full `Place`. -/// -/// Collect locals used as indices so we don't generate candidates that are impossible to apply -/// later. -fn locals_used_as_array_index(body: &Body<'_>) -> BitSet<Local> { - let mut visitor = IndexCollector { locals: BitSet::new_empty(body.local_decls.len()) }; - visitor.visit_body(body); - visitor.locals -} +///////////////////////////////////////////////////////// +// MIR Dump -struct IndexCollector { - locals: BitSet<Local>, -} +fn dest_prop_mir_dump<'body, 'tcx>( + tcx: TyCtxt<'tcx>, + body: &'body Body<'tcx>, + live: &mut ResultsCursor<'body, 'tcx, MaybeLiveLocals>, + round: usize, +) { + let mut reachable = None; + dump_mir(tcx, false, "DestinationPropagation-dataflow", &round, body, |pass_where, w| { + let reachable = reachable.get_or_insert_with(|| traversal::reachable_as_bitset(body)); + + match pass_where { + PassWhere::BeforeLocation(loc) if reachable.contains(loc.block) => { + live.seek_after_primary_effect(loc); + writeln!(w, " // live: {:?}", live.get())?; + } + PassWhere::AfterTerminator(bb) if reachable.contains(bb) => { + let loc = body.terminator_loc(bb); + live.seek_before_primary_effect(loc); + writeln!(w, " // live: {:?}", live.get())?; + } -impl<'tcx> Visitor<'tcx> for IndexCollector { - fn visit_projection_elem( - &mut self, - local: Local, - proj_base: &[PlaceElem<'tcx>], - elem: PlaceElem<'tcx>, - context: PlaceContext, - location: Location, - ) { - if let PlaceElem::Index(i) = elem { - self.locals.insert(i); + PassWhere::BeforeBlock(bb) if reachable.contains(bb) => { + live.seek_to_block_start(bb); + writeln!(w, " // live: {:?}", live.get())?; + } + + PassWhere::BeforeCFG | PassWhere::AfterCFG | PassWhere::AfterLocation(_) => {} + + PassWhere::BeforeLocation(_) | PassWhere::AfterTerminator(_) => { + writeln!(w, " // live: <unreachable>")?; + } + + PassWhere::BeforeBlock(_) => { + writeln!(w, " // live: <unreachable>")?; + } } - self.super_projection_elem(local, proj_base, elem, context, location); - } + + Ok(()) + }); } diff --git a/compiler/rustc_mir_transform/src/dump_mir.rs b/compiler/rustc_mir_transform/src/dump_mir.rs index 6b995141a2b..746e3d9652d 100644 --- a/compiler/rustc_mir_transform/src/dump_mir.rs +++ b/compiler/rustc_mir_transform/src/dump_mir.rs @@ -1,6 +1,5 @@ //! This pass just dumps MIR at a specified point. -use std::borrow::Cow; use std::fs::File; use std::io; @@ -8,20 +7,20 @@ use crate::MirPass; use rustc_middle::mir::write_mir_pretty; use rustc_middle::mir::Body; use rustc_middle::ty::TyCtxt; -use rustc_session::config::{OutputFilenames, OutputType}; +use rustc_session::config::OutputType; pub struct Marker(pub &'static str); impl<'tcx> MirPass<'tcx> for Marker { - fn name(&self) -> Cow<'_, str> { - Cow::Borrowed(self.0) + fn name(&self) -> &'static str { + self.0 } fn run_pass(&self, _tcx: TyCtxt<'tcx>, _body: &mut Body<'tcx>) {} } -pub fn emit_mir(tcx: TyCtxt<'_>, outputs: &OutputFilenames) -> io::Result<()> { - let path = outputs.path(OutputType::Mir); +pub fn emit_mir(tcx: TyCtxt<'_>) -> io::Result<()> { + let path = tcx.output_filenames(()).path(OutputType::Mir); let mut f = io::BufWriter::new(File::create(&path)?); write_mir_pretty(tcx, None, &mut f)?; Ok(()) diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs index dba42f7aff0..8a7b027ddda 100644 --- a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs +++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs @@ -104,8 +104,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { let mut should_cleanup = false; // Also consider newly generated bbs in the same pass - for i in 0..body.basic_blocks().len() { - let bbs = body.basic_blocks(); + for i in 0..body.basic_blocks.len() { + let bbs = &*body.basic_blocks; let parent = BasicBlock::from_usize(i); let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue @@ -121,7 +121,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { let TerminatorKind::SwitchInt { discr: parent_op, - switch_ty: parent_ty, targets: parent_targets } = &bbs[parent].terminator().kind else { unreachable!() @@ -132,6 +131,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { Operand::Copy(x) => Operand::Copy(*x), Operand::Constant(x) => Operand::Constant(x.clone()), }; + let parent_ty = parent_op.ty(body.local_decls(), tcx); let statements_before = bbs[parent].statements.len(); let parent_end = Location { block: parent, statement_index: statements_before }; @@ -153,7 +153,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { // create temp to store inequality comparison between the two discriminants, `_t` in // example above let nequal = BinOp::Ne; - let comp_res_type = nequal.ty(tcx, *parent_ty, opt_data.child_ty); + let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty); let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span); patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp)); @@ -181,7 +181,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { kind: TerminatorKind::SwitchInt { // switch on the first discriminant, so we can mark the second one as dead discr: parent_op, - switch_ty: opt_data.child_ty, targets: eq_targets, }, })); @@ -193,12 +192,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { let false_case = eq_bb; patch.patch_terminator( parent, - TerminatorKind::if_( - tcx, - Operand::Move(Place::from(comp_temp)), - true_case, - false_case, - ), + TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case), ); // generate StorageDead for the second_discriminant_temp not in use anymore @@ -316,14 +310,14 @@ fn evaluate_candidate<'tcx>( body: &Body<'tcx>, parent: BasicBlock, ) -> Option<OptimizationData<'tcx>> { - let bbs = body.basic_blocks(); + let bbs = &body.basic_blocks; let TerminatorKind::SwitchInt { targets, - switch_ty: parent_ty, - .. + discr: parent_discr, } = &bbs[parent].terminator().kind else { return None }; + let parent_ty = parent_discr.ty(body.local_decls(), tcx); let parent_dest = { let poss = targets.otherwise(); // If the fallthrough on the parent is trivially unreachable, we can let the @@ -339,12 +333,12 @@ fn evaluate_candidate<'tcx>( let (_, child) = targets.iter().next()?; let child_terminator = &bbs[child].terminator(); let TerminatorKind::SwitchInt { - switch_ty: child_ty, targets: child_targets, - .. + discr: child_discr, } = &child_terminator.kind else { return None }; + let child_ty = child_discr.ty(body.local_decls(), tcx); if child_ty != parent_ty { return None; } @@ -372,7 +366,7 @@ fn evaluate_candidate<'tcx>( Some(OptimizationData { destination, child_place: *child_place, - child_ty: *child_ty, + child_ty, child_source: child_terminator.source_info, }) } diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs index 44e3945d6fc..f31653caa49 100644 --- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -4,12 +4,12 @@ use crate::MirPass; use rustc_hir::def_id::DefId; -use rustc_index::vec::Idx; +use rustc_index::Idx; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::*; -use rustc_middle::ty::subst::Subst; use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_target::abi::FieldIdx; /// Constructs the types used when accessing a Box's pointer pub fn build_ptr_tys<'tcx>( @@ -18,24 +18,24 @@ pub fn build_ptr_tys<'tcx>( unique_did: DefId, nonnull_did: DefId, ) -> (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>) { - let substs = tcx.intern_substs(&[pointee.into()]); - let unique_ty = tcx.bound_type_of(unique_did).subst(tcx, substs); - let nonnull_ty = tcx.bound_type_of(nonnull_did).subst(tcx, substs); + let substs = tcx.mk_substs(&[pointee.into()]); + let unique_ty = tcx.type_of(unique_did).subst(tcx, substs); + let nonnull_ty = tcx.type_of(nonnull_did).subst(tcx, substs); let ptr_ty = tcx.mk_imm_ptr(pointee); (unique_ty, nonnull_ty, ptr_ty) } -// Constructs the projection needed to access a Box's pointer +/// Constructs the projection needed to access a Box's pointer pub fn build_projection<'tcx>( unique_ty: Ty<'tcx>, nonnull_ty: Ty<'tcx>, ptr_ty: Ty<'tcx>, ) -> [PlaceElem<'tcx>; 3] { [ - PlaceElem::Field(Field::new(0), unique_ty), - PlaceElem::Field(Field::new(0), nonnull_ty), - PlaceElem::Field(Field::new(0), ptr_ty), + PlaceElem::Field(FieldIdx::new(0), unique_ty), + PlaceElem::Field(FieldIdx::new(0), nonnull_ty), + PlaceElem::Field(FieldIdx::new(0), ptr_ty), ] } @@ -69,10 +69,7 @@ impl<'tcx, 'a> MutVisitor<'tcx> for ElaborateBoxDerefVisitor<'tcx, 'a> { let (unique_ty, nonnull_ty, ptr_ty) = build_ptr_tys(tcx, base_ty.boxed_ty(), self.unique_did, self.nonnull_did); - let ptr_local = self.patch.new_temp(ptr_ty, source_info.span); - self.local_decls.push(LocalDecl::new(ptr_ty, source_info.span)); - - self.patch.add_statement(location, StatementKind::StorageLive(ptr_local)); + let ptr_local = self.patch.new_internal(ptr_ty, source_info.span); self.patch.add_assign( location, @@ -84,11 +81,6 @@ impl<'tcx, 'a> MutVisitor<'tcx> for ElaborateBoxDerefVisitor<'tcx, 'a> { ); place.local = ptr_local; - - self.patch.add_statement( - Location { block: location.block, statement_index: location.statement_index + 1 }, - StatementKind::StorageDead(ptr_local), - ); } self.super_place(place, context, location); @@ -100,13 +92,14 @@ pub struct ElaborateBoxDerefs; impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { if let Some(def_id) = tcx.lang_items().owned_box() { - let unique_did = tcx.adt_def(def_id).non_enum_variant().fields[0].did; + let unique_did = + tcx.adt_def(def_id).non_enum_variant().fields[FieldIdx::from_u32(0)].did; - let Some(nonnull_def) = tcx.type_of(unique_did).ty_adt_def() else { + let Some(nonnull_def) = tcx.type_of(unique_did).subst_identity().ty_adt_def() else { span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique") }; - let nonnull_did = nonnull_def.non_enum_variant().fields[0].did; + let nonnull_did = nonnull_def.non_enum_variant().fields[FieldIdx::from_u32(0)].did; let patch = MirPatch::new(body); @@ -115,34 +108,8 @@ impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs { let mut visitor = ElaborateBoxDerefVisitor { tcx, unique_did, nonnull_did, local_decls, patch }; - for (block, BasicBlockData { statements, terminator, .. }) in - body.basic_blocks.as_mut().iter_enumerated_mut() - { - let mut index = 0; - for statement in statements { - let location = Location { block, statement_index: index }; - visitor.visit_statement(statement, location); - index += 1; - } - - if let Some(terminator) = terminator - && !matches!(terminator.kind, TerminatorKind::Yield{..}) - { - let location = Location { block, statement_index: index }; - visitor.visit_terminator(terminator, location); - } - - let location = Location { block, statement_index: index }; - match terminator { - // yielding into a box is handled when lowering generators - Some(Terminator { kind: TerminatorKind::Yield { value, .. }, .. }) => { - visitor.visit_operand(value, location); - } - Some(terminator) => { - visitor.visit_terminator(terminator, location); - } - None => {} - } + for (block, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + visitor.visit_basic_block_data(block, data); } visitor.patch.apply(body); @@ -173,7 +140,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs { if let Some(mut new_projections) = new_projections { new_projections.extend_from_slice(&place.projection[last_deref..]); - place.projection = tcx.intern_place_elems(&new_projections); + place.projection = tcx.mk_place_elems(&new_projections); } } } diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs index 71ab6dee1b6..fda0e1023f7 100644 --- a/compiler/rustc_mir_transform/src/elaborate_drops.rs +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -1,7 +1,7 @@ use crate::deref_separator::deref_finder; use crate::MirPass; -use rustc_data_structures::fx::FxHashMap; use rustc_index::bit_set::BitSet; +use rustc_index::IndexVec; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; use rustc_middle::ty::{self, TyCtxt}; @@ -15,41 +15,62 @@ use rustc_mir_dataflow::MoveDataParamEnv; use rustc_mir_dataflow::{on_all_children_bits, on_all_drop_children_bits}; use rustc_mir_dataflow::{Analysis, ResultsCursor}; use rustc_span::Span; -use rustc_target::abi::VariantIdx; +use rustc_target::abi::{FieldIdx, VariantIdx}; use std::fmt; +/// During MIR building, Drop terminators are inserted in every place where a drop may occur. +/// However, in this phase, the presence of these terminators does not guarantee that a destructor will run, +/// as the target of the drop may be uninitialized. +/// In general, the compiler cannot determine at compile time whether a destructor will run or not. +/// +/// At a high level, this pass refines Drop to only run the destructor if the +/// target is initialized. The way this is achieved is by inserting drop flags for every variable +/// that may be dropped, and then using those flags to determine whether a destructor should run. +/// Once this is complete, Drop terminators in the MIR correspond to a call to the "drop glue" or +/// "drop shim" for the type of the dropped place. +/// +/// This pass relies on dropped places having an associated move path, which is then used to determine +/// the initialization status of the place and its descendants. +/// It's worth noting that a MIR containing a Drop without an associated move path is probably ill formed, +/// as it would allow running a destructor on a place behind a reference: +/// +/// ```text +// fn drop_term<T>(t: &mut T) { +// mir!( +// { +// Drop(*t, exit) +// } +// exit = { +// Return() +// } +// ) +// } +/// ``` pub struct ElaborateDrops; impl<'tcx> MirPass<'tcx> for ElaborateDrops { - fn phase_change(&self) -> Option<MirPhase> { - Some(MirPhase::DropsLowered) - } - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { debug!("elaborate_drops({:?} @ {:?})", body.source, body.span); - let mut un_derefer = UnDerefer { tcx: tcx, derefer_sidetable: Default::default() }; - un_derefer.ref_finder(body); let def_id = body.source.def_id(); let param_env = tcx.param_env_reveal_all_normalized(def_id); - let move_data = match MoveData::gather_moves(body, tcx, param_env) { + let (side_table, move_data) = match MoveData::gather_moves(body, tcx, param_env) { Ok(move_data) => move_data, Err((move_data, _)) => { tcx.sess.delay_span_bug( body.span, "No `move_errors` should be allowed in MIR borrowck", ); - move_data + (Default::default(), move_data) } }; + let un_derefer = UnDerefer { tcx: tcx, derefer_sidetable: side_table }; let elaborate_patch = { - let body = &*body; let env = MoveDataParamEnv { move_data, param_env }; - let dead_unwinds = find_dead_unwinds(tcx, body, &env, &un_derefer); + remove_dead_unwinds(tcx, body, &env, &un_derefer); let inits = MaybeInitializedPlaces::new(tcx, body, &env) .into_engine(tcx, body) - .dead_unwinds(&dead_unwinds) .pass_name("elaborate_drops") .iterate_to_fixpoint() .into_results_cursor(body); @@ -57,19 +78,22 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { let uninits = MaybeUninitializedPlaces::new(tcx, body, &env) .mark_inactive_variants_as_uninit() .into_engine(tcx, body) - .dead_unwinds(&dead_unwinds) .pass_name("elaborate_drops") .iterate_to_fixpoint() .into_results_cursor(body); + let reachable = traversal::reachable_as_bitset(body); + + let drop_flags = IndexVec::from_elem(None, &env.move_data.move_paths); ElaborateDropsCtxt { tcx, body, env: &env, init_data: InitializationData { inits, uninits }, - drop_flags: Default::default(), + drop_flags, patch: MirPatch::new(body), un_derefer: un_derefer, + reachable, } .elaborate() }; @@ -78,43 +102,41 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { } } -/// Returns the set of basic blocks whose unwind edges are known -/// to not be reachable, because they are `drop` terminators +/// Removes unwind edges which are known to be unreachable, because they are in `drop` terminators /// that can't drop anything. -fn find_dead_unwinds<'tcx>( +fn remove_dead_unwinds<'tcx>( tcx: TyCtxt<'tcx>, - body: &Body<'tcx>, + body: &mut Body<'tcx>, env: &MoveDataParamEnv<'tcx>, und: &UnDerefer<'tcx>, -) -> BitSet<BasicBlock> { - debug!("find_dead_unwinds({:?})", body.span); +) { + debug!("remove_dead_unwinds({:?})", body.span); // We only need to do this pass once, because unwind edges can only // reach cleanup blocks, which can't have unwind edges themselves. - let mut dead_unwinds = BitSet::new_empty(body.basic_blocks().len()); + let mut dead_unwinds = Vec::new(); let mut flow_inits = MaybeInitializedPlaces::new(tcx, body, &env) .into_engine(tcx, body) - .pass_name("find_dead_unwinds") + .pass_name("remove_dead_unwinds") .iterate_to_fixpoint() .into_results_cursor(body); - for (bb, bb_data) in body.basic_blocks().iter_enumerated() { + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { let place = match bb_data.terminator().kind { - TerminatorKind::Drop { ref place, unwind: Some(_), .. } - | TerminatorKind::DropAndReplace { ref place, unwind: Some(_), .. } => { + TerminatorKind::Drop { ref place, unwind: UnwindAction::Cleanup(_), .. } => { und.derefer(place.as_ref(), body).unwrap_or(*place) } _ => continue, }; - debug!("find_dead_unwinds @ {:?}: {:?}", bb, bb_data); + debug!("remove_dead_unwinds @ {:?}: {:?}", bb, bb_data); let LookupResult::Exact(path) = env.move_data.rev_lookup.find(place.as_ref()) else { - debug!("find_dead_unwinds: has parent; skipping"); + debug!("remove_dead_unwinds: has parent; skipping"); continue; }; flow_inits.seek_before_primary_effect(body.terminator_loc(bb)); debug!( - "find_dead_unwinds @ {:?}: path({:?})={:?}; init_data={:?}", + "remove_dead_unwinds @ {:?}: path({:?})={:?}; init_data={:?}", bb, place, path, @@ -126,13 +148,22 @@ fn find_dead_unwinds<'tcx>( maybe_live |= flow_inits.contains(child); }); - debug!("find_dead_unwinds @ {:?}: maybe_live={}", bb, maybe_live); + debug!("remove_dead_unwinds @ {:?}: maybe_live={}", bb, maybe_live); if !maybe_live { - dead_unwinds.insert(bb); + dead_unwinds.push(bb); } } - dead_unwinds + if dead_unwinds.is_empty() { + return; + } + + let basic_blocks = body.basic_blocks.as_mut(); + for &bb in dead_unwinds.iter() { + if let Some(unwind) = basic_blocks[bb].terminator_mut().unwind_mut() { + *unwind = UnwindAction::Unreachable; + } + } } struct InitializationData<'mir, 'tcx> { @@ -222,7 +253,7 @@ impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> { } } - fn field_subpath(&self, path: Self::Path, field: Field) -> Option<Self::Path> { + fn field_subpath(&self, path: Self::Path, field: FieldIdx) -> Option<Self::Path> { rustc_mir_dataflow::move_path_children_matching(self.ctxt.move_data(), path, |e| match e { ProjectionElem::Field(idx, _) => idx == field, _ => false, @@ -263,9 +294,10 @@ struct ElaborateDropsCtxt<'a, 'tcx> { body: &'a Body<'tcx>, env: &'a MoveDataParamEnv<'tcx>, init_data: InitializationData<'a, 'tcx>, - drop_flags: FxHashMap<MovePathIndex, Local>, + drop_flags: IndexVec<MovePathIndex, Option<Local>>, patch: MirPatch<'tcx>, un_derefer: UnDerefer<'tcx>, + reachable: BitSet<BasicBlock>, } impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { @@ -281,11 +313,11 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { let tcx = self.tcx; let patch = &mut self.patch; debug!("create_drop_flag({:?})", self.body.span); - self.drop_flags.entry(index).or_insert_with(|| patch.new_internal(tcx.types.bool, span)); + self.drop_flags[index].get_or_insert_with(|| patch.new_internal(tcx.types.bool, span)); } fn drop_flag(&mut self, index: MovePathIndex) -> Option<Place<'tcx>> { - self.drop_flags.get(&index).map(|t| Place::from(*t)) + self.drop_flags[index].map(Place::from) } /// create a patch that elaborates all drops in the input @@ -304,11 +336,13 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { } fn collect_drop_flags(&mut self) { - for (bb, data) in self.body.basic_blocks().iter_enumerated() { + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } let terminator = data.terminator(); let place = match terminator.kind { - TerminatorKind::Drop { ref place, .. } - | TerminatorKind::DropAndReplace { ref place, .. } => { + TerminatorKind::Drop { ref place, .. } => { self.un_derefer.derefer(place.as_ref(), self.body).unwrap_or(*place) } _ => continue, @@ -332,7 +366,7 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { if maybe_dead { self.tcx.sess.delay_span_bug( terminator.source_info.span, - &format!( + format!( "drop of untracked, uninitialized value {:?}, place {:?} ({:?})", bb, place, path ), @@ -359,137 +393,65 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { } fn elaborate_drops(&mut self) { - for (bb, data) in self.body.basic_blocks().iter_enumerated() { + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } let loc = Location { block: bb, statement_index: data.statements.len() }; let terminator = data.terminator(); - let resume_block = self.patch.resume_block(); match terminator.kind { - TerminatorKind::Drop { mut place, target, unwind } => { + TerminatorKind::Drop { mut place, target, unwind, replace } => { if let Some(new_place) = self.un_derefer.derefer(place.as_ref(), self.body) { place = new_place; } self.init_data.seek_before(loc); match self.move_data().rev_lookup.find(place.as_ref()) { - LookupResult::Exact(path) => elaborate_drop( - &mut Elaborator { ctxt: self }, - terminator.source_info, - place, - path, - target, - if data.is_cleanup { + LookupResult::Exact(path) => { + let unwind = if data.is_cleanup { Unwind::InCleanup } else { - Unwind::To(Option::unwrap_or(unwind, resume_block)) - }, - bb, - ), + match unwind { + UnwindAction::Cleanup(cleanup) => Unwind::To(cleanup), + UnwindAction::Continue => Unwind::To(self.patch.resume_block()), + UnwindAction::Unreachable => { + Unwind::To(self.patch.unreachable_cleanup_block()) + } + UnwindAction::Terminate => { + Unwind::To(self.patch.terminate_block()) + } + } + }; + elaborate_drop( + &mut Elaborator { ctxt: self }, + terminator.source_info, + place, + path, + target, + unwind, + bb, + ) + } LookupResult::Parent(..) => { - self.tcx.sess.delay_span_bug( - terminator.source_info.span, - &format!("drop of untracked value {:?}", bb), - ); + if !replace { + self.tcx.sess.delay_span_bug( + terminator.source_info.span, + format!("drop of untracked value {:?}", bb), + ); + } + // A drop and replace behind a pointer/array/whatever. + // The borrow checker requires that these locations are initialized before the assignment, + // so we just leave an unconditional drop. + assert!(!data.is_cleanup); } } } - TerminatorKind::DropAndReplace { mut place, ref value, target, unwind } => { - assert!(!data.is_cleanup); - - if let Some(new_place) = self.un_derefer.derefer(place.as_ref(), self.body) { - place = new_place; - } - self.elaborate_replace(loc, place, value, target, unwind); - } _ => continue, } } } - /// Elaborate a MIR `replace` terminator. This instruction - /// is not directly handled by codegen, and therefore - /// must be desugared. - /// - /// The desugaring drops the location if needed, and then writes - /// the value (including setting the drop flag) over it in *both* arms. - /// - /// The `replace` terminator can also be called on places that - /// are not tracked by elaboration (for example, - /// `replace x[i] <- tmp0`). The borrow checker requires that - /// these locations are initialized before the assignment, - /// so we just generate an unconditional drop. - fn elaborate_replace( - &mut self, - loc: Location, - place: Place<'tcx>, - value: &Operand<'tcx>, - target: BasicBlock, - unwind: Option<BasicBlock>, - ) { - let bb = loc.block; - let data = &self.body[bb]; - let terminator = data.terminator(); - assert!(!data.is_cleanup, "DropAndReplace in unwind path not supported"); - - let assign = Statement { - kind: StatementKind::Assign(Box::new((place, Rvalue::Use(value.clone())))), - source_info: terminator.source_info, - }; - - let unwind = unwind.unwrap_or_else(|| self.patch.resume_block()); - let unwind = self.patch.new_block(BasicBlockData { - statements: vec![assign.clone()], - terminator: Some(Terminator { - kind: TerminatorKind::Goto { target: unwind }, - ..*terminator - }), - is_cleanup: true, - }); - - let target = self.patch.new_block(BasicBlockData { - statements: vec![assign], - terminator: Some(Terminator { kind: TerminatorKind::Goto { target }, ..*terminator }), - is_cleanup: false, - }); - - match self.move_data().rev_lookup.find(place.as_ref()) { - LookupResult::Exact(path) => { - debug!("elaborate_drop_and_replace({:?}) - tracked {:?}", terminator, path); - self.init_data.seek_before(loc); - elaborate_drop( - &mut Elaborator { ctxt: self }, - terminator.source_info, - place, - path, - target, - Unwind::To(unwind), - bb, - ); - on_all_children_bits(self.tcx, self.body, self.move_data(), path, |child| { - self.set_drop_flag( - Location { block: target, statement_index: 0 }, - child, - DropFlagState::Present, - ); - self.set_drop_flag( - Location { block: unwind, statement_index: 0 }, - child, - DropFlagState::Present, - ); - }); - } - LookupResult::Parent(parent) => { - // drop and replace behind a pointer/array/whatever. The location - // must be initialized. - debug!("elaborate_drop_and_replace({:?}) - untracked {:?}", terminator, parent); - self.patch.patch_terminator( - bb, - TerminatorKind::Drop { place, target, unwind: Some(unwind) }, - ); - } - } - } - fn constant_bool(&self, span: Span, val: bool) -> Rvalue<'tcx> { Rvalue::Use(Operand::Constant(Box::new(Constant { span, @@ -499,7 +461,7 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { } fn set_drop_flag(&mut self, loc: Location, path: MovePathIndex, val: DropFlagState) { - if let Some(&flag) = self.drop_flags.get(&path) { + if let Some(flag) = self.drop_flags[path] { let span = self.patch.source_info_for_location(self.body, loc).span; let val = self.constant_bool(span, val.value()); self.patch.add_assign(loc, Place::from(flag), val); @@ -510,15 +472,21 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { let loc = Location::START; let span = self.patch.source_info_for_location(self.body, loc).span; let false_ = self.constant_bool(span, false); - for flag in self.drop_flags.values() { + for flag in self.drop_flags.iter().flatten() { self.patch.add_assign(loc, Place::from(*flag), false_.clone()); } } fn drop_flags_for_fn_rets(&mut self) { - for (bb, data) in self.body.basic_blocks().iter_enumerated() { + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } if let TerminatorKind::Call { - destination, target: Some(tgt), cleanup: Some(_), .. + destination, + target: Some(tgt), + unwind: UnwindAction::Cleanup(_), + .. } = data.terminator().kind { assert!(!self.patch.is_patched(bb)); @@ -551,26 +519,19 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { // drop flags by themselves, to avoid the drop flags being // clobbered before they are read. - for (bb, data) in self.body.basic_blocks().iter_enumerated() { + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } debug!("drop_flags_for_locs({:?})", data); for i in 0..(data.statements.len() + 1) { debug!("drop_flag_for_locs: stmt {}", i); - let mut allow_initializations = true; if i == data.statements.len() { match data.terminator().kind { TerminatorKind::Drop { .. } => { // drop elaboration should handle that by itself continue; } - TerminatorKind::DropAndReplace { .. } => { - // this contains the move of the source and - // the initialization of the destination. We - // only want the former - the latter is handled - // by the elaboration code and must be done - // *after* the destination is dropped. - assert!(self.patch.is_patched(bb)); - allow_initializations = false; - } TerminatorKind::Resume => { // It is possible for `Resume` to be patched // (in particular it can be patched to be replaced with @@ -587,19 +548,19 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { self.body, self.env, loc, - |path, ds| { - if ds == DropFlagState::Absent || allow_initializations { - self.set_drop_flag(loc, path, ds) - } - }, + |path, ds| self.set_drop_flag(loc, path, ds), ) } // There may be a critical edge after this call, // so mark the return as initialized *before* the // call. - if let TerminatorKind::Call { destination, target: Some(_), cleanup: None, .. } = - data.terminator().kind + if let TerminatorKind::Call { + destination, + target: Some(_), + unwind: UnwindAction::Continue | UnwindAction::Unreachable | UnwindAction::Terminate, + .. + } = data.terminator().kind { assert!(!self.patch.is_patched(bb)); diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs new file mode 100644 index 00000000000..22f71bb0851 --- /dev/null +++ b/compiler/rustc_mir_transform/src/errors.rs @@ -0,0 +1,252 @@ +use rustc_errors::{ + DecorateLint, DiagnosticBuilder, DiagnosticMessage, EmissionGuarantee, Handler, IntoDiagnostic, +}; +use rustc_macros::{Diagnostic, LintDiagnostic, Subdiagnostic}; +use rustc_middle::mir::{AssertKind, UnsafetyViolationDetails}; +use rustc_session::lint::{self, Lint}; +use rustc_span::Span; + +#[derive(LintDiagnostic)] +pub(crate) enum ConstMutate { + #[diag(mir_transform_const_modify)] + #[note] + Modify { + #[note(mir_transform_const_defined_here)] + konst: Span, + }, + #[diag(mir_transform_const_mut_borrow)] + #[note] + #[note(mir_transform_note2)] + MutBorrow { + #[note(mir_transform_note3)] + method_call: Option<Span>, + #[note(mir_transform_const_defined_here)] + konst: Span, + }, +} + +#[derive(Diagnostic)] +#[diag(mir_transform_unaligned_packed_ref, code = "E0793")] +#[note] +#[note(mir_transform_note_ub)] +#[help] +pub(crate) struct UnalignedPackedRef { + #[primary_span] + pub span: Span, +} + +#[derive(LintDiagnostic)] +#[diag(mir_transform_unused_unsafe)] +pub(crate) struct UnusedUnsafe { + #[label(mir_transform_unused_unsafe)] + pub span: Span, + #[label] + pub nested_parent: Option<Span>, +} + +pub(crate) struct RequiresUnsafe { + pub span: Span, + pub details: RequiresUnsafeDetail, + pub enclosing: Option<Span>, + pub op_in_unsafe_fn_allowed: bool, +} + +// The primary message for this diagnostic should be '{$label} is unsafe and...', +// so we need to eagerly translate the label here, which isn't supported by the derive API +// We could also exhaustively list out the primary messages for all unsafe violations, +// but this would result in a lot of duplication. +impl<'sess, G: EmissionGuarantee> IntoDiagnostic<'sess, G> for RequiresUnsafe { + #[track_caller] + fn into_diagnostic(self, handler: &'sess Handler) -> DiagnosticBuilder<'sess, G> { + let mut diag = + handler.struct_diagnostic(crate::fluent_generated::mir_transform_requires_unsafe); + diag.code(rustc_errors::DiagnosticId::Error("E0133".to_string())); + diag.set_span(self.span); + diag.span_label(self.span, self.details.label()); + diag.note(self.details.note()); + let desc = handler.eagerly_translate_to_string(self.details.label(), [].into_iter()); + diag.set_arg("details", desc); + diag.set_arg("op_in_unsafe_fn_allowed", self.op_in_unsafe_fn_allowed); + if let Some(sp) = self.enclosing { + diag.span_label(sp, crate::fluent_generated::mir_transform_not_inherited); + } + diag + } +} + +#[derive(Copy, Clone)] +pub(crate) struct RequiresUnsafeDetail { + pub span: Span, + pub violation: UnsafetyViolationDetails, +} + +impl RequiresUnsafeDetail { + fn note(self) -> DiagnosticMessage { + use UnsafetyViolationDetails::*; + match self.violation { + CallToUnsafeFunction => crate::fluent_generated::mir_transform_call_to_unsafe_note, + UseOfInlineAssembly => crate::fluent_generated::mir_transform_use_of_asm_note, + InitializingTypeWith => { + crate::fluent_generated::mir_transform_initializing_valid_range_note + } + CastOfPointerToInt => crate::fluent_generated::mir_transform_const_ptr2int_note, + UseOfMutableStatic => crate::fluent_generated::mir_transform_use_of_static_mut_note, + UseOfExternStatic => crate::fluent_generated::mir_transform_use_of_extern_static_note, + DerefOfRawPointer => crate::fluent_generated::mir_transform_deref_ptr_note, + AccessToUnionField => crate::fluent_generated::mir_transform_union_access_note, + MutationOfLayoutConstrainedField => { + crate::fluent_generated::mir_transform_mutation_layout_constrained_note + } + BorrowOfLayoutConstrainedField => { + crate::fluent_generated::mir_transform_mutation_layout_constrained_borrow_note + } + CallToFunctionWith => crate::fluent_generated::mir_transform_target_feature_call_note, + } + } + + fn label(self) -> DiagnosticMessage { + use UnsafetyViolationDetails::*; + match self.violation { + CallToUnsafeFunction => crate::fluent_generated::mir_transform_call_to_unsafe_label, + UseOfInlineAssembly => crate::fluent_generated::mir_transform_use_of_asm_label, + InitializingTypeWith => { + crate::fluent_generated::mir_transform_initializing_valid_range_label + } + CastOfPointerToInt => crate::fluent_generated::mir_transform_const_ptr2int_label, + UseOfMutableStatic => crate::fluent_generated::mir_transform_use_of_static_mut_label, + UseOfExternStatic => crate::fluent_generated::mir_transform_use_of_extern_static_label, + DerefOfRawPointer => crate::fluent_generated::mir_transform_deref_ptr_label, + AccessToUnionField => crate::fluent_generated::mir_transform_union_access_label, + MutationOfLayoutConstrainedField => { + crate::fluent_generated::mir_transform_mutation_layout_constrained_label + } + BorrowOfLayoutConstrainedField => { + crate::fluent_generated::mir_transform_mutation_layout_constrained_borrow_label + } + CallToFunctionWith => crate::fluent_generated::mir_transform_target_feature_call_label, + } + } +} + +pub(crate) struct UnsafeOpInUnsafeFn { + pub details: RequiresUnsafeDetail, +} + +impl<'a> DecorateLint<'a, ()> for UnsafeOpInUnsafeFn { + #[track_caller] + fn decorate_lint<'b>( + self, + diag: &'b mut DiagnosticBuilder<'a, ()>, + ) -> &'b mut DiagnosticBuilder<'a, ()> { + let desc = diag + .handler() + .expect("lint should not yet be emitted") + .eagerly_translate_to_string(self.details.label(), [].into_iter()); + diag.set_arg("details", desc); + diag.span_label(self.details.span, self.details.label()); + diag.note(self.details.note()); + diag + } + + fn msg(&self) -> DiagnosticMessage { + crate::fluent_generated::mir_transform_unsafe_op_in_unsafe_fn + } +} + +pub(crate) enum AssertLint<P> { + ArithmeticOverflow(Span, AssertKind<P>), + UnconditionalPanic(Span, AssertKind<P>), +} + +impl<'a, P: std::fmt::Debug> DecorateLint<'a, ()> for AssertLint<P> { + fn decorate_lint<'b>( + self, + diag: &'b mut DiagnosticBuilder<'a, ()>, + ) -> &'b mut DiagnosticBuilder<'a, ()> { + let span = self.span(); + let assert_kind = self.panic(); + let message = assert_kind.diagnostic_message(); + assert_kind.add_args(&mut |name, value| { + diag.set_arg(name, value); + }); + diag.span_label(span, message); + + diag + } + + fn msg(&self) -> DiagnosticMessage { + match self { + AssertLint::ArithmeticOverflow(..) => { + crate::fluent_generated::mir_transform_arithmetic_overflow + } + AssertLint::UnconditionalPanic(..) => { + crate::fluent_generated::mir_transform_operation_will_panic + } + } + } +} + +impl<P> AssertLint<P> { + pub fn lint(&self) -> &'static Lint { + match self { + AssertLint::ArithmeticOverflow(..) => lint::builtin::ARITHMETIC_OVERFLOW, + AssertLint::UnconditionalPanic(..) => lint::builtin::UNCONDITIONAL_PANIC, + } + } + pub fn span(&self) -> Span { + match self { + AssertLint::ArithmeticOverflow(sp, _) | AssertLint::UnconditionalPanic(sp, _) => *sp, + } + } + pub fn panic(self) -> AssertKind<P> { + match self { + AssertLint::ArithmeticOverflow(_, p) | AssertLint::UnconditionalPanic(_, p) => p, + } + } +} + +#[derive(LintDiagnostic)] +#[diag(mir_transform_ffi_unwind_call)] +pub(crate) struct FfiUnwindCall { + #[label(mir_transform_ffi_unwind_call)] + pub span: Span, + pub foreign: bool, +} + +#[derive(LintDiagnostic)] +#[diag(mir_transform_fn_item_ref)] +pub(crate) struct FnItemRef { + #[suggestion(code = "{sugg}", applicability = "unspecified")] + pub span: Span, + pub sugg: String, + pub ident: String, +} + +#[derive(LintDiagnostic)] +#[diag(mir_transform_must_not_suspend)] +pub(crate) struct MustNotSupend<'a> { + #[label] + pub yield_sp: Span, + #[subdiagnostic] + pub reason: Option<MustNotSuspendReason>, + #[help] + pub src_sp: Span, + pub pre: &'a str, + pub def_path: String, + pub post: &'a str, +} + +#[derive(Subdiagnostic)] +#[note(mir_transform_note)] +pub(crate) struct MustNotSuspendReason { + #[primary_span] + pub span: Span, + pub reason: String, +} + +#[derive(Diagnostic)] +#[diag(mir_transform_simd_shuffle_last_const)] +pub(crate) struct SimdShuffleLastConst { + #[primary_span] + pub span: Span, +} diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs index 7728fdaffb0..58cc161ddcc 100644 --- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -1,12 +1,15 @@ -use rustc_hir::def_id::{CrateNum, LocalDefId, LOCAL_CRATE}; +use rustc_hir::def_id::{LocalDefId, LOCAL_CRATE}; use rustc_middle::mir::*; +use rustc_middle::query::LocalCrate; +use rustc_middle::query::Providers; use rustc_middle::ty::layout; -use rustc_middle::ty::query::Providers; use rustc_middle::ty::{self, TyCtxt}; use rustc_session::lint::builtin::FFI_UNWIND_CALLS; use rustc_target::spec::abi::Abi; use rustc_target::spec::PanicStrategy; +use crate::errors; + fn abi_can_unwind(abi: Abi) -> bool { use Abi::*; match abi { @@ -47,9 +50,9 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { return false; } - let body = &*tcx.mir_built(ty::WithOptConstParam::unknown(local_def_id)).borrow(); + let body = &*tcx.mir_built(local_def_id).borrow(); - let body_ty = tcx.type_of(def_id); + let body_ty = tcx.type_of(def_id).skip_binder(); let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, @@ -65,7 +68,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { let mut tainted = false; - for block in body.basic_blocks() { + for block in body.basic_blocks.iter() { if block.is_cleanup { continue; } @@ -106,15 +109,13 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { .lint_root; let span = terminator.source_info.span; - tcx.struct_span_lint_hir(FFI_UNWIND_CALLS, lint_root, span, |lint| { - let msg = match fn_def_id { - Some(_) => "call to foreign function with FFI-unwind ABI", - None => "call to function pointer with FFI-unwind ABI", - }; - let mut db = lint.build(msg); - db.span_label(span, msg); - db.emit(); - }); + let foreign = fn_def_id.is_some(); + tcx.emit_spanned_lint( + FFI_UNWIND_CALLS, + lint_root, + span, + errors::FfiUnwindCall { span, foreign }, + ); tainted = true; } @@ -123,9 +124,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { tainted } -fn required_panic_strategy(tcx: TyCtxt<'_>, cnum: CrateNum) -> Option<PanicStrategy> { - assert_eq!(cnum, LOCAL_CRATE); - +fn required_panic_strategy(tcx: TyCtxt<'_>, _: LocalCrate) -> Option<PanicStrategy> { if tcx.is_panic_runtime(LOCAL_CRATE) { return Some(tcx.sess.panic_strategy()); } diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index 2e4fe1e3e5d..b1c9c4acc40 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -1,18 +1,13 @@ use itertools::Itertools; -use rustc_errors::Applicability; use rustc_hir::def_id::DefId; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; -use rustc_middle::ty::{ - self, - subst::{GenericArgKind, Subst, SubstsRef}, - EarlyBinder, PredicateKind, Ty, TyCtxt, -}; +use rustc_middle::ty::{self, EarlyBinder, PredicateKind, SubstsRef, Ty, TyCtxt}; use rustc_session::lint::builtin::FUNCTION_ITEM_REFERENCES; use rustc_span::{symbol::sym, Span}; use rustc_target::spec::abi::Abi; -use crate::MirLint; +use crate::{errors, MirLint}; pub struct FunctionItemReferences; @@ -38,7 +33,7 @@ impl<'tcx> Visitor<'tcx> for FunctionItemRefChecker<'_, 'tcx> { args, destination: _, target: _, - cleanup: _, + unwind: _, from_hir_call: _, fn_span: _, } = &terminator.kind @@ -49,14 +44,12 @@ impl<'tcx> Visitor<'tcx> for FunctionItemRefChecker<'_, 'tcx> { // Handle calls to `transmute` if self.tcx.is_diagnostic_item(sym::transmute, def_id) { let arg_ty = args[0].ty(self.body, self.tcx); - for generic_inner_ty in arg_ty.walk() { - if let GenericArgKind::Type(inner_ty) = generic_inner_ty.unpack() { - if let Some((fn_id, fn_substs)) = - FunctionItemRefChecker::is_fn_ref(inner_ty) - { - let span = self.nth_arg_span(&args, 0); - self.emit_lint(fn_id, fn_substs, source_info, span); - } + for inner_ty in arg_ty.walk().filter_map(|arg| arg.as_type()) { + if let Some((fn_id, fn_substs)) = + FunctionItemRefChecker::is_fn_ref(inner_ty) + { + let span = self.nth_arg_span(&args, 0); + self.emit_lint(fn_id, fn_substs, source_info, span); } } } else { @@ -83,27 +76,25 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { for bound in bounds { if let Some(bound_ty) = self.is_pointer_trait(&bound.kind().skip_binder()) { // Get the argument types as they appear in the function signature. - let arg_defs = self.tcx.fn_sig(def_id).skip_binder().inputs(); + let arg_defs = self.tcx.fn_sig(def_id).subst_identity().skip_binder().inputs(); for (arg_num, arg_def) in arg_defs.iter().enumerate() { // For all types reachable from the argument type in the fn sig - for generic_inner_ty in arg_def.walk() { - if let GenericArgKind::Type(inner_ty) = generic_inner_ty.unpack() { - // If the inner type matches the type bound by `Pointer` - if inner_ty == bound_ty { - // Do a substitution using the parameters from the callsite - let subst_ty = EarlyBinder(inner_ty).subst(self.tcx, substs_ref); - if let Some((fn_id, fn_substs)) = - FunctionItemRefChecker::is_fn_ref(subst_ty) - { - let mut span = self.nth_arg_span(args, arg_num); - if span.from_expansion() { - // The operand's ctxt wouldn't display the lint since it's inside a macro so - // we have to use the callsite's ctxt. - let callsite_ctxt = span.source_callsite().ctxt(); - span = span.with_ctxt(callsite_ctxt); - } - self.emit_lint(fn_id, fn_substs, source_info, span); + for inner_ty in arg_def.walk().filter_map(|arg| arg.as_type()) { + // If the inner type matches the type bound by `Pointer` + if inner_ty == bound_ty { + // Do a substitution using the parameters from the callsite + let subst_ty = EarlyBinder::bind(inner_ty).subst(self.tcx, substs_ref); + if let Some((fn_id, fn_substs)) = + FunctionItemRefChecker::is_fn_ref(subst_ty) + { + let mut span = self.nth_arg_span(args, arg_num); + if span.from_expansion() { + // The operand's ctxt wouldn't display the lint since it's inside a macro so + // we have to use the callsite's ctxt. + let callsite_ctxt = span.source_callsite().ctxt(); + span = span.with_ctxt(callsite_ctxt); } + self.emit_lint(fn_id, fn_substs, source_info, span); } } } @@ -114,12 +105,10 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { /// If the given predicate is the trait `fmt::Pointer`, returns the bound parameter type. fn is_pointer_trait(&self, bound: &PredicateKind<'tcx>) -> Option<Ty<'tcx>> { - if let ty::PredicateKind::Trait(predicate) = bound { - if self.tcx.is_diagnostic_item(sym::Pointer, predicate.def_id()) { - Some(predicate.trait_ref.self_ty()) - } else { - None - } + if let ty::PredicateKind::Clause(ty::Clause::Trait(predicate)) = bound { + self.tcx + .is_diagnostic_item(sym::Pointer, predicate.def_id()) + .then(|| predicate.trait_ref.self_ty()) } else { None } @@ -165,7 +154,8 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { .as_ref() .assert_crate_local() .lint_root; - let fn_sig = self.tcx.fn_sig(fn_id); + // FIXME: use existing printing routines to print the function signature + let fn_sig = self.tcx.fn_sig(fn_id).subst(self.tcx, fn_substs); let unsafety = fn_sig.unsafety().prefix_str(); let abi = match fn_sig.abi() { Abi::Rust => String::from(""), @@ -183,23 +173,21 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { let num_args = fn_sig.inputs().map_bound(|inputs| inputs.len()).skip_binder(); let variadic = if fn_sig.c_variadic() { ", ..." } else { "" }; let ret = if fn_sig.output().skip_binder().is_unit() { "" } else { " -> _" }; - self.tcx.struct_span_lint_hir(FUNCTION_ITEM_REFERENCES, lint_root, span, |lint| { - lint.build("taking a reference to a function item does not give a function pointer") - .span_suggestion( - span, - &format!("cast `{}` to obtain a function pointer", ident), - format!( - "{} as {}{}fn({}{}){}", - if params.is_empty() { ident } else { format!("{}::<{}>", ident, params) }, - unsafety, - abi, - vec!["_"; num_args].join(", "), - variadic, - ret, - ), - Applicability::Unspecified, - ) - .emit(); - }); + let sugg = format!( + "{} as {}{}fn({}{}){}", + if params.is_empty() { ident.clone() } else { format!("{}::<{}>", ident, params) }, + unsafety, + abi, + vec!["_"; num_args].join(", "), + variadic, + ret, + ); + + self.tcx.emit_spanned_lint( + FUNCTION_ITEM_REFERENCES, + lint_root, + span, + errors::FnItemRef { span, sugg, ident }, + ); } } diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index 91ecf387922..fe3f8ed047a 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -11,10 +11,10 @@ //! generator in the MIR, since it is used to create the drop glue for the generator. We'd get //! infinite recursion otherwise. //! -//! This pass creates the implementation for the Generator::resume function and the drop shim -//! for the generator based on the MIR input. It converts the generator argument from Self to -//! &mut Self adding derefs in the MIR as needed. It computes the final layout of the generator -//! struct which looks like this: +//! This pass creates the implementation for either the `Generator::resume` or `Future::poll` +//! function and the drop shim for the generator based on the MIR input. +//! It converts the generator argument from Self to &mut Self adding derefs in the MIR as needed. +//! It computes the final layout of the generator struct which looks like this: //! First upvars are stored //! It is followed by the generator state field. //! Then finally the MIR locals which are live across a suspension point are stored. @@ -32,14 +32,15 @@ //! 2 - Generator has been poisoned //! //! It also rewrites `return x` and `yield y` as setting a new generator state and returning -//! GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively. +//! `GeneratorState::Complete(x)` and `GeneratorState::Yielded(y)`, +//! or `Poll::Ready(x)` and `Poll::Pending` respectively. //! MIR locals which are live across a suspension point are moved to the generator struct //! with references to them being updated with references to the generator struct. //! //! The pass creates two functions which have a switch on the generator state giving //! the action to take. //! -//! One of them is the implementation of Generator::resume. +//! One of them is the implementation of `Generator::resume` / `Future::poll`. //! For generators with state 0 (unresumed) it starts the execution of the generator. //! For generators with state 1 (returned) and state 2 (poisoned) it panics. //! Otherwise it continues the execution from the last suspension point. @@ -50,26 +51,30 @@ //! Otherwise it drops all the values in scope at the last suspension point. use crate::deref_separator::deref_finder; +use crate::errors; use crate::simplify; -use crate::util::expand_aggregate; use crate::MirPass; -use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_errors::pluralize; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; +use rustc_hir::GeneratorKind; use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet}; -use rustc_index::vec::{Idx, IndexVec}; +use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::dump_mir; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::ty::subst::{Subst, SubstsRef}; -use rustc_middle::ty::GeneratorSubsts; use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt}; +use rustc_middle::ty::{GeneratorSubsts, SubstsRef}; use rustc_mir_dataflow::impls::{ MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, }; use rustc_mir_dataflow::storage::always_storage_live_locals; use rustc_mir_dataflow::{self, Analysis}; -use rustc_target::abi::VariantIdx; +use rustc_span::def_id::{DefId, LocalDefId}; +use rustc_span::symbol::sym; +use rustc_span::Span; +use rustc_target::abi::{FieldIdx, VariantIdx}; use rustc_target::spec::PanicStrategy; use std::{iter, ops}; @@ -122,7 +127,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> { place, Place { local: SELF_ARG, - projection: self.tcx().intern_place_elems(&[ProjectionElem::Deref]), + projection: self.tcx().mk_place_elems(&[ProjectionElem::Deref]), }, self.tcx, ); @@ -158,8 +163,8 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> { place, Place { local: SELF_ARG, - projection: self.tcx().intern_place_elems(&[ProjectionElem::Field( - Field::new(0), + projection: self.tcx().mk_place_elems(&[ProjectionElem::Field( + FieldIdx::new(0), self.ref_gen_ty, )]), }, @@ -183,7 +188,7 @@ fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtx let mut new_projection = new_base.projection.to_vec(); new_projection.append(&mut place.projection.to_vec()); - place.projection = tcx.intern_place_elems(&new_projection); + place.projection = tcx.mk_place_elems(&new_projection); } const SELF_ARG: Local = Local::from_u32(1); @@ -216,6 +221,7 @@ struct SuspensionPoint<'tcx> { struct TransformVisitor<'tcx> { tcx: TyCtxt<'tcx>, + is_async_kind: bool, state_adt_ref: AdtDef<'tcx>, state_substs: SubstsRef<'tcx>, @@ -240,28 +246,52 @@ struct TransformVisitor<'tcx> { } impl<'tcx> TransformVisitor<'tcx> { - // Make a GeneratorState variant assignment. `core::ops::GeneratorState` only has single - // element tuple variants, so we can just write to the downcasted first field and then set the + // Make a `GeneratorState` or `Poll` variant assignment. + // + // `core::ops::GeneratorState` only has single element tuple variants, + // so we can just write to the downcasted first field and then set the // discriminant to the appropriate variant. fn make_state( &self, - idx: VariantIdx, val: Operand<'tcx>, source_info: SourceInfo, - ) -> impl Iterator<Item = Statement<'tcx>> { + is_return: bool, + statements: &mut Vec<Statement<'tcx>>, + ) { + let idx = VariantIdx::new(match (is_return, self.is_async_kind) { + (true, false) => 1, // GeneratorState::Complete + (false, false) => 0, // GeneratorState::Yielded + (true, true) => 0, // Poll::Ready + (false, true) => 1, // Poll::Pending + }); + let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_substs, None, None); + + // `Poll::Pending` + if self.is_async_kind && idx == VariantIdx::new(1) { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + + // FIXME(swatinem): assert that `val` is indeed unit? + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + + // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); - let ty = self - .tcx - .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did) - .subst(self.tcx, self.state_substs); - expand_aggregate( - Place::return_place(), - std::iter::once((val, ty)), - kind, + + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), [val].into()), + ))), source_info, - self.tcx, - ) + }); } // Create a Place referencing a generator struct field @@ -269,9 +299,9 @@ impl<'tcx> TransformVisitor<'tcx> { let self_place = Place::from(SELF_ARG); let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index); let mut projection = base.projection.to_vec(); - projection.push(ProjectionElem::Field(Field::new(idx), ty)); + projection.push(ProjectionElem::Field(FieldIdx::new(idx), ty)); - Place { local: base.local, projection: self.tcx.intern_place_elems(&projection) } + Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) } } // Create a statement which changes the discriminant @@ -332,22 +362,19 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { }); let ret_val = match data.terminator().kind { - TerminatorKind::Return => Some(( - VariantIdx::new(1), - None, - Operand::Move(Place::from(self.new_ret_local)), - None, - )), + TerminatorKind::Return => { + Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None)) + } TerminatorKind::Yield { ref value, resume, resume_arg, drop } => { - Some((VariantIdx::new(0), Some((resume, resume_arg)), value.clone(), drop)) + Some((false, Some((resume, resume_arg)), value.clone(), drop)) } _ => None, }; - if let Some((state_idx, resume, v, drop)) = ret_val { + if let Some((is_return, resume, v, drop)) = ret_val { let source_info = data.terminator().source_info; // We must assign the value first in case it gets declared dead below - data.statements.extend(self.make_state(state_idx, v, source_info)); + self.make_state(v, source_info, is_return, &mut data.statements); let state = if let Some((resume, mut resume_arg)) = resume { // Yield let state = RESERVED_VARIANTS + self.suspension_points.len(); @@ -401,7 +428,7 @@ fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span)); let pin_adt_ref = tcx.adt_def(pin_did); - let substs = tcx.intern_substs(&[ref_gen_ty.into()]); + let substs = tcx.mk_substs(&[ref_gen_ty.into()]); let pin_ref_gen_ty = tcx.mk_adt(pin_adt_ref, substs); // Replace the by ref generator argument @@ -432,6 +459,104 @@ fn replace_local<'tcx>( new_local } +/// Transforms the `body` of the generator applying the following transforms: +/// +/// - Eliminates all the `get_context` calls that async lowering created. +/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`). +/// +/// The `Local`s that have their types replaced are: +/// - The `resume` argument itself. +/// - The argument to `get_context`. +/// - The yielded value of a `yield`. +/// +/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the +/// `get_context` function is being used to convert that back to a `&mut Context<'_>`. +/// +/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection, +/// but rather directly use `&mut Context<'_>`, however that would currently +/// lead to higher-kinded lifetime errors. +/// See <https://github.com/rust-lang/rust/issues/105501>. +/// +/// The async lowering step and the type / lifetime inference / checking are +/// still using the `ResumeTy` indirection for the time being, and that indirection +/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`. +fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let context_mut_ref = tcx.mk_task_context(); + + // replace the type of the `resume` argument + replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref); + + let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None); + + for bb in START_BLOCK..body.basic_blocks.next_index() { + let bb_data = &body[bb]; + if bb_data.is_cleanup { + continue; + } + + match &bb_data.terminator().kind { + TerminatorKind::Call { func, .. } => { + let func_ty = func.ty(body, tcx); + if let ty::FnDef(def_id, _) = *func_ty.kind() { + if def_id == get_context_def_id { + let local = eliminate_get_context_call(&mut body[bb]); + replace_resume_ty_local(tcx, body, local, context_mut_ref); + } + } else { + continue; + } + } + TerminatorKind::Yield { resume_arg, .. } => { + replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref); + } + _ => {} + } + } +} + +fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local { + let terminator = bb_data.terminator.take().unwrap(); + if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind { + let arg = args.pop().unwrap(); + let local = arg.place().unwrap().local; + + let arg = Rvalue::Use(arg); + let assign = Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new((destination, arg))), + }; + bb_data.statements.push(assign); + bb_data.terminator = Some(Terminator { + source_info: terminator.source_info, + kind: TerminatorKind::Goto { target: target.unwrap() }, + }); + local + } else { + bug!(); + } +} + +#[cfg_attr(not(debug_assertions), allow(unused))] +fn replace_resume_ty_local<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + local: Local, + context_mut_ref: Ty<'tcx>, +) { + let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref); + // We have to replace the `ResumeTy` that is used for type and borrow checking + // with `&mut Context<'_>` in MIR. + #[cfg(debug_assertions)] + { + if let ty::Adt(resume_ty_adt, _) = local_ty.kind() { + let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None)); + assert_eq!(*resume_ty_adt, expected_adt); + } else { + panic!("expected `ResumeTy`, found `{:?}`", local_ty); + }; + } +} + struct LivenessInfo { /// Which locals are live across any suspension point. saved_locals: GeneratorSavedLocals, @@ -462,7 +587,7 @@ fn locals_live_across_suspend_points<'tcx>( // Calculate when MIR locals have live storage. This gives us an upper bound of their // lifetimes. - let mut storage_live = MaybeStorageLive::new(always_live_locals.clone()) + let mut storage_live = MaybeStorageLive::new(std::borrow::Cow::Borrowed(always_live_locals)) .into_engine(tcx, body_ref) .iterate_to_fixpoint() .into_results_cursor(body_ref); @@ -472,16 +597,15 @@ fn locals_live_across_suspend_points<'tcx>( let borrowed_locals_results = MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("generator").iterate_to_fixpoint(); - let mut borrowed_locals_cursor = - rustc_mir_dataflow::ResultsCursor::new(body_ref, &borrowed_locals_results); + let mut borrowed_locals_cursor = borrowed_locals_results.cloned_results_cursor(body_ref); // Calculate the MIR locals that we actually need to keep storage around // for. - let requires_storage_results = MaybeRequiresStorage::new(body, &borrowed_locals_results) - .into_engine(tcx, body_ref) - .iterate_to_fixpoint(); - let mut requires_storage_cursor = - rustc_mir_dataflow::ResultsCursor::new(body_ref, &requires_storage_results); + let mut requires_storage_results = + MaybeRequiresStorage::new(borrowed_locals_results.cloned_results_cursor(body)) + .into_engine(tcx, body_ref) + .iterate_to_fixpoint(); + let mut requires_storage_cursor = requires_storage_results.as_results_cursor(body_ref); // Calculate the liveness of MIR locals ignoring borrows. let mut liveness = MaybeLiveLocals @@ -490,12 +614,12 @@ fn locals_live_across_suspend_points<'tcx>( .iterate_to_fixpoint() .into_results_cursor(body_ref); - let mut storage_liveness_map = IndexVec::from_elem(None, body.basic_blocks()); + let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks); let mut live_locals_at_suspension_points = Vec::new(); let mut source_info_at_suspension_points = Vec::new(); let mut live_locals_at_any_suspension_point = BitSet::new_empty(body.local_decls.len()); - for (block, data) in body.basic_blocks().iter_enumerated() { + for (block, data) in body.basic_blocks.iter_enumerated() { if let TerminatorKind::Yield { .. } = data.terminator().kind { let loc = Location { block, statement_index: data.statements.len() }; @@ -622,7 +746,7 @@ fn compute_storage_conflicts<'mir, 'tcx>( body: &'mir Body<'tcx>, saved_locals: &GeneratorSavedLocals, always_live_locals: BitSet<Local>, - requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>, + mut requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'_, 'mir, 'tcx>>, ) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> { assert_eq!(body.local_decls.len(), saved_locals.domain_size()); @@ -677,13 +801,14 @@ struct StorageConflictVisitor<'mir, 'tcx, 's> { local_conflicts: BitMatrix<Local, Local>, } -impl<'mir, 'tcx> rustc_mir_dataflow::ResultsVisitor<'mir, 'tcx> +impl<'mir, 'tcx, R> rustc_mir_dataflow::ResultsVisitor<'mir, 'tcx, R> for StorageConflictVisitor<'mir, 'tcx, '_> { type FlowState = BitSet<Local>; fn visit_statement_before_primary_effect( &mut self, + _results: &R, state: &Self::FlowState, _statement: &'mir Statement<'tcx>, loc: Location, @@ -693,6 +818,7 @@ impl<'mir, 'tcx> rustc_mir_dataflow::ResultsVisitor<'mir, 'tcx> fn visit_terminator_before_primary_effect( &mut self, + _results: &R, state: &Self::FlowState, _terminator: &'mir Terminator<'tcx>, loc: Location, @@ -704,7 +830,7 @@ impl<'mir, 'tcx> rustc_mir_dataflow::ResultsVisitor<'mir, 'tcx> impl StorageConflictVisitor<'_, '_, '_> { fn apply_state(&mut self, flow_state: &BitSet<Local>, loc: Location) { // Ignore unreachable blocks. - if self.body.basic_blocks()[loc.block].terminator().kind == TerminatorKind::Unreachable { + if self.body.basic_blocks[loc.block].terminator().kind == TerminatorKind::Unreachable { return; } @@ -728,7 +854,7 @@ fn sanitize_witness<'tcx>( body: &Body<'tcx>, witness: Ty<'tcx>, upvars: Vec<Ty<'tcx>>, - saved_locals: &GeneratorSavedLocals, + layout: &GeneratorLayout<'tcx>, ) { let did = body.source.def_id(); let param_env = tcx.param_env(did); @@ -741,37 +867,42 @@ fn sanitize_witness<'tcx>( _ => { tcx.sess.delay_span_bug( body.span, - &format!("unexpected generator witness type {:?}", witness.kind()), + format!("unexpected generator witness type {:?}", witness.kind()), ); return; } }; - for (local, decl) in body.local_decls.iter_enumerated() { - // Ignore locals which are internal or not saved between yields. - if !saved_locals.contains(local) || decl.internal { + let mut mismatches = Vec::new(); + for fty in &layout.field_tys { + if fty.ignore_for_traits { continue; } - let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty); + let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty); // Sanity check that typeck knows about the type of locals which are // live across a suspension point if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) { - span_bug!( - body.span, - "Broken MIR: generator contains type {} in MIR, \ - but typeck only knows about {} and {:?}", - decl_ty, - allowed, - allowed_upvars - ); + mismatches.push(decl_ty); } } + + if !mismatches.is_empty() { + span_bug!( + body.span, + "Broken MIR: generator contains type {:?} in MIR, \ + but typeck only knows about {} and {:?}", + mismatches, + allowed, + allowed_upvars + ); + } } fn compute_layout<'tcx>( + tcx: TyCtxt<'tcx>, liveness: LivenessInfo, - body: &mut Body<'tcx>, + body: &Body<'tcx>, ) -> ( FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>, GeneratorLayout<'tcx>, @@ -789,9 +920,39 @@ fn compute_layout<'tcx>( let mut locals = IndexVec::<GeneratorSavedLocal, _>::new(); let mut tys = IndexVec::<GeneratorSavedLocal, _>::new(); for (saved_local, local) in saved_locals.iter_enumerated() { - locals.push(local); - tys.push(body.local_decls[local].ty); debug!("generator saved local {:?} => {:?}", saved_local, local); + + locals.push(local); + let decl = &body.local_decls[local]; + debug!(?decl); + + let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir { + // Do not `assert_crate_local` here, as post-borrowck cleanup may have already cleared + // the information. This is alright, since `ignore_for_traits` is only relevant when + // this code runs on pre-cleanup MIR, and `ignore_for_traits = false` is the safer + // default. + match decl.local_info { + // Do not include raw pointers created from accessing `static` items, as those could + // well be re-created by another access to the same static. + ClearCrossCrate::Set(box LocalInfo::StaticRef { is_thread_local, .. }) => { + !is_thread_local + } + // Fake borrows are only read by fake reads, so do not have any reality in + // post-analysis MIR. + ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true, + _ => false, + } + } else { + // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that + // MIR building may introduce. This leads to wrongly ignored types, but this is + // necessary for internal consistency and to avoid ICEs. + decl.internal + }; + let decl = + GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; + debug!(?decl); + + tys.push(decl); } // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states. @@ -809,7 +970,7 @@ fn compute_layout<'tcx>( // Build the generator variant field list. // Create a map from local indices to generator struct indices. - let mut variant_fields: IndexVec<VariantIdx, IndexVec<Field, GeneratorSavedLocal>> = + let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, GeneratorSavedLocal>> = iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect(); let mut remap = FxHashMap::default(); for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() { @@ -821,7 +982,7 @@ fn compute_layout<'tcx>( // just use the first one here. That's fine; fields do not move // around inside generators, so it doesn't matter which variant // index we access them by. - remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx)); + remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx)); } variant_fields.push(fields); variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); @@ -831,6 +992,7 @@ fn compute_layout<'tcx>( let layout = GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts }; + debug!(?layout); (remap, layout, storage_liveness) } @@ -849,11 +1011,7 @@ fn insert_switch<'tcx>( let (assign, discr) = transform.get_discr(body); let switch_targets = SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block); - let switch = TerminatorKind::SwitchInt { - discr: Operand::Move(discr), - switch_ty: transform.discr_ty, - targets: switch_targets, - }; + let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets }; let source_info = SourceInfo::outermost(body.span); body.basic_blocks_mut().raw.insert( @@ -886,9 +1044,12 @@ fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let mut elaborator = DropShimElaborator { body, patch: MirPatch::new(body), tcx, param_env }; - for (block, block_data) in body.basic_blocks().iter_enumerated() { + for (block, block_data) in body.basic_blocks.iter_enumerated() { let (target, unwind, source_info) = match block_data.terminator() { - Terminator { source_info, kind: TerminatorKind::Drop { place, target, unwind } } => { + Terminator { + source_info, + kind: TerminatorKind::Drop { place, target, unwind, replace: _ }, + } => { if let Some(local) = place.as_local() { if local == SELF_ARG { (target, unwind, source_info) @@ -904,7 +1065,12 @@ fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let unwind = if block_data.is_cleanup { Unwind::InCleanup } else { - Unwind::To(unwind.unwrap_or_else(|| elaborator.patch.resume_block())) + Unwind::To(match *unwind { + UnwindAction::Cleanup(tgt) => tgt, + UnwindAction::Continue => elaborator.patch.resume_block(), + UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(), + UnwindAction::Terminate => elaborator.patch.terminate_block(), + }) }; elaborate_drop( &mut elaborator, @@ -957,22 +1123,12 @@ fn create_generator_drop_shim<'tcx>( tcx.mk_ptr(ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }), source_info, ); - if tcx.sess.opts.unstable_opts.mir_emit_retag { - // Alias tracking must know we changed the type - body.basic_blocks_mut()[START_BLOCK].statements.insert( - 0, - Statement { - source_info, - kind: StatementKind::Retag(RetagKind::Raw, Box::new(Place::from(SELF_ARG))), - }, - ) - } // Make sure we remove dead blocks to remove // unrelated code from the resume part of the function simplify::remove_dead_blocks(tcx, &mut body); - dump_mir(tcx, None, "generator_drop", &0, &body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_drop", &0, &body, |_, _| Ok(())); body } @@ -991,7 +1147,7 @@ fn insert_panic_block<'tcx>( body: &mut Body<'tcx>, message: AssertMessage<'tcx>, ) -> BasicBlock { - let assert_block = BasicBlock::new(body.basic_blocks().len()); + let assert_block = BasicBlock::new(body.basic_blocks.len()); let term = TerminatorKind::Assert { cond: Operand::Constant(Box::new(Constant { span: body.span, @@ -999,9 +1155,9 @@ fn insert_panic_block<'tcx>( literal: ConstantKind::from_bool(tcx, false), })), expected: true, - msg: message, + msg: Box::new(message), target: assert_block, - cleanup: None, + unwind: UnwindAction::Continue, }; let source_info = SourceInfo::outermost(body.span); @@ -1016,12 +1172,12 @@ fn insert_panic_block<'tcx>( fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ty::ParamEnv<'tcx>) -> bool { // Returning from a function with an uninhabited return type is undefined behavior. - if tcx.conservative_is_privately_uninhabited(param_env.and(body.return_ty())) { + if body.return_ty().is_privately_uninhabited(tcx, param_env) { return false; } // If there's a return terminator the function may return. - for block in body.basic_blocks() { + for block in body.basic_blocks.iter() { if let TerminatorKind::Return = block.terminator().kind { return true; } @@ -1038,12 +1194,12 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { } // Unwinds can only start at certain terminators. - for block in body.basic_blocks() { + for block in body.basic_blocks.iter() { match block.terminator().kind { // These never unwind. TerminatorKind::Goto { .. } | TerminatorKind::SwitchInt { .. } - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::GeneratorDrop @@ -1060,7 +1216,6 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { // These may unwind. TerminatorKind::Drop { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::Call { .. } | TerminatorKind::InlineAsm { .. } | TerminatorKind::Assert { .. } => return true, @@ -1103,8 +1258,8 @@ fn create_generator_resume_function<'tcx>( } else if !block.is_cleanup { // Any terminators that *can* unwind but don't have an unwind target set are also // pointed at our poisoning block (unless they're part of the cleanup path). - if let Some(unwind @ None) = block.terminator_mut().unwind_mut() { - *unwind = Some(poison_block); + if let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() { + *unwind = UnwindAction::Cleanup(poison_block); } } } @@ -1115,7 +1270,7 @@ fn create_generator_resume_function<'tcx>( use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn}; // Jump to the entry point on the unresumed - cases.insert(0, (UNRESUMED, BasicBlock::new(0))); + cases.insert(0, (UNRESUMED, START_BLOCK)); // Panic when resumed on the returned or poisoned state let generator_kind = body.generator_kind().unwrap(); @@ -1143,14 +1298,18 @@ fn create_generator_resume_function<'tcx>( // unrelated code from the drop part of the function simplify::remove_dead_blocks(tcx, body); - dump_mir(tcx, None, "generator_resume", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_resume", &0, body, |_, _| Ok(())); } fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock { let return_block = insert_term_block(body, TerminatorKind::Return); - let term = - TerminatorKind::Drop { place: Place::from(SELF_ARG), target: return_block, unwind: None }; + let term = TerminatorKind::Drop { + place: Place::from(SELF_ARG), + target: return_block, + unwind: UnwindAction::Continue, + replace: false, + }; let source_info = SourceInfo::outermost(body.span); // Create a block to destroy an unresumed generators. This can only destroy upvars. @@ -1182,8 +1341,6 @@ fn create_cases<'tcx>( transform: &TransformVisitor<'tcx>, operation: Operation, ) -> Vec<(usize, BasicBlock)> { - let tcx = transform.tcx; - let source_info = SourceInfo::outermost(body.span); transform @@ -1216,85 +1373,13 @@ fn create_cases<'tcx>( if operation == Operation::Resume { // Move the resume argument to the destination place of the `Yield` terminator let resume_arg = Local::new(2); // 0 = return, 1 = self - - // handle `box yield` properly - let box_place = if let [projection @ .., ProjectionElem::Deref] = - &**point.resume_arg.projection - { - let box_place = - Place::from(point.resume_arg.local).project_deeper(projection, tcx); - - let box_ty = box_place.ty(&body.local_decls, tcx).ty; - - if box_ty.is_box() { Some((box_place, box_ty)) } else { None } - } else { - None - }; - - if let Some((box_place, box_ty)) = box_place { - let unique_did = box_ty - .ty_adt_def() - .expect("expected Box to be an Adt") - .non_enum_variant() - .fields[0] - .did; - - let Some(nonnull_def) = tcx.type_of(unique_did).ty_adt_def() else { - span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique") - }; - - let nonnull_did = nonnull_def.non_enum_variant().fields[0].did; - - let (unique_ty, nonnull_ty, ptr_ty) = - crate::elaborate_box_derefs::build_ptr_tys( - tcx, - box_ty.boxed_ty(), - unique_did, - nonnull_did, - ); - - let ptr_local = body.local_decls.push(LocalDecl::new(ptr_ty, body.span)); - - statements.push(Statement { - source_info, - kind: StatementKind::StorageLive(ptr_local), - }); - - statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new(( - Place::from(ptr_local), - Rvalue::Use(Operand::Copy(box_place.project_deeper( - &crate::elaborate_box_derefs::build_projection( - unique_ty, nonnull_ty, ptr_ty, - ), - tcx, - ))), - ))), - }); - - statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new(( - Place::from(ptr_local) - .project_deeper(&[ProjectionElem::Deref], tcx), - Rvalue::Use(Operand::Move(resume_arg.into())), - ))), - }); - - statements.push(Statement { - source_info, - kind: StatementKind::StorageDead(ptr_local), - }); - } else { - statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new(( - point.resume_arg, - Rvalue::Use(Operand::Move(resume_arg.into())), - ))), - }); - } + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + point.resume_arg, + Rvalue::Use(Operand::Move(resume_arg.into())), + ))), + }); } // Then jump to the real target @@ -1313,11 +1398,43 @@ fn create_cases<'tcx>( .collect() } -impl<'tcx> MirPass<'tcx> for StateTransform { - fn phase_change(&self) -> Option<MirPhase> { - Some(MirPhase::GeneratorsLowered) - } +#[instrument(level = "debug", skip(tcx), ret)] +pub(crate) fn mir_generator_witnesses<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: LocalDefId, +) -> Option<GeneratorLayout<'tcx>> { + assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir); + + let (body, _) = tcx.mir_promoted(def_id); + let body = body.borrow(); + let body = &*body; + + // The first argument is the generator type passed by value + let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + + // Get the interior types and substs which typeck computed + let movable = match *gen_ty.kind() { + ty::Generator(_, _, movability) => movability == hir::Movability::Movable, + ty::Error(_) => return None, + _ => span_bug!(body.span, "unexpected generator type {}", gen_ty), + }; + + // When first entering the generator, move the resume argument into its new local. + let always_live_locals = always_storage_live_locals(&body); + + let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + + // Extract locals which are live across suspension point into `layout` + // `remap` gives a mapping from local indices onto generator struct indices + // `storage_liveness` tells us which locals have live storage at suspension points + let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body); + check_suspend_tys(tcx, &generator_layout, &body); + + Some(generator_layout) +} + +impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let Some(yield_ty) = body.yield_ty() else { // This only applies to generators @@ -1329,45 +1446,60 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // The first argument is the generator type passed by value let gen_ty = body.local_decls.raw[1].ty; - // Get the interior types and substs which typeck computed - let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() { + // Get the discriminant type and substs which typeck computed + let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() { ty::Generator(_, substs, movability) => { let substs = substs.as_generator(); ( - substs.upvar_tys().collect(), - substs.witness(), substs.discr_ty(tcx), + substs.upvar_tys().collect::<Vec<_>>(), + substs.witness(), movability == hir::Movability::Movable, ) } _ => { - tcx.sess - .delay_span_bug(body.span, &format!("unexpected generator type {}", gen_ty)); + tcx.sess.delay_span_bug(body.span, format!("unexpected generator type {}", gen_ty)); return; } }; - // Compute GeneratorState<yield_ty, return_ty> - let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]); + let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_))); + let (state_adt_ref, state_substs) = if is_async_kind { + // Compute Poll<return_ty> + let poll_did = tcx.require_lang_item(LangItem::Poll, None); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_substs = tcx.mk_substs(&[body.return_ty().into()]); + (poll_adt_ref, poll_substs) + } else { + // Compute GeneratorState<yield_ty, return_ty> + let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_substs = tcx.mk_substs(&[yield_ty.into(), body.return_ty().into()]); + (state_adt_ref, state_substs) + }; let ret_ty = tcx.mk_adt(state_adt_ref, state_substs); // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local // RETURN_PLACE then is a fresh unused local with type ret_ty. let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx); + // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies. + if is_async_kind { + transform_async_context(tcx, body); + } + // We also replace the resume argument and insert an `Assign`. // This is needed because the resume argument `_2` might be live across a `yield`, in which // case there is no `Assign` to it that the transform can turn into a store to the generator // state. After the yield the slot in the generator state would then be uninitialized. let resume_local = Local::new(2); - let new_resume_local = - replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx); + let resume_ty = + if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty }; + let new_resume_local = replace_local(resume_local, resume_ty, body, tcx); // When first entering the generator, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); - let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements; + let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements; stmts.insert( 0, Statement { @@ -1384,8 +1516,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); - sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals); - if tcx.sess.opts.unstable_opts.validate_mir { let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias { assigned_local: None, @@ -1399,16 +1529,24 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Extract locals which are live across suspension point into `layout` // `remap` gives a mapping from local indices onto generator struct indices // `storage_liveness` tells us which locals have live storage at suspension points - let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); + let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body); + + if tcx.sess.opts.unstable_opts.validate_mir + && !tcx.sess.opts.unstable_opts.drop_tracking_mir + { + sanitize_witness(tcx, body, interior, upvars, &layout); + } let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id())); // Run the transformation which converts Places from Local to generator struct // accesses for locals in `remap`. // It also rewrites `return x` and `yield y` as writing a new generator state and returning - // GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively. + // either GeneratorState::Complete(x) and GeneratorState::Yielded(y), + // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`. let mut transform = TransformVisitor { tcx, + is_async_kind, state_adt_ref, state_substs, remap, @@ -1424,6 +1562,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform { body.arg_count = 2; // self, resume arg body.spread_arg = None; + // The original arguments to the function are no longer arguments, mark them as such. + // Otherwise they'll conflict with our new arguments, which although they don't have + // argument_index set, will get emitted as unnamed arguments. + for var in &mut body.var_debug_info { + var.argument_index = None; + } + body.generator.as_mut().unwrap().yield_ty = None; body.generator.as_mut().unwrap().generator_layout = Some(layout); @@ -1432,21 +1577,21 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // This is expanded to a drop ladder in `elaborate_generator_drops`. let drop_clean = insert_clean_drop(body); - dump_mir(tcx, None, "generator_pre-elab", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_pre-elab", &0, body, |_, _| Ok(())); // Expand `drop(generator_struct)` to a drop ladder which destroys upvars. // If any upvars are moved out of, drop elaboration will handle upvar destruction. // However we need to also elaborate the code generated by `insert_clean_drop`. elaborate_generator_drops(tcx, body); - dump_mir(tcx, None, "generator_post-transform", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_post-transform", &0, body, |_, _| Ok(())); // Create a copy of our MIR and use it to create the drop shim for the generator let drop_shim = create_generator_drop_shim(tcx, &transform, gen_ty, body, drop_clean); body.generator.as_mut().unwrap().generator_drop = Some(drop_shim); - // Create the Generator::resume function + // Create the Generator::resume / Future::poll function create_generator_resume_function(tcx, transform, body, can_return); // Run derefer to fix Derefs that are not in the first place @@ -1529,8 +1674,10 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { | StatementKind::StorageDead(_) | StatementKind::Retag(..) | StatementKind::AscribeUserType(..) + | StatementKind::PlaceMention(..) | StatementKind::Coverage(..) - | StatementKind::CopyNonOverlapping(..) + | StatementKind::Intrinsic(..) + | StatementKind::ConstEvalCounter | StatementKind::Nop => {} } } @@ -1544,7 +1691,7 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { args, destination, target: Some(_), - cleanup: _, + unwind: _, from_hir_call: _, fn_span: _, } => { @@ -1567,11 +1714,10 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { | TerminatorKind::Goto { .. } | TerminatorKind::SwitchInt { .. } | TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::Drop { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::Assert { .. } | TerminatorKind::GeneratorDrop | TerminatorKind::FalseEdge { .. } @@ -1579,3 +1725,199 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { } } } + +fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) { + let mut linted_tys = FxHashSet::default(); + + // We want a user-facing param-env. + let param_env = tcx.param_env(body.source.def_id()); + + for (variant, yield_source_info) in + layout.variant_fields.iter().zip(&layout.variant_source_info) + { + debug!(?variant); + for &local in variant { + let decl = &layout.field_tys[local]; + debug!(?decl); + + if !decl.ignore_for_traits && linted_tys.insert(decl.ty) { + let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue }; + + check_must_not_suspend_ty( + tcx, + decl.ty, + hir_id, + param_env, + SuspendCheckData { + source_span: decl.source_info.span, + yield_span: yield_source_info.span, + plural_len: 1, + ..Default::default() + }, + ); + } + } + } +} + +#[derive(Default)] +struct SuspendCheckData<'a> { + source_span: Span, + yield_span: Span, + descr_pre: &'a str, + descr_post: &'a str, + plural_len: usize, +} + +// Returns whether it emitted a diagnostic or not +// Note that this fn and the proceeding one are based on the code +// for creating must_use diagnostics +// +// Note that this technique was chosen over things like a `Suspend` marker trait +// as it is simpler and has precedent in the compiler +fn check_must_not_suspend_ty<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + hir_id: hir::HirId, + param_env: ty::ParamEnv<'tcx>, + data: SuspendCheckData<'_>, +) -> bool { + if ty.is_unit() { + return false; + } + + let plural_suffix = pluralize!(data.plural_len); + + debug!("Checking must_not_suspend for {}", ty); + + match *ty.kind() { + ty::Adt(..) if ty.is_box() => { + let boxed_ty = ty.boxed_ty(); + let descr_pre = &format!("{}boxed ", data.descr_pre); + check_must_not_suspend_ty( + tcx, + boxed_ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data), + // FIXME: support adding the attribute to TAITs + ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => { + let mut has_emitted = false; + for &(predicate, _) in tcx.explicit_item_bounds(def).skip_binder() { + // We only look at the `DefId`, so it is safe to skip the binder here. + if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) = + predicate.kind().skip_binder() + { + let def_id = poly_trait_predicate.trait_ref.def_id; + let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_pre, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Dynamic(binder, _, _) => { + let mut has_emitted = false; + for predicate in binder.iter() { + if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() { + let def_id = trait_ref.def_id; + let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Tuple(fields) => { + let mut has_emitted = false; + for (i, ty) in fields.iter().enumerate() { + let descr_post = &format!(" in tuple element {i}"); + if check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + } + } + has_emitted + } + ty::Array(ty, len) => { + let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { + descr_pre, + plural_len: len.try_eval_target_usize(tcx, param_env).unwrap_or(0) as usize + 1, + ..data + }, + ) + } + // If drop tracking is enabled, we want to look through references, since the referent + // may not be considered live across the await point. + ty::Ref(_region, ty, _mutability) => { + let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + _ => false, + } +} + +fn check_must_not_suspend_def( + tcx: TyCtxt<'_>, + def_id: DefId, + hir_id: hir::HirId, + data: SuspendCheckData<'_>, +) -> bool { + if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) { + let reason = attr.value_str().map(|s| errors::MustNotSuspendReason { + span: data.source_span, + reason: s.as_str().to_string(), + }); + tcx.emit_spanned_lint( + rustc_session::lint::builtin::MUST_NOT_SUSPEND, + hir_id, + data.source_span, + errors::MustNotSupend { + yield_sp: data.yield_span, + reason, + src_sp: data.source_span, + pre: data.descr_pre, + def_path: tcx.def_path_str(def_id), + post: data.descr_post, + }, + ); + + true + } else { + false + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index dc5d5cee879..5487b5987e0 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -1,19 +1,20 @@ //! Inlining pass for MIR functions use crate::deref_separator::deref_finder; use rustc_attr::InlineAttr; -use rustc_const_eval::transform::validate::equal_up_to_regions; +use rustc_hir::def_id::DefId; use rustc_index::bit_set::BitSet; -use rustc_index::vec::Idx; +use rustc_index::Idx; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::subst::Subst; -use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt}; +use rustc_middle::ty::TypeVisitableExt; +use rustc_middle::ty::{self, Instance, InstanceDef, ParamEnv, Ty, TyCtxt}; use rustc_session::config::OptLevel; -use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span}; +use rustc_target::abi::{FieldIdx, FIRST_VARIANT}; use rustc_target::spec::abi::Abi; -use super::simplify::{remove_dead_blocks, CfgSimplifier}; +use crate::simplify::{remove_dead_blocks, CfgSimplifier}; +use crate::util; use crate::MirPass; use std::iter; use std::ops::{Range, RangeFrom}; @@ -25,7 +26,7 @@ const CALL_PENALTY: usize = 25; const LANDINGPAD_PENALTY: usize = 50; const RESUME_PENALTY: usize = 45; -const UNKNOWN_SIZE_COST: usize = 10; +const TOP_DOWN_DEPTH_LIMIT: usize = 5; pub struct Inline; @@ -93,7 +94,7 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { history: Vec::new(), changed: false, }; - let blocks = BasicBlock::new(0)..body.basic_blocks().next_index(); + let blocks = START_BLOCK..body.basic_blocks.next_index(); this.process_blocks(body, blocks); this.changed } @@ -103,14 +104,26 @@ struct Inliner<'tcx> { param_env: ParamEnv<'tcx>, /// Caller codegen attributes. codegen_fn_attrs: &'tcx CodegenFnAttrs, - /// Stack of inlined Instances. - history: Vec<ty::Instance<'tcx>>, + /// Stack of inlined instances. + /// We only check the `DefId` and not the substs because we want to + /// avoid inlining cases of polymorphic recursion. + /// The number of `DefId`s is finite, so checking history is enough + /// to ensure that we do not loop endlessly while inlining. + history: Vec<DefId>, /// Indicates that the caller body has been modified. changed: bool, } impl<'tcx> Inliner<'tcx> { fn process_blocks(&mut self, caller_body: &mut Body<'tcx>, blocks: Range<BasicBlock>) { + // How many callsites in this body are we allowed to inline? We need to limit this in order + // to prevent super-linear growth in MIR size + let inline_limit = match self.history.len() { + 0 => usize::MAX, + 1..=TOP_DOWN_DEPTH_LIMIT => 1, + _ => return, + }; + let mut inlined_count = 0; for bb in blocks { let bb_data = &caller_body[bb]; if bb_data.is_cleanup { @@ -132,9 +145,16 @@ impl<'tcx> Inliner<'tcx> { Ok(new_blocks) => { debug!("inlined {}", callsite.callee); self.changed = true; - self.history.push(callsite.callee); + + self.history.push(callsite.callee.def_id()); self.process_blocks(caller_body, new_blocks); self.history.pop(); + + inlined_count += 1; + if inlined_count == inline_limit { + debug!("inline count reached"); + return; + } } } } @@ -150,8 +170,20 @@ impl<'tcx> Inliner<'tcx> { ) -> Result<std::ops::Range<BasicBlock>, &'static str> { let callee_attrs = self.tcx.codegen_fn_attrs(callsite.callee.def_id()); self.check_codegen_attributes(callsite, callee_attrs)?; + + let terminator = caller_body[callsite.block].terminator.as_ref().unwrap(); + let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() }; + let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty; + for arg in args { + if !arg.ty(&caller_body.local_decls, self.tcx).is_sized(self.tcx, self.param_env) { + // We do not allow inlining functions with unsized params. Inlining these functions + // could create unsized locals, which are unsound and being phased out. + return Err("Call has unsized argument"); + } + } + self.check_mir_is_available(caller_body, &callsite.callee)?; - let callee_body = self.tcx.instance_mir(callsite.callee.def); + let callee_body = try_instance_mir(self.tcx, callsite.callee.def)?; self.check_mir_body(callsite, callee_body, callee_attrs)?; if !self.tcx.consider_optimizing(|| { @@ -163,7 +195,7 @@ impl<'tcx> Inliner<'tcx> { let Ok(callee_body) = callsite.callee.try_subst_mir_and_normalize_erasing_regions( self.tcx, self.param_env, - callee_body.clone(), + ty::EarlyBinder::bind(callee_body.clone()), ) else { return Err("failed to normalize callee body"); }; @@ -171,11 +203,8 @@ impl<'tcx> Inliner<'tcx> { // Check call signature compatibility. // Normally, this shouldn't be required, but trait normalization failure can create a // validation ICE. - let terminator = caller_body[callsite.block].terminator.as_ref().unwrap(); - let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() }; - let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty; let output_type = callee_body.return_ty(); - if !equal_up_to_regions(self.tcx, self.param_env, output_type, destination_ty) { + if !util::is_subtype(self.tcx, self.param_env, output_type, destination_ty) { trace!(?output_type, ?destination_ty); return Err("failed to normalize return type"); } @@ -195,7 +224,7 @@ impl<'tcx> Inliner<'tcx> { arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args)) { let input_type = callee_body.local_decls[input].ty; - if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) { trace!(?arg_ty, ?input_type); return Err("failed to normalize tuple argument type"); } @@ -204,16 +233,16 @@ impl<'tcx> Inliner<'tcx> { for (arg, input) in args.iter().zip(callee_body.args_iter()) { let input_type = callee_body.local_decls[input].ty; let arg_ty = arg.ty(&caller_body.local_decls, self.tcx); - if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) { trace!(?arg_ty, ?input_type); return Err("failed to normalize argument type"); } } } - let old_blocks = caller_body.basic_blocks().next_index(); + let old_blocks = caller_body.basic_blocks.next_index(); self.inline_call(caller_body, &callsite, callee_body); - let new_blocks = old_blocks..caller_body.basic_blocks().next_index(); + let new_blocks = old_blocks..caller_body.basic_blocks.next_index(); Ok(new_blocks) } @@ -246,12 +275,14 @@ impl<'tcx> Inliner<'tcx> { // not get any optimizations run on it. Any subsequent inlining may cause cycles, but we // do not need to catch this here, we can wait until the inliner decides to continue // inlining a second time. - InstanceDef::VtableShim(_) + InstanceDef::VTableShim(_) | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::DropGlue(..) - | InstanceDef::CloneShim(..) => return Ok(()), + | InstanceDef::CloneShim(..) + | InstanceDef::ThreadLocalShim(..) + | InstanceDef::FnPtrAddrShim(..) => return Ok(()), } if self.tcx.is_constructor(callee_def_id) { @@ -296,7 +327,7 @@ impl<'tcx> Inliner<'tcx> { ) -> Option<CallSite<'tcx>> { // Only consider direct calls to functions let terminator = bb_data.terminator(); - if let TerminatorKind::Call { ref func, target, .. } = terminator.kind { + if let TerminatorKind::Call { ref func, target, fn_span, .. } = terminator.kind { let func_ty = func.ty(caller_body, self.tcx); if let ty::FnDef(def_id, substs) = *func_ty.kind() { // To resolve an instance its substs have to be fully normalized. @@ -308,19 +339,14 @@ impl<'tcx> Inliner<'tcx> { return None; } - if self.history.contains(&callee) { + if self.history.contains(&callee.def_id()) { return None; } - let fn_sig = self.tcx.bound_fn_sig(def_id).subst(self.tcx, substs); + let fn_sig = self.tcx.fn_sig(def_id).subst(self.tcx, substs); + let source_info = SourceInfo { span: fn_span, ..terminator.source_info }; - return Some(CallSite { - callee, - fn_sig, - block: bb, - target, - source_info: terminator.source_info, - }); + return Some(CallSite { callee, fn_sig, block: bb, target, source_info }); } } @@ -334,14 +360,8 @@ impl<'tcx> Inliner<'tcx> { callsite: &CallSite<'tcx>, callee_attrs: &CodegenFnAttrs, ) -> Result<(), &'static str> { - match callee_attrs.inline { - InlineAttr::Never => return Err("never inline hint"), - InlineAttr::Always | InlineAttr::Hint => {} - InlineAttr::None => { - if self.tcx.sess.mir_opt_level() <= 2 { - return Err("at mir-opt-level=2, only #[inline] is inlined"); - } - } + if let InlineAttr::Never = callee_attrs.inline { + return Err("never inline hint"); } // Only inline local functions if they would be eligible for cross-crate @@ -358,10 +378,6 @@ impl<'tcx> Inliner<'tcx> { return Err("C variadic"); } - if callee_attrs.flags.contains(CodegenFnAttrFlags::NAKED) { - return Err("naked"); - } - if callee_attrs.flags.contains(CodegenFnAttrFlags::COLD) { return Err("cold"); } @@ -370,7 +386,12 @@ impl<'tcx> Inliner<'tcx> { return Err("incompatible sanitizer set"); } - if callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set { + // Two functions are compatible if the callee has no attribute (meaning + // that it's codegen agnostic), or sets an attribute that is identical + // to this function's attribute. + if callee_attrs.instruction_set.is_some() + && callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set + { return Err("incompatible instruction set"); } @@ -403,124 +424,62 @@ impl<'tcx> Inliner<'tcx> { // Give a bonus functions with a small number of blocks, // We normally have two or three blocks for even // very small functions. - if callee_body.basic_blocks().len() <= 3 { + if callee_body.basic_blocks.len() <= 3 { threshold += threshold / 4; } debug!(" final inline threshold = {}", threshold); // FIXME: Give a bonus to functions with only a single caller - let mut first_block = true; - let mut cost = 0; - // Traverse the MIR manually so we can account for the effects of - // inlining on the CFG. + let mut checker = CostChecker { + tcx: self.tcx, + param_env: self.param_env, + instance: callsite.callee, + callee_body, + cost: 0, + validation: Ok(()), + }; + + // Traverse the MIR manually so we can account for the effects of inlining on the CFG. let mut work_list = vec![START_BLOCK]; - let mut visited = BitSet::new_empty(callee_body.basic_blocks().len()); + let mut visited = BitSet::new_empty(callee_body.basic_blocks.len()); while let Some(bb) = work_list.pop() { if !visited.insert(bb.index()) { continue; } - let blk = &callee_body.basic_blocks()[bb]; - - for stmt in &blk.statements { - // Don't count StorageLive/StorageDead in the inlining cost. - match stmt.kind { - StatementKind::StorageLive(_) - | StatementKind::StorageDead(_) - | StatementKind::Deinit(_) - | StatementKind::Nop => {} - _ => cost += INSTR_COST, - } - } - let term = blk.terminator(); - let mut is_drop = false; - match term.kind { - TerminatorKind::Drop { ref place, target, unwind } - | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => { - is_drop = true; - work_list.push(target); - // If the place doesn't actually need dropping, treat it like - // a regular goto. - let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty); - if ty.needs_drop(tcx, self.param_env) { - cost += CALL_PENALTY; - if let Some(unwind) = unwind { - cost += LANDINGPAD_PENALTY; - work_list.push(unwind); - } - } else { - cost += INSTR_COST; - } - } - - TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. } - if first_block => - { - // If the function always diverges, don't inline - // unless the cost is zero - threshold = 0; - } - - TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => { - if let ty::FnDef(def_id, _) = - *callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind() - { - // Don't give intrinsics the extra penalty for calls - if tcx.is_intrinsic(def_id) { - cost += INSTR_COST; - } else { - cost += CALL_PENALTY; - } - } else { - cost += CALL_PENALTY; - } - if cleanup.is_some() { - cost += LANDINGPAD_PENALTY; - } - } - TerminatorKind::Assert { cleanup, .. } => { - cost += CALL_PENALTY; - if cleanup.is_some() { - cost += LANDINGPAD_PENALTY; - } - } - TerminatorKind::Resume => cost += RESUME_PENALTY, - TerminatorKind::InlineAsm { cleanup, .. } => { - cost += INSTR_COST; - - if cleanup.is_some() { - cost += LANDINGPAD_PENALTY; - } - } - _ => cost += INSTR_COST, - } + let blk = &callee_body.basic_blocks[bb]; + checker.visit_basic_block_data(bb, blk); - if !is_drop { - for succ in term.successors() { - work_list.push(succ); + let term = blk.terminator(); + if let TerminatorKind::Drop { ref place, target, unwind, replace: _ } = term.kind { + work_list.push(target); + + // If the place doesn't actually need dropping, treat it like a regular goto. + let ty = callsite + .callee + .subst_mir(self.tcx, ty::EarlyBinder::bind(&place.ty(callee_body, tcx).ty)); + if ty.needs_drop(tcx, self.param_env) && let UnwindAction::Cleanup(unwind) = unwind { + work_list.push(unwind); } - } - - first_block = false; - } - - // Count up the cost of local variables and temps, if we know the size - // use that, otherwise we use a moderately-large dummy cost. - - let ptr_size = tcx.data_layout.pointer_size.bytes(); - - for v in callee_body.vars_and_temps_iter() { - let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty); - // Cost of the var is the size in machine-words, if we know - // it. - if let Some(size) = type_size_of(tcx, self.param_env, ty) { - cost += ((size + ptr_size - 1) / ptr_size) as usize; + } else if callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set + && matches!(term.kind, TerminatorKind::InlineAsm { .. }) + { + // During the attribute checking stage we allow a callee with no + // instruction_set assigned to count as compatible with a function that does + // assign one. However, during this stage we require an exact match when any + // inline-asm is detected. LLVM will still possibly do an inline later on + // if the no-attribute function ends up with the same instruction set anyway. + return Err("Cannot move inline-asm across instruction sets"); } else { - cost += UNKNOWN_SIZE_COST; + work_list.extend(term.successors()) } } + // Abort if type validation found anything fishy. + checker.validation?; + + let cost = checker.cost; if let InlineAttr::Always = callee_attrs.inline { debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost); Ok(()) @@ -541,7 +500,7 @@ impl<'tcx> Inliner<'tcx> { ) { let terminator = caller_body[callsite.block].terminator.take().unwrap(); match terminator.kind { - TerminatorKind::Call { args, destination, cleanup, .. } => { + TerminatorKind::Call { args, destination, unwind, .. } => { // If the call is something like `a[*i] = f(i)`, where // `i : &mut usize`, then just duplicating the `a[*i]` // Place could result in two different locations if `f` @@ -576,31 +535,35 @@ impl<'tcx> Inliner<'tcx> { destination }; + // Always create a local to hold the destination, as `RETURN_PLACE` may appear + // where a full `Place` is not allowed. + let (remap_destination, destination_local) = if let Some(d) = dest.as_local() { + (false, d) + } else { + ( + true, + self.new_call_temp( + caller_body, + &callsite, + destination.ty(caller_body, self.tcx).ty, + ), + ) + }; + // Copy the arguments if needed. let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body); - let mut expn_data = ExpnData::default( - ExpnKind::Inlined, - callsite.source_info.span, - self.tcx.sess.edition(), - None, - None, - ); - expn_data.def_site = callee_body.span; - let expn_data = - self.tcx.with_stable_hashing_context(|hcx| LocalExpnId::fresh(expn_data, hcx)); let mut integrator = Integrator { args: &args, new_locals: Local::new(caller_body.local_decls.len()).., new_scopes: SourceScope::new(caller_body.source_scopes.len()).., - new_blocks: BasicBlock::new(caller_body.basic_blocks().len()).., - destination: dest, + new_blocks: BasicBlock::new(caller_body.basic_blocks.len()).., + destination: destination_local, callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(), callsite, - cleanup_block: cleanup, + cleanup_block: unwind, in_cleanup_block: false, tcx: self.tcx, - expn_data, always_live_locals: BitSet::new_filled(callee_body.local_decls.len()), }; @@ -611,7 +574,9 @@ impl<'tcx> Inliner<'tcx> { // If there are any locals without storage markers, give them storage only for the // duration of the call. for local in callee_body.vars_and_temps_iter() { - if integrator.always_live_locals.contains(local) { + if !callee_body.local_decls[local].internal + && integrator.always_live_locals.contains(local) + { let new_local = integrator.map_local(local); caller_body[callsite.block].statements.push(Statement { source_info: callsite.source_info, @@ -623,8 +588,20 @@ impl<'tcx> Inliner<'tcx> { // To avoid repeated O(n) insert, push any new statements to the end and rotate // the slice once. let mut n = 0; + if remap_destination { + caller_body[block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new(( + dest, + Rvalue::Use(Operand::Move(destination_local.into())), + ))), + }); + n += 1; + } for local in callee_body.vars_and_temps_iter().rev() { - if integrator.always_live_locals.contains(local) { + if !callee_body.local_decls[local].internal + && integrator.always_live_locals.contains(local) + { let new_local = integrator.map_local(local); caller_body[block].statements.push(Statement { source_info: callsite.source_info, @@ -652,11 +629,11 @@ impl<'tcx> Inliner<'tcx> { // `required_consts`, here we may not only have `ConstKind::Unevaluated` // because we are calling `subst_and_normalize_erasing_regions`. caller_body.required_consts.extend( - callee_body.required_consts.iter().copied().filter(|&ct| { - match ct.literal.const_for_ty() { - Some(ct) => matches!(ct.kind(), ConstKind::Unevaluated(_)), - None => true, + callee_body.required_consts.iter().copied().filter(|&ct| match ct.literal { + ConstantKind::Ty(_) => { + bug!("should never encounter ty::UnevaluatedConst in `required_consts`") } + ConstantKind::Val(..) | ConstantKind::Unevaluated(..) => true, }), ); } @@ -713,7 +690,7 @@ impl<'tcx> Inliner<'tcx> { // The `tmp0`, `tmp1`, and `tmp2` in our example above. let tuple_tmp_args = tuple_tys.iter().enumerate().map(|(i, ty)| { // This is e.g., `tuple_tmp.0` in our example above. - let tuple_field = Operand::Move(tcx.mk_place_field(tuple, Field::new(i), ty)); + let tuple_field = Operand::Move(tcx.mk_place_field(tuple, FieldIdx::new(i), ty)); // Spill to a local to make e.g., `tmp0`. self.create_temp_if_necessary(tuple_field, callsite, caller_body) @@ -782,12 +759,174 @@ impl<'tcx> Inliner<'tcx> { } } -fn type_size_of<'tcx>( +/// Verify that the callee body is compatible with the caller. +/// +/// This visitor mostly computes the inlining cost, +/// but also needs to verify that types match because of normalization failure. +struct CostChecker<'b, 'tcx> { tcx: TyCtxt<'tcx>, - param_env: ty::ParamEnv<'tcx>, - ty: Ty<'tcx>, -) -> Option<u64> { - tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes()) + param_env: ParamEnv<'tcx>, + cost: usize, + callee_body: &'b Body<'tcx>, + instance: ty::Instance<'tcx>, + validation: Result<(), &'static str>, +} + +impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + // Don't count StorageLive/StorageDead in the inlining cost. + match statement.kind { + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Deinit(_) + | StatementKind::Nop => {} + _ => self.cost += INSTR_COST, + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + let tcx = self.tcx; + match terminator.kind { + TerminatorKind::Drop { ref place, unwind, .. } => { + // If the place doesn't actually need dropping, treat it like a regular goto. + let ty = self + .instance + .subst_mir(tcx, ty::EarlyBinder::bind(&place.ty(self.callee_body, tcx).ty)); + if ty.needs_drop(tcx, self.param_env) { + self.cost += CALL_PENALTY; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } else { + self.cost += INSTR_COST; + } + } + TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => { + let fn_ty = self.instance.subst_mir(tcx, ty::EarlyBinder::bind(&f.literal.ty())); + self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) { + // Don't give intrinsics the extra penalty for calls + INSTR_COST + } else { + CALL_PENALTY + }; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Assert { unwind, .. } => { + self.cost += CALL_PENALTY; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Resume => self.cost += RESUME_PENALTY, + TerminatorKind::InlineAsm { unwind, .. } => { + self.cost += INSTR_COST; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } + _ => self.cost += INSTR_COST, + } + + self.super_terminator(terminator, location); + } + + /// This method duplicates code from MIR validation in an attempt to detect type mismatches due + /// to normalization failure. + fn visit_projection_elem( + &mut self, + local: Local, + proj_base: &[PlaceElem<'tcx>], + elem: PlaceElem<'tcx>, + context: PlaceContext, + location: Location, + ) { + if let ProjectionElem::Field(f, ty) = elem { + let parent = Place { local, projection: self.tcx.mk_place_elems(proj_base) }; + let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx); + let check_equal = |this: &mut Self, f_ty| { + if !util::is_equal_up_to_subtyping(this.tcx, this.param_env, ty, f_ty) { + trace!(?ty, ?f_ty); + this.validation = Err("failed to normalize projection type"); + return; + } + }; + + let kind = match parent_ty.ty.kind() { + &ty::Alias(ty::Opaque, ty::AliasTy { def_id, substs, .. }) => { + self.tcx.type_of(def_id).subst(self.tcx, substs).kind() + } + kind => kind, + }; + + match kind { + ty::Tuple(fields) => { + let Some(f_ty) = fields.get(f.as_usize()) else { + self.validation = Err("malformed MIR"); + return; + }; + check_equal(self, *f_ty); + } + ty::Adt(adt_def, substs) => { + let var = parent_ty.variant_index.unwrap_or(FIRST_VARIANT); + let Some(field) = adt_def.variant(var).fields.get(f) else { + self.validation = Err("malformed MIR"); + return; + }; + check_equal(self, field.ty(self.tcx, substs)); + } + ty::Closure(_, substs) => { + let substs = substs.as_closure(); + let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else { + self.validation = Err("malformed MIR"); + return; + }; + check_equal(self, f_ty); + } + &ty::Generator(def_id, substs, _) => { + let f_ty = if let Some(var) = parent_ty.variant_index { + let gen_body = if def_id == self.callee_body.source.def_id() { + self.callee_body + } else { + self.tcx.optimized_mir(def_id) + }; + + let Some(layout) = gen_body.generator_layout() else { + self.validation = Err("malformed MIR"); + return; + }; + + let Some(&local) = layout.variant_fields[var].get(f) else { + self.validation = Err("malformed MIR"); + return; + }; + + let Some(f_ty) = layout.field_tys.get(local) else { + self.validation = Err("malformed MIR"); + return; + }; + + f_ty.ty + } else { + let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else { + self.validation = Err("malformed MIR"); + return; + }; + + f_ty + }; + + check_equal(self, f_ty); + } + _ => self.validation = Err("malformed MIR"), + } + } + + self.super_projection_elem(local, proj_base, elem, context, location); + } } /** @@ -802,20 +941,19 @@ struct Integrator<'a, 'tcx> { new_locals: RangeFrom<Local>, new_scopes: RangeFrom<SourceScope>, new_blocks: RangeFrom<BasicBlock>, - destination: Place<'tcx>, + destination: Local, callsite_scope: SourceScopeData<'tcx>, callsite: &'a CallSite<'tcx>, - cleanup_block: Option<BasicBlock>, + cleanup_block: UnwindAction, in_cleanup_block: bool, tcx: TyCtxt<'tcx>, - expn_data: LocalExpnId, always_live_locals: BitSet<Local>, } impl Integrator<'_, '_> { fn map_local(&self, local: Local) -> Local { let new = if local == RETURN_PLACE { - self.destination.local + self.destination } else { let idx = local.index() - 1; if idx < self.args.len() { @@ -839,6 +977,24 @@ impl Integrator<'_, '_> { trace!("mapping block `{:?}` to `{:?}`", block, new); new } + + fn map_unwind(&self, unwind: UnwindAction) -> UnwindAction { + if self.in_cleanup_block { + match unwind { + UnwindAction::Cleanup(_) | UnwindAction::Continue => { + bug!("cleanup on cleanup block"); + } + UnwindAction::Unreachable | UnwindAction::Terminate => return unwind, + } + } + + match unwind { + UnwindAction::Unreachable | UnwindAction::Terminate => unwind, + UnwindAction::Cleanup(target) => UnwindAction::Cleanup(self.map_block(target)), + // Add an unwind edge to the original call's cleanup block + UnwindAction::Continue => self.cleanup_block, + } + } } impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { @@ -876,32 +1032,6 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { *scope = self.map_scope(*scope); } - fn visit_span(&mut self, span: &mut Span) { - // Make sure that all spans track the fact that they were inlined. - *span = span.fresh_expansion(self.expn_data); - } - - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - for elem in place.projection { - // FIXME: Make sure that return place is not used in an indexing projection, since it - // won't be rebased as it is supposed to be. - assert_ne!(ProjectionElem::Index(RETURN_PLACE), elem); - } - - // If this is the `RETURN_PLACE`, we need to rebase any projections onto it. - let dest_proj_len = self.destination.projection.len(); - if place.local == RETURN_PLACE && dest_proj_len > 0 { - let mut projs = Vec::with_capacity(dest_proj_len + place.projection.len()); - projs.extend(self.destination.projection); - projs.extend(place.projection); - - place.projection = self.tcx.intern_place_elems(&*projs); - } - // Handles integrating any locals that occur in the base - // or projections - self.super_place(place, context, location) - } - fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { self.in_cleanup_block = data.is_cleanup; self.super_basic_block_data(block, data); @@ -944,38 +1074,19 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { *tgt = self.map_block(*tgt); } } - TerminatorKind::Drop { ref mut target, ref mut unwind, .. } - | TerminatorKind::DropAndReplace { ref mut target, ref mut unwind, .. } => { + TerminatorKind::Drop { ref mut target, ref mut unwind, .. } => { *target = self.map_block(*target); - if let Some(tgt) = *unwind { - *unwind = Some(self.map_block(tgt)); - } else if !self.in_cleanup_block { - // Unless this drop is in a cleanup block, add an unwind edge to - // the original call's cleanup block - *unwind = self.cleanup_block; - } + *unwind = self.map_unwind(*unwind); } - TerminatorKind::Call { ref mut target, ref mut cleanup, .. } => { + TerminatorKind::Call { ref mut target, ref mut unwind, .. } => { if let Some(ref mut tgt) = *target { *tgt = self.map_block(*tgt); } - if let Some(tgt) = *cleanup { - *cleanup = Some(self.map_block(tgt)); - } else if !self.in_cleanup_block { - // Unless this call is in a cleanup block, add an unwind edge to - // the original call's cleanup block - *cleanup = self.cleanup_block; - } + *unwind = self.map_unwind(*unwind); } - TerminatorKind::Assert { ref mut target, ref mut cleanup, .. } => { + TerminatorKind::Assert { ref mut target, ref mut unwind, .. } => { *target = self.map_block(*target); - if let Some(tgt) = *cleanup { - *cleanup = Some(self.map_block(tgt)); - } else if !self.in_cleanup_block { - // Unless this assert is in a cleanup block, add an unwind edge to - // the original call's cleanup block - *cleanup = self.cleanup_block; - } + *unwind = self.map_unwind(*unwind); } TerminatorKind::Return => { terminator.kind = if let Some(tgt) = self.callsite.target { @@ -985,11 +1096,14 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { } } TerminatorKind::Resume => { - if let Some(tgt) = self.cleanup_block { - terminator.kind = TerminatorKind::Goto { target: tgt } - } + terminator.kind = match self.cleanup_block { + UnwindAction::Cleanup(tgt) => TerminatorKind::Goto { target: tgt }, + UnwindAction::Continue => TerminatorKind::Resume, + UnwindAction::Unreachable => TerminatorKind::Unreachable, + UnwindAction::Terminate => TerminatorKind::Terminate, + }; } - TerminatorKind::Abort => {} + TerminatorKind::Terminate => {} TerminatorKind::Unreachable => {} TerminatorKind::FalseEdge { ref mut real_target, ref mut imaginary_target } => { *real_target = self.map_block(*real_target); @@ -1000,15 +1114,36 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { { bug!("False unwinds should have been removed before inlining") } - TerminatorKind::InlineAsm { ref mut destination, ref mut cleanup, .. } => { + TerminatorKind::InlineAsm { ref mut destination, ref mut unwind, .. } => { if let Some(ref mut tgt) = *destination { *tgt = self.map_block(*tgt); - } else if !self.in_cleanup_block { - // Unless this inline asm is in a cleanup block, add an unwind edge to - // the original call's cleanup block - *cleanup = self.cleanup_block; } + *unwind = self.map_unwind(*unwind); } } } } + +#[instrument(skip(tcx), level = "debug")] +fn try_instance_mir<'tcx>( + tcx: TyCtxt<'tcx>, + instance: InstanceDef<'tcx>, +) -> Result<&'tcx Body<'tcx>, &'static str> { + match instance { + ty::InstanceDef::DropGlue(_, Some(ty)) => match ty.kind() { + ty::Adt(def, substs) => { + let fields = def.all_fields(); + for field in fields { + let field_ty = field.ty(tcx, substs); + if field_ty.has_param() && field_ty.has_projections() { + return Err("cannot build drop shim for polymorphic type"); + } + } + + Ok(tcx.instance_mir(instance)) + } + _ => Ok(tcx.instance_mir(instance)), + }, + _ => Ok(tcx.instance_mir(instance)), + } +} diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index a3a35f95071..8a10445f837 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -2,7 +2,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet}; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_middle::mir::TerminatorKind; -use rustc_middle::ty::TypeVisitable; +use rustc_middle::ty::TypeVisitableExt; use rustc_middle::ty::{self, subst::SubstsRef, InstanceDef, TyCtxt}; use rustc_session::Limit; @@ -13,7 +13,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( tcx: TyCtxt<'tcx>, (root, target): (ty::Instance<'tcx>, LocalDefId), ) -> bool { - trace!(%root, target = %tcx.def_path_str(target.to_def_id())); + trace!(%root, target = %tcx.def_path_str(target)); let param_env = tcx.param_env_reveal_all_normalized(target); assert_ne!( root.def_id().expect_local(), @@ -44,7 +44,11 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( ) -> bool { trace!(%caller); for &(callee, substs) in tcx.mir_inliner_callees(caller.def) { - let Ok(substs) = caller.try_subst_mir_and_normalize_erasing_regions(tcx, param_env, substs) else { + let Ok(substs) = caller.try_subst_mir_and_normalize_erasing_regions( + tcx, + param_env, + ty::EarlyBinder::bind(substs), + ) else { trace!(?caller, ?param_env, ?substs, "cannot normalize, skipping"); continue; }; @@ -79,16 +83,20 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( // These have MIR and if that MIR is inlined, substituted and then inlining is run // again, a function item can end up getting inlined. Thus we'll be able to cause // a cycle that way - InstanceDef::VtableShim(_) + InstanceDef::VTableShim(_) | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} + + // This shim does not call any other functions, thus there can be no recursion. + InstanceDef::FnPtrAddrShim(..) => continue, InstanceDef::DropGlue(..) => { // FIXME: A not fully substituted drop shim can cause ICEs if one attempts to // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this // needs some more analysis. - if callee.needs_subst() { + if callee.has_param() { continue; } } @@ -144,8 +152,7 @@ pub(crate) fn mir_inliner_callees<'tcx>( let guard; let body = match (instance, instance.def_id().as_local()) { (InstanceDef::Item(_), Some(def_id)) => { - let def = ty::WithOptConstParam::unknown(def_id); - steal = tcx.mir_promoted(def).0; + steal = tcx.mir_promoted(def_id).0; guard = steal.borrow(); &*guard } @@ -153,7 +160,7 @@ pub(crate) fn mir_inliner_callees<'tcx>( _ => tcx.instance_mir(instance), }; let mut calls = FxIndexSet::default(); - for bb_data in body.basic_blocks() { + for bb_data in body.basic_blocks.iter() { let terminator = bb_data.terminator(); if let TerminatorKind::Call { func, .. } = &terminator.kind { let ty = func.ty(&body.local_decls, tcx); diff --git a/compiler/rustc_mir_transform/src/instcombine.rs b/compiler/rustc_mir_transform/src/instcombine.rs deleted file mode 100644 index 2f3c65869ef..00000000000 --- a/compiler/rustc_mir_transform/src/instcombine.rs +++ /dev/null @@ -1,203 +0,0 @@ -//! Performs various peephole optimizations. - -use crate::MirPass; -use rustc_hir::Mutability; -use rustc_middle::mir::{ - BinOp, Body, Constant, ConstantKind, LocalDecls, Operand, Place, ProjectionElem, Rvalue, - SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp, -}; -use rustc_middle::ty::{self, TyCtxt}; - -pub struct InstCombine; - -impl<'tcx> MirPass<'tcx> for InstCombine { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() > 0 - } - - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let ctx = InstCombineContext { tcx, local_decls: &body.local_decls }; - for block in body.basic_blocks.as_mut() { - for statement in block.statements.iter_mut() { - match statement.kind { - StatementKind::Assign(box (_place, ref mut rvalue)) => { - ctx.combine_bool_cmp(&statement.source_info, rvalue); - ctx.combine_ref_deref(&statement.source_info, rvalue); - ctx.combine_len(&statement.source_info, rvalue); - } - _ => {} - } - } - - ctx.combine_primitive_clone( - &mut block.terminator.as_mut().unwrap(), - &mut block.statements, - ); - } - } -} - -struct InstCombineContext<'tcx, 'a> { - tcx: TyCtxt<'tcx>, - local_decls: &'a LocalDecls<'tcx>, -} - -impl<'tcx> InstCombineContext<'tcx, '_> { - fn should_combine(&self, source_info: &SourceInfo, rvalue: &Rvalue<'tcx>) -> bool { - self.tcx.consider_optimizing(|| { - format!("InstCombine - Rvalue: {:?} SourceInfo: {:?}", rvalue, source_info) - }) - } - - /// Transform boolean comparisons into logical operations. - fn combine_bool_cmp(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { - match rvalue { - Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) => { - let new = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) { - // Transform "Eq(a, true)" ==> "a" - (BinOp::Eq, _, Some(true)) => Some(Rvalue::Use(a.clone())), - - // Transform "Ne(a, false)" ==> "a" - (BinOp::Ne, _, Some(false)) => Some(Rvalue::Use(a.clone())), - - // Transform "Eq(true, b)" ==> "b" - (BinOp::Eq, Some(true), _) => Some(Rvalue::Use(b.clone())), - - // Transform "Ne(false, b)" ==> "b" - (BinOp::Ne, Some(false), _) => Some(Rvalue::Use(b.clone())), - - // Transform "Eq(false, b)" ==> "Not(b)" - (BinOp::Eq, Some(false), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())), - - // Transform "Ne(true, b)" ==> "Not(b)" - (BinOp::Ne, Some(true), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())), - - // Transform "Eq(a, false)" ==> "Not(a)" - (BinOp::Eq, _, Some(false)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())), - - // Transform "Ne(a, true)" ==> "Not(a)" - (BinOp::Ne, _, Some(true)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())), - - _ => None, - }; - - if let Some(new) = new && self.should_combine(source_info, rvalue) { - *rvalue = new; - } - } - - _ => {} - } - } - - fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> { - let a = a.constant()?; - if a.literal.ty().is_bool() { a.literal.try_to_bool() } else { None } - } - - /// Transform "&(*a)" ==> "a". - fn combine_ref_deref(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { - if let Rvalue::Ref(_, _, place) = rvalue { - if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() { - if let ty::Ref(_, _, Mutability::Not) = - base.ty(self.local_decls, self.tcx).ty.kind() - { - // The dereferenced place must have type `&_`, so that we don't copy `&mut _`. - } else { - return; - } - - if !self.should_combine(source_info, rvalue) { - return; - } - - *rvalue = Rvalue::Use(Operand::Copy(Place { - local: base.local, - projection: self.tcx.intern_place_elems(base.projection), - })); - } - } - } - - /// Transform "Len([_; N])" ==> "N". - fn combine_len(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { - if let Rvalue::Len(ref place) = *rvalue { - let place_ty = place.ty(self.local_decls, self.tcx).ty; - if let ty::Array(_, len) = *place_ty.kind() { - if !self.should_combine(source_info, rvalue) { - return; - } - - let literal = ConstantKind::from_const(len, self.tcx); - let constant = Constant { span: source_info.span, literal, user_ty: None }; - *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant))); - } - } - } - - fn combine_primitive_clone( - &self, - terminator: &mut Terminator<'tcx>, - statements: &mut Vec<Statement<'tcx>>, - ) { - let TerminatorKind::Call { func, args, destination, target, .. } = &mut terminator.kind - else { return }; - - // It's definitely not a clone if there are multiple arguments - if args.len() != 1 { - return; - } - - let Some(destination_block) = *target - else { return }; - - // Only bother looking more if it's easy to know what we're calling - let Some((fn_def_id, fn_substs)) = func.const_fn_def() - else { return }; - - // Clone needs one subst, so we can cheaply rule out other stuff - if fn_substs.len() != 1 { - return; - } - - // These types are easily available from locals, so check that before - // doing DefId lookups to figure out what we're actually calling. - let arg_ty = args[0].ty(self.local_decls, self.tcx); - - let ty::Ref(_region, inner_ty, Mutability::Not) = *arg_ty.kind() - else { return }; - - if !inner_ty.is_trivially_pure_clone_copy() { - return; - } - - let trait_def_id = self.tcx.trait_of_item(fn_def_id); - if trait_def_id.is_none() || trait_def_id != self.tcx.lang_items().clone_trait() { - return; - } - - if !self.tcx.consider_optimizing(|| { - format!( - "InstCombine - Call: {:?} SourceInfo: {:?}", - (fn_def_id, fn_substs), - terminator.source_info - ) - }) { - return; - } - - let Some(arg_place) = args.pop().unwrap().place() - else { return }; - - statements.push(Statement { - source_info: terminator.source_info, - kind: StatementKind::Assign(Box::new(( - *destination, - Rvalue::Use(Operand::Copy( - arg_place.project_deeper(&[ProjectionElem::Deref], self.tcx), - )), - ))), - }); - terminator.kind = TerminatorKind::Goto { target: destination_block }; - } -} diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs new file mode 100644 index 00000000000..e4dc617620e --- /dev/null +++ b/compiler/rustc_mir_transform/src/instsimplify.rs @@ -0,0 +1,305 @@ +//! Performs various peephole optimizations. + +use crate::simplify::simplify_duplicate_switch_targets; +use crate::MirPass; +use rustc_hir::Mutability; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::ValidityRequirement; +use rustc_middle::ty::{self, ParamEnv, SubstsRef, Ty, TyCtxt}; +use rustc_span::symbol::Symbol; +use rustc_target::abi::FieldIdx; + +pub struct InstSimplify; + +impl<'tcx> MirPass<'tcx> for InstSimplify { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let ctx = InstSimplifyContext { + tcx, + local_decls: &body.local_decls, + param_env: tcx.param_env_reveal_all_normalized(body.source.def_id()), + }; + for block in body.basic_blocks.as_mut() { + for statement in block.statements.iter_mut() { + match statement.kind { + StatementKind::Assign(box (_place, ref mut rvalue)) => { + ctx.simplify_bool_cmp(&statement.source_info, rvalue); + ctx.simplify_ref_deref(&statement.source_info, rvalue); + ctx.simplify_len(&statement.source_info, rvalue); + ctx.simplify_cast(&statement.source_info, rvalue); + } + _ => {} + } + } + + ctx.simplify_primitive_clone( + &mut block.terminator.as_mut().unwrap(), + &mut block.statements, + ); + ctx.simplify_intrinsic_assert( + &mut block.terminator.as_mut().unwrap(), + &mut block.statements, + ); + simplify_duplicate_switch_targets(block.terminator.as_mut().unwrap()); + } + } +} + +struct InstSimplifyContext<'tcx, 'a> { + tcx: TyCtxt<'tcx>, + local_decls: &'a LocalDecls<'tcx>, + param_env: ParamEnv<'tcx>, +} + +impl<'tcx> InstSimplifyContext<'tcx, '_> { + fn should_simplify(&self, source_info: &SourceInfo, rvalue: &Rvalue<'tcx>) -> bool { + self.tcx.consider_optimizing(|| { + format!("InstSimplify - Rvalue: {:?} SourceInfo: {:?}", rvalue, source_info) + }) + } + + /// Transform boolean comparisons into logical operations. + fn simplify_bool_cmp(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + match rvalue { + Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) => { + let new = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) { + // Transform "Eq(a, true)" ==> "a" + (BinOp::Eq, _, Some(true)) => Some(Rvalue::Use(a.clone())), + + // Transform "Ne(a, false)" ==> "a" + (BinOp::Ne, _, Some(false)) => Some(Rvalue::Use(a.clone())), + + // Transform "Eq(true, b)" ==> "b" + (BinOp::Eq, Some(true), _) => Some(Rvalue::Use(b.clone())), + + // Transform "Ne(false, b)" ==> "b" + (BinOp::Ne, Some(false), _) => Some(Rvalue::Use(b.clone())), + + // Transform "Eq(false, b)" ==> "Not(b)" + (BinOp::Eq, Some(false), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())), + + // Transform "Ne(true, b)" ==> "Not(b)" + (BinOp::Ne, Some(true), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())), + + // Transform "Eq(a, false)" ==> "Not(a)" + (BinOp::Eq, _, Some(false)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())), + + // Transform "Ne(a, true)" ==> "Not(a)" + (BinOp::Ne, _, Some(true)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())), + + _ => None, + }; + + if let Some(new) = new && self.should_simplify(source_info, rvalue) { + *rvalue = new; + } + } + + _ => {} + } + } + + fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> { + let a = a.constant()?; + if a.literal.ty().is_bool() { a.literal.try_to_bool() } else { None } + } + + /// Transform "&(*a)" ==> "a". + fn simplify_ref_deref(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Ref(_, _, place) = rvalue { + if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() { + if rvalue.ty(self.local_decls, self.tcx) != base.ty(self.local_decls, self.tcx).ty { + return; + } + + if !self.should_simplify(source_info, rvalue) { + return; + } + + *rvalue = Rvalue::Use(Operand::Copy(Place { + local: base.local, + projection: self.tcx.mk_place_elems(base.projection), + })); + } + } + } + + /// Transform "Len([_; N])" ==> "N". + fn simplify_len(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Len(ref place) = *rvalue { + let place_ty = place.ty(self.local_decls, self.tcx).ty; + if let ty::Array(_, len) = *place_ty.kind() { + if !self.should_simplify(source_info, rvalue) { + return; + } + + let literal = ConstantKind::from_const(len, self.tcx); + let constant = Constant { span: source_info.span, literal, user_ty: None }; + *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant))); + } + } + } + + fn simplify_cast(&self, _source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Cast(kind, operand, cast_ty) = rvalue { + let operand_ty = operand.ty(self.local_decls, self.tcx); + if operand_ty == *cast_ty { + *rvalue = Rvalue::Use(operand.clone()); + } else if *kind == CastKind::Transmute { + // Transmuting an integer to another integer is just a signedness cast + if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) = (operand_ty.kind(), cast_ty.kind()) + && int.bit_width() == uint.bit_width() + { + // The width check isn't strictly necessary, as different widths + // are UB and thus we'd be allowed to turn it into a cast anyway. + // But let's keep the UB around for codegen to exploit later. + // (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes, + // then the width check is necessary for big-endian correctness.) + *kind = CastKind::IntToInt; + return; + } + + // Transmuting a transparent struct/union to a field's type is a projection + if let ty::Adt(adt_def, substs) = operand_ty.kind() + && adt_def.repr().transparent() + && (adt_def.is_struct() || adt_def.is_union()) + && let Some(place) = operand.place() + { + let variant = adt_def.non_enum_variant(); + for (i, field) in variant.fields.iter().enumerate() { + let field_ty = field.ty(self.tcx, substs); + if field_ty == *cast_ty { + let place = place.project_deeper(&[ProjectionElem::Field(FieldIdx::from_usize(i), *cast_ty)], self.tcx); + let operand = if operand.is_move() { Operand::Move(place) } else { Operand::Copy(place) }; + *rvalue = Rvalue::Use(operand); + return; + } + } + } + } + } + } + + fn simplify_primitive_clone( + &self, + terminator: &mut Terminator<'tcx>, + statements: &mut Vec<Statement<'tcx>>, + ) { + let TerminatorKind::Call { func, args, destination, target, .. } = &mut terminator.kind + else { return }; + + // It's definitely not a clone if there are multiple arguments + if args.len() != 1 { + return; + } + + let Some(destination_block) = *target + else { return }; + + // Only bother looking more if it's easy to know what we're calling + let Some((fn_def_id, fn_substs)) = func.const_fn_def() + else { return }; + + // Clone needs one subst, so we can cheaply rule out other stuff + if fn_substs.len() != 1 { + return; + } + + // These types are easily available from locals, so check that before + // doing DefId lookups to figure out what we're actually calling. + let arg_ty = args[0].ty(self.local_decls, self.tcx); + + let ty::Ref(_region, inner_ty, Mutability::Not) = *arg_ty.kind() + else { return }; + + if !inner_ty.is_trivially_pure_clone_copy() { + return; + } + + let trait_def_id = self.tcx.trait_of_item(fn_def_id); + if trait_def_id.is_none() || trait_def_id != self.tcx.lang_items().clone_trait() { + return; + } + + if !self.tcx.consider_optimizing(|| { + format!( + "InstSimplify - Call: {:?} SourceInfo: {:?}", + (fn_def_id, fn_substs), + terminator.source_info + ) + }) { + return; + } + + let Some(arg_place) = args.pop().unwrap().place() + else { return }; + + statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Use(Operand::Copy( + arg_place.project_deeper(&[ProjectionElem::Deref], self.tcx), + )), + ))), + }); + terminator.kind = TerminatorKind::Goto { target: destination_block }; + } + + fn simplify_intrinsic_assert( + &self, + terminator: &mut Terminator<'tcx>, + _statements: &mut Vec<Statement<'tcx>>, + ) { + let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else { return; }; + let Some(target_block) = target else { return; }; + let func_ty = func.ty(self.local_decls, self.tcx); + let Some((intrinsic_name, substs)) = resolve_rust_intrinsic(self.tcx, func_ty) else { + return; + }; + // The intrinsics we are interested in have one generic parameter + if substs.is_empty() { + return; + } + let ty = substs.type_at(0); + + let known_is_valid = intrinsic_assert_panics(self.tcx, self.param_env, ty, intrinsic_name); + match known_is_valid { + // We don't know the layout or it's not validity assertion at all, don't touch it + None => {} + Some(true) => { + // If we know the assert panics, indicate to later opts that the call diverges + *target = None; + } + Some(false) => { + // If we know the assert does not panic, turn the call into a Goto + terminator.kind = TerminatorKind::Goto { target: *target_block }; + } + } + } +} + +fn intrinsic_assert_panics<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ty: Ty<'tcx>, + intrinsic_name: Symbol, +) -> Option<bool> { + let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?; + Some(!tcx.check_validity_requirement((requirement, param_env.and(ty))).ok()?) +} + +fn resolve_rust_intrinsic<'tcx>( + tcx: TyCtxt<'tcx>, + func_ty: Ty<'tcx>, +) -> Option<(Symbol, SubstsRef<'tcx>)> { + if let ty::FnDef(def_id, substs) = *func_ty.kind() { + if tcx.is_intrinsic(def_id) { + return Some((tcx.item_name(def_id), substs)); + } + } + None +} diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs new file mode 100644 index 00000000000..430a6f6cef5 --- /dev/null +++ b/compiler/rustc_mir_transform/src/large_enums.rs @@ -0,0 +1,299 @@ +use crate::rustc_middle::ty::util::IntTypeExt; +use crate::MirPass; +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::mir::interpret::AllocId; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt}; +use rustc_session::Session; +use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants}; + +/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large +/// enough discrepancy between them. +/// +/// i.e. If there is are two variants: +/// ``` +/// enum Example { +/// Small, +/// Large([u32; 1024]), +/// } +/// ``` +/// Instead of emitting moves of the large variant, +/// Perform a memcpy instead. +/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD). +/// +/// In summary, what this does is at runtime determine which enum variant is active, +/// and instead of copying all the bytes of the largest possible variant, +/// copy only the bytes for the currently active variant. +pub struct EnumSizeOpt { + pub(crate) discrepancy: u64, +} + +impl<'tcx> MirPass<'tcx> for EnumSizeOpt { + fn is_enabled(&self, sess: &Session) -> bool { + sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3 + } + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // NOTE: This pass may produce different MIR based on the alignment of the target + // platform, but it will still be valid. + self.optim(tcx, body); + } +} + +impl EnumSizeOpt { + fn candidate<'tcx>( + &self, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + ty: Ty<'tcx>, + alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>, + ) -> Option<(AdtDef<'tcx>, usize, AllocId)> { + let adt_def = match ty.kind() { + ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def, + _ => return None, + }; + let layout = tcx.layout_of(param_env.and(ty)).ok()?; + let variants = match &layout.variants { + Variants::Single { .. } => return None, + Variants::Multiple { tag_encoding, .. } + if matches!(tag_encoding, TagEncoding::Niche { .. }) => + { + return None; + } + Variants::Multiple { variants, .. } if variants.len() <= 1 => return None, + Variants::Multiple { variants, .. } => variants, + }; + let min = variants.iter().map(|v| v.size).min().unwrap(); + let max = variants.iter().map(|v| v.size).max().unwrap(); + if max.bytes() - min.bytes() < self.discrepancy { + return None; + } + + let num_discrs = adt_def.discriminants(tcx).count(); + if variants.iter_enumerated().any(|(var_idx, _)| { + let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val; + (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs) + }) { + return None; + } + if let Some(alloc_id) = alloc_cache.get(&ty) { + return Some((*adt_def, num_discrs, *alloc_id)); + } + + let data_layout = tcx.data_layout(); + let ptr_sized_int = data_layout.ptr_sized_integer(); + let target_bytes = ptr_sized_int.size().bytes() as usize; + let mut data = vec![0; target_bytes * num_discrs]; + macro_rules! encode_store { + ($curr_idx: expr, $endian: expr, $bytes: expr) => { + let bytes = match $endian { + rustc_target::abi::Endian::Little => $bytes.to_le_bytes(), + rustc_target::abi::Endian::Big => $bytes.to_be_bytes(), + }; + for (i, b) in bytes.into_iter().enumerate() { + data[$curr_idx + i] = b; + } + }; + } + + for (var_idx, layout) in variants.iter_enumerated() { + let curr_idx = + target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize; + let sz = layout.size; + match ptr_sized_int { + rustc_target::abi::Integer::I32 => { + encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32); + } + rustc_target::abi::Integer::I64 => { + encode_store!(curr_idx, data_layout.endian, sz.bytes()); + } + _ => unreachable!(), + }; + } + let alloc = interpret::Allocation::from_bytes( + data, + tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi, + Mutability::Not, + ); + let alloc = tcx.create_memory_alloc(tcx.mk_const_alloc(alloc)); + Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc))) + } + fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut alloc_cache = FxHashMap::default(); + let body_did = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(body_did); + + let blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + for bb in blocks { + bb.expand_statements(|st| { + if let StatementKind::Assign(box ( + lhs, + Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), + )) = &st.kind + { + let ty = lhs.ty(local_decls, tcx).ty; + + let source_info = st.source_info; + let span = source_info.span; + + let (adt_def, num_variants, alloc_id) = + self.candidate(tcx, param_env, ty, &mut alloc_cache)?; + let alloc = tcx.global_alloc(alloc_id).unwrap_memory(); + + let tmp_ty = tcx.mk_array(tcx.types.usize, num_variants as u64); + + let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span)); + let store_live = Statement { + source_info, + kind: StatementKind::StorageLive(size_array_local), + }; + + let place = Place::from(size_array_local); + let constant_vals = Constant { + span, + user_ty: None, + literal: ConstantKind::Val( + interpret::ConstValue::ByRef { alloc, offset: Size::ZERO }, + tmp_ty, + ), + }; + let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals))); + + let const_assign = Statement { + source_info, + kind: StatementKind::Assign(Box::new((place, rval))), + }; + + let discr_place = Place::from( + local_decls + .push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)), + ); + + let store_discr = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + discr_place, + Rvalue::Discriminant(*rhs), + ))), + }; + + let discr_cast_place = + Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); + + let cast_discr = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + discr_cast_place, + Rvalue::Cast( + CastKind::IntToInt, + Operand::Copy(discr_place), + tcx.types.usize, + ), + ))), + }; + + let size_place = + Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); + + let store_size = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + size_place, + Rvalue::Use(Operand::Copy(Place { + local: size_array_local, + projection: tcx + .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]), + })), + ))), + }; + + let dst = + Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span))); + + let dst_ptr = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + dst, + Rvalue::AddressOf(Mutability::Mut, *lhs), + ))), + }; + + let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8); + let dst_cast_place = + Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span))); + + let dst_cast = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + dst_cast_place, + Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty), + ))), + }; + + let src = + Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span))); + + let src_ptr = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + src, + Rvalue::AddressOf(Mutability::Not, *rhs), + ))), + }; + + let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8); + let src_cast_place = + Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span))); + + let src_cast = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + src_cast_place, + Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty), + ))), + }; + + let deinit_old = + Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) }; + + let copy_bytes = Statement { + source_info, + kind: StatementKind::Intrinsic(Box::new( + NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { + src: Operand::Copy(src_cast_place), + dst: Operand::Copy(dst_cast_place), + count: Operand::Copy(size_place), + }), + )), + }; + + let store_dead = Statement { + source_info, + kind: StatementKind::StorageDead(size_array_local), + }; + let iter = [ + store_live, + const_assign, + store_discr, + cast_discr, + store_size, + dst_ptr, + dst_cast, + src_ptr, + src_cast, + deinit_old, + copy_bytes, + store_dead, + ] + .into_iter(); + + st.make_nop(); + Some(iter) + } else { + None + } + }); + } + } +} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index d968a488519..7d9f6c38e36 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -1,15 +1,18 @@ #![allow(rustc::potential_query_instability)] +#![deny(rustc::untranslatable_diagnostic)] +#![deny(rustc::diagnostic_outside_of_impl)] #![feature(box_patterns)] +#![feature(drain_filter)] +#![feature(is_sorted)] #![feature(let_chains)] -#![feature(let_else)] #![feature(map_try_insert)] #![feature(min_specialization)] #![feature(never_type)] -#![feature(once_cell)] #![feature(option_get_or_insert_default)] #![feature(trusted_step)] #![feature(try_blocks)] #![feature(yeet_expr)] +#![feature(if_let_guard)] #![recursion_limit = "256"] #[macro_use] @@ -22,14 +25,20 @@ use rustc_const_eval::util; use rustc_data_structures::fx::FxIndexSet; use rustc_data_structures::steal::Steal; use rustc_hir as hir; -use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_hir::def::DefKind; +use rustc_hir::def_id::LocalDefId; use rustc_hir::intravisit::{self, Visitor}; -use rustc_index::vec::IndexVec; +use rustc_index::IndexVec; use rustc_middle::mir::visit::Visitor as _; -use rustc_middle::mir::{traversal, Body, ConstQualifs, MirPass, MirPhase, Promoted}; -use rustc_middle::ty::query::Providers; -use rustc_middle::ty::{self, TyCtxt, TypeVisitable}; -use rustc_span::{Span, Symbol}; +use rustc_middle::mir::{ + traversal, AnalysisPhase, Body, ClearCrossCrate, ConstQualifs, Constant, LocalDecl, MirPass, + MirPhase, Operand, Place, ProjectionElem, Promoted, RuntimePhase, Rvalue, SourceInfo, + Statement, StatementKind, TerminatorKind, START_BLOCK, +}; +use rustc_middle::query::Providers; +use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt}; +use rustc_span::sym; +use rustc_trait_selection::traits; #[macro_use] mod pass_manager; @@ -43,15 +52,19 @@ mod add_retag; mod check_const_item_mutation; mod check_packed_ref; pub mod check_unsafety; +mod remove_place_mention; // This pass is public to allow external drivers to perform MIR cleanup pub mod cleanup_post_borrowck; mod const_debuginfo; mod const_goto; mod const_prop; mod const_prop_lint; +mod copy_prop; mod coverage; +mod ctfe_limit; +mod dataflow_const_prop; mod dead_store_elimination; -mod deaggregator; +mod deduce_param_attrs; mod deduplicate_blocks; mod deref_separator; mod dest_prop; @@ -59,20 +72,21 @@ pub mod dump_mir; mod early_otherwise_branch; mod elaborate_box_derefs; mod elaborate_drops; +mod errors; mod ffi_unwind_calls; mod function_item_references; mod generator; mod inline; -mod instcombine; +mod instsimplify; +mod large_enums; mod lower_intrinsics; mod lower_slice_len; -mod marker; mod match_branches; mod multiple_return_terminators; mod normalize_array_len; mod nrvo; -// This pass is public to allow external drivers to perform MIR cleanup -pub mod remove_false_edges; +mod prettify; +mod ref_prop; mod remove_noop_landing_pads; mod remove_storage_markers; mod remove_uninit_drops; @@ -82,11 +96,13 @@ mod required_consts; mod reveal_all; mod separate_const_switch; mod shim; +mod ssa; // This pass is public to allow external drivers to perform MIR cleanup +mod check_alignment; pub mod simplify; mod simplify_branches; mod simplify_comparison_integral; -mod simplify_try; +mod sroa; mod uninhabited_enum_branching; mod unreachable_prop; @@ -95,52 +111,94 @@ use rustc_const_eval::transform::promote_consts; use rustc_const_eval::transform::validate; use rustc_mir_dataflow::rustc_peek; +use rustc_errors::{DiagnosticMessage, SubdiagnosticMessage}; +use rustc_fluent_macro::fluent_messages; + +fluent_messages! { "../messages.ftl" } + pub fn provide(providers: &mut Providers) { check_unsafety::provide(providers); - check_packed_ref::provide(providers); coverage::query::provide(providers); ffi_unwind_calls::provide(providers); shim::provide(providers); *providers = Providers { mir_keys, mir_const, - mir_const_qualif: |tcx, def_id| { - let def_id = def_id.expect_local(); - if let Some(def) = ty::WithOptConstParam::try_lookup(def_id, tcx) { - tcx.mir_const_qualif_const_arg(def) - } else { - mir_const_qualif(tcx, ty::WithOptConstParam::unknown(def_id)) - } - }, - mir_const_qualif_const_arg: |tcx, (did, param_did)| { - mir_const_qualif(tcx, ty::WithOptConstParam { did, const_param_did: Some(param_did) }) - }, + mir_const_qualif, mir_promoted, mir_drops_elaborated_and_const_checked, mir_for_ctfe, - mir_for_ctfe_of_const_arg, + mir_generator_witnesses: generator::mir_generator_witnesses, optimized_mir, is_mir_available, is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did), mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable, mir_inliner_callees: inline::cycle::mir_inliner_callees, - promoted_mir: |tcx, def_id| { - let def_id = def_id.expect_local(); - if let Some(def) = ty::WithOptConstParam::try_lookup(def_id, tcx) { - tcx.promoted_mir_of_const_arg(def) - } else { - promoted_mir(tcx, ty::WithOptConstParam::unknown(def_id)) - } - }, - promoted_mir_of_const_arg: |tcx, (did, param_did)| { - promoted_mir(tcx, ty::WithOptConstParam { did, const_param_did: Some(param_did) }) - }, + promoted_mir, + deduced_param_attrs: deduce_param_attrs::deduced_param_attrs, ..*providers }; } -fn is_mir_available(tcx: TyCtxt<'_>, def_id: DefId) -> bool { - let def_id = def_id.expect_local(); +fn remap_mir_for_const_eval_select<'tcx>( + tcx: TyCtxt<'tcx>, + mut body: Body<'tcx>, + context: hir::Constness, +) -> Body<'tcx> { + for bb in body.basic_blocks.as_mut().iter_mut() { + let terminator = bb.terminator.as_mut().expect("invalid terminator"); + match terminator.kind { + TerminatorKind::Call { + func: Operand::Constant(box Constant { ref literal, .. }), + ref mut args, + destination, + target, + unwind, + fn_span, + .. + } if let ty::FnDef(def_id, _) = *literal.ty().kind() + && tcx.item_name(def_id) == sym::const_eval_select + && tcx.is_intrinsic(def_id) => + { + let [tupled_args, called_in_const, called_at_rt]: [_; 3] = std::mem::take(args).try_into().unwrap(); + let ty = tupled_args.ty(&body.local_decls, tcx); + let fields = ty.tuple_fields(); + let num_args = fields.len(); + let func = if context == hir::Constness::Const { called_in_const } else { called_at_rt }; + let (method, place): (fn(Place<'tcx>) -> Operand<'tcx>, Place<'tcx>) = match tupled_args { + Operand::Constant(_) => { + // there is no good way of extracting a tuple arg from a constant (const generic stuff) + // so we just create a temporary and deconstruct that. + let local = body.local_decls.push(LocalDecl::new(ty, fn_span)); + bb.statements.push(Statement { + source_info: SourceInfo::outermost(fn_span), + kind: StatementKind::Assign(Box::new((local.into(), Rvalue::Use(tupled_args.clone())))), + }); + (Operand::Move, local.into()) + } + Operand::Move(place) => (Operand::Move, place), + Operand::Copy(place) => (Operand::Copy, place), + }; + let place_elems = place.projection; + let arguments = (0..num_args).map(|x| { + let mut place_elems = place_elems.to_vec(); + place_elems.push(ProjectionElem::Field(x.into(), fields[x])); + let projection = tcx.mk_place_elems(&place_elems); + let place = Place { + local: place.local, + projection, + }; + method(place) + }).collect(); + terminator.kind = TerminatorKind::Call { func, args: arguments, destination, target, unwind, from_hir_call: false, fn_span }; + } + _ => {} + } + } + body +} + +fn is_mir_available(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { tcx.mir_keys(()).contains(&def_id) } @@ -154,32 +212,24 @@ fn mir_keys(tcx: TyCtxt<'_>, (): ()) -> FxIndexSet<LocalDefId> { // Additionally, tuple struct/variant constructors have MIR, but // they don't have a BodyId, so we need to build them separately. - struct GatherCtors<'a, 'tcx> { - tcx: TyCtxt<'tcx>, + struct GatherCtors<'a> { set: &'a mut FxIndexSet<LocalDefId>, } - impl<'tcx> Visitor<'tcx> for GatherCtors<'_, 'tcx> { - fn visit_variant_data( - &mut self, - v: &'tcx hir::VariantData<'tcx>, - _: Symbol, - _: &'tcx hir::Generics<'tcx>, - _: hir::HirId, - _: Span, - ) { - if let hir::VariantData::Tuple(_, hir_id) = *v { - self.set.insert(self.tcx.hir().local_def_id(hir_id)); + impl<'tcx> Visitor<'tcx> for GatherCtors<'_> { + fn visit_variant_data(&mut self, v: &'tcx hir::VariantData<'tcx>) { + if let hir::VariantData::Tuple(_, _, def_id) = *v { + self.set.insert(def_id); } intravisit::walk_struct_def(self, v) } } - tcx.hir().visit_all_item_likes_in_crate(&mut GatherCtors { tcx, set: &mut set }); + tcx.hir().visit_all_item_likes_in_crate(&mut GatherCtors { set: &mut set }); set } -fn mir_const_qualif(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -> ConstQualifs { - let const_kind = tcx.hir().body_const_context(def.did); +fn mir_const_qualif(tcx: TyCtxt<'_>, def: LocalDefId) -> ConstQualifs { + let const_kind = tcx.hir().body_const_context(def); // No need to const-check a non-const `fn`. if const_kind.is_none() { @@ -188,7 +238,7 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -> // N.B., this `borrow()` is guaranteed to be valid (i.e., the value // cannot yet be stolen), because `mir_promoted()`, which steals - // from `mir_const(), forces this query to execute before + // from `mir_const()`, forces this query to execute before // performing the steal. let body = &tcx.mir_const(def).borrow(); @@ -197,7 +247,7 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -> return Default::default(); } - let ccx = check_consts::ConstCx { body, tcx, const_kind, param_env: tcx.param_env(def.did) }; + let ccx = check_consts::ConstCx { body, tcx, const_kind, param_env: tcx.param_env(def) }; let mut validator = check_consts::check::Checker::new(&ccx); validator.check_body(); @@ -208,29 +258,20 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -> } /// Make MIR ready for const evaluation. This is run on all MIR, not just on consts! -fn mir_const<'tcx>( - tcx: TyCtxt<'tcx>, - def: ty::WithOptConstParam<LocalDefId>, -) -> &'tcx Steal<Body<'tcx>> { - if let Some(def) = def.try_upgrade(tcx) { - return tcx.mir_const(def); - } - +/// FIXME(oli-obk): it's unclear whether we still need this phase (and its corresponding query). +/// We used to have this for pre-miri MIR based const eval. +fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { // Unsafety check uses the raw mir, so make sure it is run. if !tcx.sess.opts.unstable_opts.thir_unsafeck { - if let Some(param_did) = def.const_param_did { - tcx.ensure().unsafety_check_result_for_const_arg((def.did, param_did)); - } else { - tcx.ensure().unsafety_check_result(def.did); - } + tcx.ensure_with_value().unsafety_check_result(def); } // has_ffi_unwind_calls query uses the raw mir, so make sure it is run. - tcx.ensure().has_ffi_unwind_calls(def.did); + tcx.ensure_with_value().has_ffi_unwind_calls(def); let mut body = tcx.mir_built(def).steal(); - rustc_middle::mir::dump_mir(tcx, None, "mir_map", &0, &body, |_, _| Ok(())); + pass_manager::dump_mir_for_phase_change(tcx, &body); pm::run_passes( tcx, @@ -241,27 +282,23 @@ fn mir_const<'tcx>( &Lint(check_const_item_mutation::CheckConstItemMutation), &Lint(function_item_references::FunctionItemReferences), // What we need to do constant evaluation. - &simplify::SimplifyCfg::new("initial"), + &simplify::SimplifyCfg::Initial, &rustc_peek::SanityCheck, // Just a lint - &marker::PhaseChange(MirPhase::Const), ], + None, ); tcx.alloc_steal_mir(body) } /// Compute the main MIR body and the list of MIR bodies of the promoteds. -fn mir_promoted<'tcx>( - tcx: TyCtxt<'tcx>, - def: ty::WithOptConstParam<LocalDefId>, -) -> (&'tcx Steal<Body<'tcx>>, &'tcx Steal<IndexVec<Promoted, Body<'tcx>>>) { - if let Some(def) = def.try_upgrade(tcx) { - return tcx.mir_promoted(def); - } - +fn mir_promoted( + tcx: TyCtxt<'_>, + def: LocalDefId, +) -> (&Steal<Body<'_>>, &Steal<IndexVec<Promoted, Body<'_>>>) { // Ensure that we compute the `mir_const_qualif` for constants at // this point, before we steal the mir-const result. // Also this means promotion can rely on all const checks having been done. - let const_qualifs = tcx.mir_const_qualif_opt_const_arg(def); + let const_qualifs = tcx.mir_const_qualif(def); let mut body = tcx.mir_const(def).steal(); if let Some(error_reported) = const_qualifs.tainted_by_errors { body.tainted_by_errors = Some(error_reported); @@ -279,11 +316,8 @@ fn mir_promoted<'tcx>( pm::run_passes( tcx, &mut body, - &[ - &promote_pass, - &simplify::SimplifyCfg::new("promote-consts"), - &coverage::InstrumentCoverage, - ], + &[&promote_pass, &simplify::SimplifyCfg::PromoteConsts, &coverage::InstrumentCoverage], + Some(MirPhase::Analysis(AnalysisPhase::Initial)), ); let promoted = promote_pass.promoted_fragments.into_inner(); @@ -291,46 +325,28 @@ fn mir_promoted<'tcx>( } /// Compute the MIR that is used during CTFE (and thus has no optimizations run on it) -fn mir_for_ctfe<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx Body<'tcx> { - let did = def_id.expect_local(); - if let Some(def) = ty::WithOptConstParam::try_lookup(did, tcx) { - tcx.mir_for_ctfe_of_const_arg(def) - } else { - tcx.arena.alloc(inner_mir_for_ctfe(tcx, ty::WithOptConstParam::unknown(did))) - } +fn mir_for_ctfe(tcx: TyCtxt<'_>, def_id: LocalDefId) -> &Body<'_> { + tcx.arena.alloc(inner_mir_for_ctfe(tcx, def_id)) } -/// Same as `mir_for_ctfe`, but used to get the MIR of a const generic parameter. -/// The docs on `WithOptConstParam` explain this a bit more, but the TLDR is that -/// we'd get cycle errors with `mir_for_ctfe`, because typeck would need to typeck -/// the const parameter while type checking the main body, which in turn would try -/// to type check the main body again. -fn mir_for_ctfe_of_const_arg<'tcx>( - tcx: TyCtxt<'tcx>, - (did, param_did): (LocalDefId, DefId), -) -> &'tcx Body<'tcx> { - tcx.arena.alloc(inner_mir_for_ctfe( - tcx, - ty::WithOptConstParam { did, const_param_did: Some(param_did) }, - )) -} - -fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -> Body<'_> { +fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> { // FIXME: don't duplicate this between the optimized_mir/mir_for_ctfe queries - if tcx.is_constructor(def.did.to_def_id()) { + if tcx.is_constructor(def.to_def_id()) { // There's no reason to run all of the MIR passes on constructors when // we can just output the MIR we want directly. This also saves const // qualification and borrow checking the trouble of special casing // constructors. - return shim::build_adt_ctor(tcx, def.did.to_def_id()); + return shim::build_adt_ctor(tcx, def.to_def_id()); } let context = tcx .hir() - .body_const_context(def.did) + .body_const_context(def) .expect("mir_for_ctfe should not be used for runtime functions"); - let mut body = tcx.mir_drops_elaborated_and_const_checked(def).borrow().clone(); + let body = tcx.mir_drops_elaborated_and_const_checked(def).borrow().clone(); + + let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::Const); match context { // Do not const prop functions, either they get executed at runtime or exported to metadata, @@ -349,12 +365,13 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) - pm::run_passes( tcx, &mut body, - &[&const_prop::ConstProp, &marker::PhaseChange(MirPhase::Optimized)], + &[&const_prop::ConstProp], + Some(MirPhase::Runtime(RuntimePhase::Optimized)), ); } } - debug_assert!(!body.has_free_regions(), "Free regions in MIR for CTFE"); + pm::run_passes(tcx, &mut body, &[&ctfe_limit::CtfeLimit], None); body } @@ -362,24 +379,19 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) - /// Obtain just the main MIR (no promoteds) and run some cleanups on it. This also runs /// mir borrowck *before* doing so in order to ensure that borrowck can be run and doesn't /// end up missing the source MIR due to stealing happening. -fn mir_drops_elaborated_and_const_checked<'tcx>( - tcx: TyCtxt<'tcx>, - def: ty::WithOptConstParam<LocalDefId>, -) -> &'tcx Steal<Body<'tcx>> { - if let Some(def) = def.try_upgrade(tcx) { - return tcx.mir_drops_elaborated_and_const_checked(def); +fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { + if tcx.sess.opts.unstable_opts.drop_tracking_mir + && let DefKind::Generator = tcx.def_kind(def) + { + tcx.ensure_with_value().mir_generator_witnesses(def); } + let mir_borrowck = tcx.mir_borrowck(def); - let mir_borrowck = tcx.mir_borrowck_opt_const_arg(def); - - let is_fn_like = tcx.def_kind(def.did).is_fn_like(); + let is_fn_like = tcx.def_kind(def).is_fn_like(); if is_fn_like { - let did = def.did.to_def_id(); - let def = ty::WithOptConstParam::unknown(did); - // Do not compute the mir call graph without said call graph actually being used. if inline::Inline.is_enabled(&tcx.sess) { - let _ = tcx.mir_inliner_callees(ty::InstanceDef::Item(def)); + tcx.ensure_with_value().mir_inliner_callees(ty::InstanceDef::Item(def.to_def_id())); } } @@ -389,38 +401,100 @@ fn mir_drops_elaborated_and_const_checked<'tcx>( body.tainted_by_errors = Some(error_reported); } - // IMPORTANT - pm::run_passes(tcx, &mut body, &[&remove_false_edges::RemoveFalseEdges]); + // Check if it's even possible to satisfy the 'where' clauses + // for this item. + // + // This branch will never be taken for any normal function. + // However, it's possible to `#!feature(trivial_bounds)]` to write + // a function with impossible to satisfy clauses, e.g.: + // `fn foo() where String: Copy {}` + // + // We don't usually need to worry about this kind of case, + // since we would get a compilation error if the user tried + // to call it. However, since we optimize even without any + // calls to the function, we need to make sure that it even + // makes sense to try to evaluate the body. + // + // If there are unsatisfiable where clauses, then all bets are + // off, and we just give up. + // + // We manually filter the predicates, skipping anything that's not + // "global". We are in a potentially generic context + // (e.g. we are evaluating a function without substituting generic + // parameters, so this filtering serves two purposes: + // + // 1. We skip evaluating any predicates that we would + // never be able prove are unsatisfiable (e.g. `<T as Foo>` + // 2. We avoid trying to normalize predicates involving generic + // parameters (e.g. `<T as Foo>::MyItem`). This can confuse + // the normalization code (leading to cycle errors), since + // it's usually never invoked in this way. + let predicates = tcx + .predicates_of(body.source.def_id()) + .predicates + .iter() + .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None }); + if traits::impossible_predicates(tcx, traits::elaborate(tcx, predicates).collect()) { + trace!("found unsatisfiable predicates for {:?}", body.source); + // Clear the body to only contain a single `unreachable` statement. + let bbs = body.basic_blocks.as_mut(); + bbs.raw.truncate(1); + bbs[START_BLOCK].statements.clear(); + bbs[START_BLOCK].terminator_mut().kind = TerminatorKind::Unreachable; + body.var_debug_info.clear(); + body.local_decls.raw.truncate(body.arg_count + 1); + } + + run_analysis_to_runtime_passes(tcx, &mut body); + + tcx.alloc_steal_mir(body) +} + +fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + assert!(body.phase == MirPhase::Analysis(AnalysisPhase::Initial)); + let did = body.source.def_id(); + + debug!("analysis_mir_cleanup({:?})", did); + run_analysis_cleanup_passes(tcx, body); + assert!(body.phase == MirPhase::Analysis(AnalysisPhase::PostCleanup)); // Do a little drop elaboration before const-checking if `const_precise_live_drops` is enabled. if check_consts::post_drop_elaboration::checking_enabled(&ConstCx::new(tcx, &body)) { pm::run_passes( tcx, - &mut body, - &[ - &simplify::SimplifyCfg::new("remove-false-edges"), - &remove_uninit_drops::RemoveUninitDrops, - ], + body, + &[&remove_uninit_drops::RemoveUninitDrops, &simplify::SimplifyCfg::RemoveFalseEdges], + None, ); check_consts::post_drop_elaboration::check_live_drops(tcx, &body); // FIXME: make this a MIR lint } - run_post_borrowck_cleanup_passes(tcx, &mut body); - assert!(body.phase == MirPhase::Deaggregated); - tcx.alloc_steal_mir(body) + debug!("runtime_mir_lowering({:?})", did); + run_runtime_lowering_passes(tcx, body); + assert!(body.phase == MirPhase::Runtime(RuntimePhase::Initial)); + + debug!("runtime_mir_cleanup({:?})", did); + run_runtime_cleanup_passes(tcx, body); + assert!(body.phase == MirPhase::Runtime(RuntimePhase::PostCleanup)); } -/// After this series of passes, no lifetime analysis based on borrowing can be done. -fn run_post_borrowck_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - debug!("post_borrowck_cleanup({:?})", body.source.def_id()); +// FIXME(JakobDegen): Can we make these lists of passes consts? - let post_borrowck_cleanup: &[&dyn MirPass<'tcx>] = &[ - // Remove all things only needed by analysis - &simplify_branches::SimplifyConstCondition::new("initial"), +/// After this series of passes, no lifetime analysis based on borrowing can be done. +fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let passes: &[&dyn MirPass<'tcx>] = &[ + &cleanup_post_borrowck::CleanupPostBorrowck, &remove_noop_landing_pads::RemoveNoopLandingPads, - &cleanup_post_borrowck::CleanupNonCodegenStatements, - &simplify::SimplifyCfg::new("early-opt"), + &simplify::SimplifyCfg::EarlyOpt, &deref_separator::Derefer, + ]; + + pm::run_passes(tcx, body, passes, Some(MirPhase::Analysis(AnalysisPhase::PostCleanup))); +} + +/// Returns the sequence of passes that lowers analysis to runtime MIR. +fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let passes: &[&dyn MirPass<'tcx>] = &[ // These next passes must be executed together &add_call_guards::CriticalCallEdges, &elaborate_drops::ElaborateDrops, @@ -434,16 +508,28 @@ fn run_post_borrowck_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tc // `AddRetag` needs to run after `ElaborateDrops`. Otherwise it should run fairly late, // but before optimizations begin. &elaborate_box_derefs::ElaborateBoxDerefs, + &generator::StateTransform, &add_retag::AddRetag, - &lower_intrinsics::LowerIntrinsics, - &simplify::SimplifyCfg::new("elaborate-drops"), - // `Deaggregator` is conceptually part of MIR building, some backends rely on it happening - // and it can help optimizations. - &deaggregator::Deaggregator, &Lint(const_prop_lint::ConstProp), ]; + pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial))); +} + +/// Returns the sequence of passes that do the initial cleanup of runtime MIR. +fn run_runtime_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let passes: &[&dyn MirPass<'tcx>] = &[ + &lower_intrinsics::LowerIntrinsics, + &remove_place_mention::RemovePlaceMention, + &simplify::SimplifyCfg::ElaborateDrops, + ]; - pm::run_passes(tcx, body, post_borrowck_cleanup); + pm::run_passes(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::PostCleanup))); + + // Clear this by anticipation. Optimizations and runtime MIR have no reason to look + // into this information, which is meant for borrowck diagnostics. + for decl in &mut body.local_decls { + decl.local_info = ClearCrossCrate::Clear; + } } fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -451,73 +537,67 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { WithMinOptLevel(1, x) } - // Lowering generator control-flow and variables has to happen before we do anything else - // to them. We run some optimizations before that, because they may be harder to do on the state - // machine than on MIR with async primitives. + // The main optimizations that we do on MIR. pm::run_passes( tcx, body, &[ + &check_alignment::CheckAlignment, &reveal_all::RevealAll, // has to be done before inlining, since inlined code is in RevealAll mode. &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first - &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering &unreachable_prop::UnreachablePropagation, &uninhabited_enum_branching::UninhabitedEnumBranching, - &o1(simplify::SimplifyCfg::new("after-uninhabited-enum-branching")), + &o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching), &inline::Inline, - &generator::StateTransform, - ], - ); - - assert!(body.phase == MirPhase::GeneratorsLowered); - - // The main optimizations that we do on MIR. - pm::run_passes( - tcx, - body, - &[ &remove_storage_markers::RemoveStorageMarkers, &remove_zsts::RemoveZsts, + &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering &const_goto::ConstGoto, &remove_unneeded_drops::RemoveUnneededDrops, + &sroa::ScalarReplacementOfAggregates, &match_branches::MatchBranchSimplification, // inst combine is after MatchBranchSimplification to clean up Ne(_1, false) &multiple_return_terminators::MultipleReturnTerminators, - &instcombine::InstCombine, + &instsimplify::InstSimplify, + &simplify::SimplifyLocals::BeforeConstProp, + ©_prop::CopyProp, + &ref_prop::ReferencePropagation, + // Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may + // destroy the SSA property. It should still happen before const-propagation, so the + // latter pass will leverage the created opportunities. &separate_const_switch::SeparateConstSwitch, - // - // FIXME(#70073): This pass is responsible for both optimization as well as some lints. &const_prop::ConstProp, + &dataflow_const_prop::DataflowConstProp, // // Const-prop runs unconditionally, but doesn't mutate the MIR at mir-opt-level=0. &const_debuginfo::ConstDebugInfo, - &o1(simplify_branches::SimplifyConstCondition::new("after-const-prop")), + &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), &early_otherwise_branch::EarlyOtherwiseBranch, &simplify_comparison_integral::SimplifyComparisonIntegral, - &simplify_try::SimplifyArmIdentity, - &simplify_try::SimplifyBranchSame, &dead_store_elimination::DeadStoreElimination, &dest_prop::DestinationPropagation, - &o1(simplify_branches::SimplifyConstCondition::new("final")), + &o1(simplify_branches::SimplifyConstCondition::Final), &o1(remove_noop_landing_pads::RemoveNoopLandingPads), - &o1(simplify::SimplifyCfg::new("final")), + &o1(simplify::SimplifyCfg::Final), &nrvo::RenameReturnPlace, - &simplify::SimplifyLocals, + &simplify::SimplifyLocals::Final, &multiple_return_terminators::MultipleReturnTerminators, &deduplicate_blocks::DeduplicateBlocks, + &large_enums::EnumSizeOpt { discrepancy: 128 }, // Some cleanup necessary at least for LLVM and potentially other codegen backends. &add_call_guards::CriticalCallEdges, - &marker::PhaseChange(MirPhase::Optimized), + // Cleanup for human readability, off by default. + &prettify::ReorderBasicBlocks, + &prettify::ReorderLocals, // Dump the end result for testing and debugging purposes. &dump_mir::Marker("PreCodegen"), ], + Some(MirPhase::Runtime(RuntimePhase::Optimized)), ); } /// Optimize the MIR and prepare it for codegen. -fn optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, did: DefId) -> &'tcx Body<'tcx> { - let did = did.expect_local(); - assert_eq!(ty::WithOptConstParam::try_lookup(did, tcx), None); +fn optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> &Body<'_> { tcx.arena.alloc(inner_optimized_mir(tcx, did)) } @@ -534,42 +614,32 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> { // Run the `mir_for_ctfe` query, which depends on `mir_drops_elaborated_and_const_checked` // which we are going to steal below. Thus we need to run `mir_for_ctfe` first, so it // computes and caches its result. - Some(hir::ConstContext::ConstFn) => tcx.ensure().mir_for_ctfe(did), + Some(hir::ConstContext::ConstFn) => tcx.ensure_with_value().mir_for_ctfe(did), None => {} Some(other) => panic!("do not use `optimized_mir` for constants: {:?}", other), } debug!("about to call mir_drops_elaborated..."); - let mut body = - tcx.mir_drops_elaborated_and_const_checked(ty::WithOptConstParam::unknown(did)).steal(); + let body = tcx.mir_drops_elaborated_and_const_checked(did).steal(); + let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::NotConst); debug!("body: {:#?}", body); run_optimization_passes(tcx, &mut body); - debug_assert!(!body.has_free_regions(), "Free regions in optimized MIR"); - body } /// Fetch all the promoteds of an item and prepare their MIR bodies to be ready for /// constant evaluation once all substitutions become known. -fn promoted_mir<'tcx>( - tcx: TyCtxt<'tcx>, - def: ty::WithOptConstParam<LocalDefId>, -) -> &'tcx IndexVec<Promoted, Body<'tcx>> { - if tcx.is_constructor(def.did.to_def_id()) { +fn promoted_mir(tcx: TyCtxt<'_>, def: LocalDefId) -> &IndexVec<Promoted, Body<'_>> { + if tcx.is_constructor(def.to_def_id()) { return tcx.arena.alloc(IndexVec::new()); } - let tainted_by_errors = tcx.mir_borrowck_opt_const_arg(def).tainted_by_errors; + tcx.ensure_with_value().mir_borrowck(def); let mut promoted = tcx.mir_promoted(def).1.steal(); for body in &mut promoted { - if let Some(error_reported) = tainted_by_errors { - body.tainted_by_errors = Some(error_reported); - } - run_post_borrowck_cleanup_passes(tcx, body); + run_analysis_to_runtime_passes(tcx, body); } - debug_assert!(!promoted.has_free_regions(), "Free regions in promoted MIR"); - tcx.arena.alloc(promoted) } diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs index b7ba616510c..3a7d58f7125 100644 --- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -1,11 +1,12 @@ //! Lowers intrinsic calls -use crate::MirPass; +use crate::{errors, MirPass}; use rustc_middle::mir::*; use rustc_middle::ty::subst::SubstsRef; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_span::symbol::{sym, Symbol}; use rustc_span::Span; +use rustc_target::abi::{FieldIdx, VariantIdx}; pub struct LowerIntrinsics; @@ -46,12 +47,14 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { let mut args = args.drain(..); block.statements.push(Statement { source_info: terminator.source_info, - kind: StatementKind::CopyNonOverlapping(Box::new( - rustc_middle::mir::CopyNonOverlapping { - src: args.next().unwrap(), - dst: args.next().unwrap(), - count: args.next().unwrap(), - }, + kind: StatementKind::Intrinsic(Box::new( + NonDivergingIntrinsic::CopyNonOverlapping( + rustc_middle::mir::CopyNonOverlapping { + src: args.next().unwrap(), + dst: args.next().unwrap(), + count: args.next().unwrap(), + }, + ), )), }); assert_eq!( @@ -62,7 +65,54 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { drop(args); terminator.kind = TerminatorKind::Goto { target }; } - sym::wrapping_add | sym::wrapping_sub | sym::wrapping_mul => { + sym::assume => { + let target = target.unwrap(); + let mut args = args.drain(..); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Intrinsic(Box::new( + NonDivergingIntrinsic::Assume(args.next().unwrap()), + )), + }); + assert_eq!( + args.next(), + None, + "Extra argument for copy_non_overlapping intrinsic" + ); + drop(args); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::wrapping_add + | sym::wrapping_sub + | sym::wrapping_mul + | sym::unchecked_div + | sym::unchecked_rem => { + let target = target.unwrap(); + let lhs; + let rhs; + { + let mut args = args.drain(..); + lhs = args.next().unwrap(); + rhs = args.next().unwrap(); + } + let bin_op = match intrinsic_name { + sym::wrapping_add => BinOp::Add, + sym::wrapping_sub => BinOp::Sub, + sym::wrapping_mul => BinOp::Mul, + sym::unchecked_div => BinOp::Div, + sym::unchecked_rem => BinOp::Rem, + _ => bug!("unexpected intrinsic"), + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::BinaryOp(bin_op, Box::new((lhs, rhs))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::add_with_overflow | sym::sub_with_overflow | sym::mul_with_overflow => { if let Some(target) = *target { let lhs; let rhs; @@ -72,26 +122,21 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { rhs = args.next().unwrap(); } let bin_op = match intrinsic_name { - sym::wrapping_add => BinOp::Add, - sym::wrapping_sub => BinOp::Sub, - sym::wrapping_mul => BinOp::Mul, + sym::add_with_overflow => BinOp::Add, + sym::sub_with_overflow => BinOp::Sub, + sym::mul_with_overflow => BinOp::Mul, _ => bug!("unexpected intrinsic"), }; block.statements.push(Statement { source_info: terminator.source_info, kind: StatementKind::Assign(Box::new(( *destination, - Rvalue::BinaryOp(bin_op, Box::new((lhs, rhs))), + Rvalue::CheckedBinaryOp(bin_op, Box::new((lhs, rhs))), ))), }); terminator.kind = TerminatorKind::Goto { target }; } } - sym::add_with_overflow | sym::sub_with_overflow | sym::mul_with_overflow => { - // The checked binary operations are not suitable target for lowering here, - // since their semantics depend on the value of overflow-checks flag used - // during codegen. Issue #35310. - } sym::size_of | sym::min_align_of => { if let Some(target) = *target { let tp_ty = substs.type_at(0); @@ -110,6 +155,58 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { terminator.kind = TerminatorKind::Goto { target }; } } + sym::read_via_copy => { + let [arg] = args.as_slice() else { + span_bug!(terminator.source_info.span, "Wrong number of arguments"); + }; + let derefed_place = + if let Some(place) = arg.place() && let Some(local) = place.as_local() { + tcx.mk_place_deref(local.into()) + } else { + span_bug!(terminator.source_info.span, "Only passing a local is supported"); + }; + terminator.kind = match *target { + None => { + // No target means this read something uninhabited, + // so it must be unreachable, and we don't need to + // preserve the assignment either. + TerminatorKind::Unreachable + } + Some(target) => { + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Use(Operand::Copy(derefed_place)), + ))), + }); + TerminatorKind::Goto { target } + } + } + } + sym::write_via_move => { + let target = target.unwrap(); + let Ok([ptr, val]) = <[_; 2]>::try_from(std::mem::take(args)) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for write_via_move intrinsic", + ); + }; + let derefed_place = + if let Some(place) = ptr.place() && let Some(local) = place.as_local() { + tcx.mk_place_deref(local.into()) + } else { + span_bug!(terminator.source_info.span, "Only passing a local is supported"); + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + derefed_place, + Rvalue::Use(val), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } sym::discriminant_value => { if let (Some(target), Some(arg)) = (*target, args[0].place()) { let arg = tcx.mk_place_deref(arg); @@ -123,6 +220,78 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { terminator.kind = TerminatorKind::Goto { target }; } } + sym::offset => { + let target = target.unwrap(); + let Ok([ptr, delta]) = <[_; 2]>::try_from(std::mem::take(args)) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for offset intrinsic", + ); + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::BinaryOp(BinOp::Offset, Box::new((ptr, delta))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::option_payload_ptr => { + if let (Some(target), Some(arg)) = (*target, args[0].place()) { + let ty::RawPtr(ty::TypeAndMut { ty: dest_ty, .. }) = + destination.ty(local_decls, tcx).ty.kind() + else { bug!(); }; + + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::AddressOf( + Mutability::Not, + arg.project_deeper( + &[ + PlaceElem::Deref, + PlaceElem::Downcast( + Some(sym::Some), + VariantIdx::from_u32(1), + ), + PlaceElem::Field(FieldIdx::from_u32(0), *dest_ty), + ], + tcx, + ), + ), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + } + sym::transmute | sym::transmute_unchecked => { + let dst_ty = destination.ty(local_decls, tcx).ty; + let Ok([arg]) = <[_; 1]>::try_from(std::mem::take(args)) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for transmute intrinsic", + ); + }; + + // Always emit the cast, even if we transmute to an uninhabited type, + // because that lets CTFE and codegen generate better error messages + // when such a transmute actually ends up reachable. + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Cast(CastKind::Transmute, arg, dst_ty), + ))), + }); + + if let Some(target) = *target { + terminator.kind = TerminatorKind::Goto { target }; + } else { + terminator.kind = TerminatorKind::Unreachable; + } + } _ if intrinsic_name.as_str().starts_with("simd_shuffle") => { validate_simd_shuffle(tcx, args, terminator.source_info.span); } @@ -146,11 +315,7 @@ fn resolve_rust_intrinsic<'tcx>( } fn validate_simd_shuffle<'tcx>(tcx: TyCtxt<'tcx>, args: &[Operand<'tcx>], span: Span) { - match &args[2] { - Operand::Constant(_) => {} // all good - _ => { - let msg = "last argument of `simd_shuffle` is required to be a `const` item"; - tcx.sess.span_err(span, msg); - } + if !matches!(args[2], Operand::Constant(_)) { + tcx.sess.emit_err(errors::SimdShuffleLastConst { span }); } } diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs index 47848cfa497..6e40dfa0d13 100644 --- a/compiler/rustc_mir_transform/src/lower_slice_len.rs +++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs @@ -3,7 +3,7 @@ use crate::MirPass; use rustc_hir::def_id::DefId; -use rustc_index::vec::IndexVec; +use rustc_index::IndexSlice; use rustc_middle::mir::*; use rustc_middle::ty::{self, TyCtxt}; @@ -11,7 +11,7 @@ pub struct LowerSliceLenCalls; impl<'tcx> MirPass<'tcx> for LowerSliceLenCalls { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.opts.mir_opt_level() > 0 + sess.mir_opt_level() > 0 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -42,7 +42,7 @@ struct SliceLenPatchInformation<'tcx> { fn lower_slice_len_call<'tcx>( tcx: TyCtxt<'tcx>, block: &mut BasicBlockData<'tcx>, - local_decls: &IndexVec<Local, LocalDecl<'tcx>>, + local_decls: &IndexSlice<Local, LocalDecl<'tcx>>, slice_len_fn_item_def_id: DefId, ) { let mut patch_found: Option<SliceLenPatchInformation<'_>> = None; @@ -54,7 +54,6 @@ fn lower_slice_len_call<'tcx>( args, destination, target: Some(bb), - cleanup: None, from_hir_call: true, .. } => { @@ -68,8 +67,11 @@ fn lower_slice_len_call<'tcx>( ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => { // perform modifications // from something like `_5 = core::slice::<impl [u8]>::len(move _6) -> bb1` - // into `_5 = Len(*_6) + // into: + // ``` + // _5 = Len(*_6) // goto bb1 + // ``` // make new RValue for Len let deref_arg = tcx.mk_place_deref(arg); diff --git a/compiler/rustc_mir_transform/src/marker.rs b/compiler/rustc_mir_transform/src/marker.rs deleted file mode 100644 index 06819fc1d37..00000000000 --- a/compiler/rustc_mir_transform/src/marker.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::borrow::Cow; - -use crate::MirPass; -use rustc_middle::mir::{Body, MirPhase}; -use rustc_middle::ty::TyCtxt; - -/// Changes the MIR phase without changing the MIR itself. -pub struct PhaseChange(pub MirPhase); - -impl<'tcx> MirPass<'tcx> for PhaseChange { - fn phase_change(&self) -> Option<MirPhase> { - Some(self.0) - } - - fn name(&self) -> Cow<'_, str> { - Cow::from(format!("PhaseChange-{:?}", self.0)) - } - - fn run_pass(&self, _: TyCtxt<'tcx>, _body: &mut Body<'tcx>) {} -} diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index a0ba69c89b0..6eb48498274 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -41,12 +41,12 @@ pub struct MatchBranchSimplification; impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 3 + sess.mir_opt_level() >= 1 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let def_id = body.source.def_id(); - let param_env = tcx.param_env(def_id); + let param_env = tcx.param_env_reveal_all_normalized(def_id); let bbs = body.basic_blocks.as_mut(); let mut should_cleanup = false; @@ -55,18 +55,22 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { continue; } - let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind { + let (discr, val, first, second) = match bbs[bb_idx].terminator().kind { TerminatorKind::SwitchInt { discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), - switch_ty, ref targets, .. } if targets.iter().len() == 1 => { let (value, target) = targets.iter().next().unwrap(); - if target == targets.otherwise() { + // We require that this block and the two possible target blocks all be + // distinct. + if target == targets.otherwise() + || bb_idx == target + || bb_idx == targets.otherwise() + { continue; } - (discr, value, switch_ty, target, targets.otherwise()) + (discr, value, target, targets.otherwise()) } // Only optimize switch int statements _ => continue, @@ -105,10 +109,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { } // Take ownership of items now that we know we can optimize. let discr = discr.clone(); + let discr_ty = discr.ty(&body.local_decls, tcx); // Introduce a temporary for the discriminant value. let source_info = bbs[bb_idx].terminator().source_info; - let discr_local = body.local_decls.push(LocalDecl::new(switch_ty, source_info.span)); + let discr_local = body.local_decls.push(LocalDecl::new(discr_ty, source_info.span)); // We already checked that first and second are different blocks, // and bb_idx has a different terminator from both of them. @@ -130,10 +135,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { (*f).clone() } else { // Different value between blocks. Make value conditional on switch condition. - let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size; + let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; let const_cmp = Operand::const_from_scalar( tcx, - switch_ty, + discr_ty, rustc_const_eval::interpret::Scalar::from_uint(val, size), rustc_span::DUMMY_SP, ); diff --git a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs index 22b6dead99c..3957cd92c4e 100644 --- a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs +++ b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs @@ -15,7 +15,7 @@ impl<'tcx> MirPass<'tcx> for MultipleReturnTerminators { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // find basic blocks with no statement and a return terminator - let mut bbs_simple_returns = BitSet::new_empty(body.basic_blocks().len()); + let mut bbs_simple_returns = BitSet::new_empty(body.basic_blocks.len()); let def_id = body.source.def_id(); let bbs = body.basic_blocks_mut(); for idx in bbs.indices() { diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs index c0217a10541..3d61d33ce35 100644 --- a/compiler/rustc_mir_transform/src/normalize_array_len.rs +++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs @@ -1,287 +1,101 @@ //! This pass eliminates casting of arrays into slices when their length //! is taken using `.len()` method. Handy to preserve information in MIR for const prop +use crate::ssa::SsaLocals; use crate::MirPass; -use rustc_data_structures::fx::FxIndexMap; -use rustc_data_structures::intern::Interned; -use rustc_index::bit_set::BitSet; -use rustc_index::vec::IndexVec; +use rustc_index::IndexVec; +use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::{self, ReErased, Region, TyCtxt}; - -const MAX_NUM_BLOCKS: usize = 800; -const MAX_NUM_LOCALS: usize = 3000; +use rustc_middle::ty::{self, TyCtxt}; pub struct NormalizeArrayLen; impl<'tcx> MirPass<'tcx> for NormalizeArrayLen { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 4 + sess.mir_opt_level() >= 3 } + #[instrument(level = "trace", skip(self, tcx, body))] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // early returns for edge cases of highly unrolled functions - if body.basic_blocks().len() > MAX_NUM_BLOCKS { - return; - } - if body.local_decls().len() > MAX_NUM_LOCALS { - return; - } + debug!(def_id = ?body.source.def_id()); normalize_array_len_calls(tcx, body) } } -pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // We don't ever touch terminators, so no need to invalidate the CFG cache - let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); - let local_decls = &mut body.local_decls; +fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let ssa = SsaLocals::new(body); - // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]` - let mut interesting_locals = BitSet::new_empty(local_decls.len()); - for (local, decl) in local_decls.iter_enumerated() { - match decl.ty.kind() { - ty::Array(..) => { - interesting_locals.insert(local); - } - ty::Ref(.., ty, Mutability::Not) => match ty.kind() { - ty::Array(..) => { - interesting_locals.insert(local); - } - _ => {} - }, - _ => {} - } - } - if interesting_locals.is_empty() { - // we have found nothing to analyze - return; - } - let num_intesting_locals = interesting_locals.count(); - let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); - let mut patches_scratchpad = - FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); - let mut replacements_scratchpad = - FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); - for block in basic_blocks { - // make length calls for arrays [T; N] not to decay into length calls for &[T] - // that forbids constant propagation - normalize_array_len_call( - tcx, - block, - local_decls, - &interesting_locals, - &mut state, - &mut patches_scratchpad, - &mut replacements_scratchpad, - ); - state.clear(); - patches_scratchpad.clear(); - replacements_scratchpad.clear(); - } -} + let slice_lengths = compute_slice_length(tcx, &ssa, body); + debug!(?slice_lengths); -struct Patcher<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - patches_scratchpad: &'a FxIndexMap<usize, usize>, - replacements_scratchpad: &'a mut FxIndexMap<usize, Local>, - local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>, - statement_idx: usize, + Replacer { tcx, slice_lengths }.visit_body_preserves_cfg(body); } -impl<'tcx> Patcher<'_, 'tcx> { - fn patch_expand_statement( - &mut self, - statement: &mut Statement<'tcx>, - ) -> Option<std::vec::IntoIter<Statement<'tcx>>> { - let idx = self.statement_idx; - if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() { - let mut statements = Vec::with_capacity(2); - - // we are at statement that performs a cast. The only sound way is - // to create another local that performs a similar copy without a cast and then - // use this copy in the Len operation - - match &statement.kind { - StatementKind::Assign(box ( - .., - Rvalue::Cast( - CastKind::Pointer(ty::adjustment::PointerCast::Unsize), - operand, - _, - ), - )) => { - match operand { - Operand::Copy(place) | Operand::Move(place) => { - // create new local - let ty = operand.ty(self.local_decls, self.tcx); - let local_decl = LocalDecl::with_source_info(ty, statement.source_info); - let local = self.local_decls.push(local_decl); - // make it live - let mut make_live_statement = statement.clone(); - make_live_statement.kind = StatementKind::StorageLive(local); - statements.push(make_live_statement); - // copy into it - - let operand = Operand::Copy(*place); - let mut make_copy_statement = statement.clone(); - let assign_to = Place::from(local); - let rvalue = Rvalue::Use(operand); - make_copy_statement.kind = - StatementKind::Assign(Box::new((assign_to, rvalue))); - statements.push(make_copy_statement); - - // to reorder we have to copy and make NOP - statements.push(statement.clone()); - statement.make_nop(); - - self.replacements_scratchpad.insert(len_statemnt_idx, local); - } - _ => { - unreachable!("it's a bug in the implementation") - } - } - } - _ => { - unreachable!("it's a bug in the implementation") +fn compute_slice_length<'tcx>( + tcx: TyCtxt<'tcx>, + ssa: &SsaLocals, + body: &Body<'tcx>, +) -> IndexVec<Local, Option<ty::Const<'tcx>>> { + let mut slice_lengths = IndexVec::from_elem(None, &body.local_decls); + + for (local, rvalue, _) in ssa.assignments(body) { + match rvalue { + Rvalue::Cast( + CastKind::Pointer(ty::adjustment::PointerCast::Unsize), + operand, + cast_ty, + ) => { + let operand_ty = operand.ty(body, tcx); + debug!(?operand_ty); + if let Some(operand_ty) = operand_ty.builtin_deref(true) + && let ty::Array(_, len) = operand_ty.ty.kind() + && let Some(cast_ty) = cast_ty.builtin_deref(true) + && let ty::Slice(..) = cast_ty.ty.kind() + { + slice_lengths[local] = Some(*len); } } - - self.statement_idx += 1; - - Some(statements.into_iter()) - } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() { - let mut statements = Vec::with_capacity(2); - - match &statement.kind { - StatementKind::Assign(box (into, Rvalue::Len(place))) => { - let add_deref = if let Some(..) = place.as_local() { - false - } else if let Some(..) = place.local_or_deref_local() { - true - } else { - unreachable!("it's a bug in the implementation") - }; - // replace len statement - let mut len_statement = statement.clone(); - let mut place = Place::from(local); - if add_deref { - place = self.tcx.mk_place_deref(place); - } - len_statement.kind = - StatementKind::Assign(Box::new((*into, Rvalue::Len(place)))); - statements.push(len_statement); - - // make temporary dead - let mut make_dead_statement = statement.clone(); - make_dead_statement.kind = StatementKind::StorageDead(local); - statements.push(make_dead_statement); - - // make original statement NOP - statement.make_nop(); + // The length information is stored in the fat pointer, so we treat `operand` as a value. + Rvalue::Use(operand) => { + if let Some(rhs) = operand.place() && let Some(rhs) = rhs.as_local() { + slice_lengths[local] = slice_lengths[rhs]; } - _ => { - unreachable!("it's a bug in the implementation") + } + // The length information is stored in the fat pointer. + // Reborrowing copies length information from one pointer to the other. + Rvalue::Ref(_, _, rhs) | Rvalue::AddressOf(_, rhs) => { + if let [PlaceElem::Deref] = rhs.projection[..] { + slice_lengths[local] = slice_lengths[rhs.local]; } } - - self.statement_idx += 1; - - Some(statements.into_iter()) - } else { - self.statement_idx += 1; - None + _ => {} } } + + slice_lengths } -fn normalize_array_len_call<'tcx>( +struct Replacer<'tcx> { tcx: TyCtxt<'tcx>, - block: &mut BasicBlockData<'tcx>, - local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, - interesting_locals: &BitSet<Local>, - state: &mut FxIndexMap<Local, usize>, - patches_scratchpad: &mut FxIndexMap<usize, usize>, - replacements_scratchpad: &mut FxIndexMap<usize, Local>, -) { - for (statement_idx, statement) in block.statements.iter_mut().enumerate() { - match &mut statement.kind { - StatementKind::Assign(box (place, rvalue)) => { - match rvalue { - Rvalue::Cast( - CastKind::Pointer(ty::adjustment::PointerCast::Unsize), - operand, - cast_ty, - ) => { - let Some(local) = place.as_local() else { return }; - match operand { - Operand::Copy(place) | Operand::Move(place) => { - let Some(operand_local) = place.local_or_deref_local() else { return; }; - if !interesting_locals.contains(operand_local) { - return; - } - let operand_ty = local_decls[operand_local].ty; - match (operand_ty.kind(), cast_ty.kind()) { - (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { - if of_ty_src == of_ty_dst { - // this is a cast from [T; N] into [T], so we are good - state.insert(local, statement_idx); - } - } - // current way of patching doesn't allow to work with `mut` - ( - ty::Ref( - Region(Interned(ReErased, _)), - operand_ty, - Mutability::Not, - ), - ty::Ref( - Region(Interned(ReErased, _)), - cast_ty, - Mutability::Not, - ), - ) => { - match (operand_ty.kind(), cast_ty.kind()) { - // current way of patching doesn't allow to work with `mut` - (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { - if of_ty_src == of_ty_dst { - // this is a cast from [T; N] into [T], so we are good - state.insert(local, statement_idx); - } - } - _ => {} - } - } - _ => {} - } - } - _ => {} - } - } - Rvalue::Len(place) => { - let Some(local) = place.local_or_deref_local() else { - return; - }; - if let Some(cast_statement_idx) = state.get(&local).copied() { - patches_scratchpad.insert(cast_statement_idx, statement_idx); - } - } - _ => { - // invalidate - state.remove(&place.local); - } - } - } - _ => {} - } - } + slice_lengths: IndexVec<Local, Option<ty::Const<'tcx>>>, +} - let mut patcher = Patcher { - tcx, - patches_scratchpad: &*patches_scratchpad, - replacements_scratchpad, - local_decls, - statement_idx: 0, - }; +impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } - block.expand_statements(|st| patcher.patch_expand_statement(st)); + fn visit_rvalue(&mut self, rvalue: &mut Rvalue<'tcx>, loc: Location) { + if let Rvalue::Len(place) = rvalue + && let [PlaceElem::Deref] = &place.projection[..] + && let Some(len) = self.slice_lengths[place.local] + { + *rvalue = Rvalue::Use(Operand::Constant(Box::new(Constant { + span: rustc_span::DUMMY_SP, + user_ty: None, + literal: ConstantKind::from_const(len, self.tcx), + }))); + } + self.super_rvalue(rvalue, loc); + } } diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs index bb063915f55..5ce96012b90 100644 --- a/compiler/rustc_mir_transform/src/nrvo.rs +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -34,7 +34,8 @@ pub struct RenameReturnPlace; impl<'tcx> MirPass<'tcx> for RenameReturnPlace { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() > 0 + // #111005 + sess.mir_opt_level() > 0 && sess.opts.unstable_opts.unsound_mir_opts } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { @@ -53,10 +54,10 @@ impl<'tcx> MirPass<'tcx> for RenameReturnPlace { def_id, returned_local ); - RenameToReturnPlace { tcx, to_rename: returned_local }.visit_body(body); + RenameToReturnPlace { tcx, to_rename: returned_local }.visit_body_preserves_cfg(body); // Clean up the `NOP`s we inserted for statements made useless by our renaming. - for block_data in body.basic_blocks_mut() { + for block_data in body.basic_blocks.as_mut_preserves_cfg() { block_data.statements.retain(|stmt| stmt.kind != mir::StatementKind::Nop); } @@ -89,7 +90,7 @@ fn local_eligible_for_nrvo(body: &mut mir::Body<'_>) -> Option<Local> { } let mut copied_to_return_place = None; - for block in body.basic_blocks().indices() { + for block in body.basic_blocks.indices() { // Look for blocks with a `Return` terminator. if !matches!(body[block].terminator().kind, mir::TerminatorKind::Return) { continue; @@ -102,12 +103,12 @@ fn local_eligible_for_nrvo(body: &mut mir::Body<'_>) -> Option<Local> { mir::LocalKind::Arg => return None, mir::LocalKind::ReturnPointer => bug!("Return place was assigned to itself?"), - mir::LocalKind::Var | mir::LocalKind::Temp => {} + mir::LocalKind::Temp => {} } // If multiple different locals are copied to the return place. We can't pick a // single one to rename. - if copied_to_return_place.map_or(false, |old| old != returned_local) { + if copied_to_return_place.is_some_and(|old| old != returned_local) { return None; } @@ -122,7 +123,7 @@ fn find_local_assigned_to_return_place( body: &mut mir::Body<'_>, ) -> Option<Local> { let mut block = start; - let mut seen = HybridBitSet::new_empty(body.basic_blocks().len()); + let mut seen = HybridBitSet::new_empty(body.basic_blocks.len()); // Iterate as long as `block` has exactly one predecessor that we have not yet visited. while seen.insert(block) { diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index e27d4ab1688..710eed3ed38 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -1,6 +1,4 @@ -use std::borrow::Cow; - -use rustc_middle::mir::{self, Body, MirPhase}; +use rustc_middle::mir::{self, Body, MirPhase, RuntimePhase}; use rustc_middle::ty::TyCtxt; use rustc_session::Session; @@ -8,13 +6,9 @@ use crate::{validate, MirPass}; /// Just like `MirPass`, except it cannot mutate `Body`. pub trait MirLint<'tcx> { - fn name(&self) -> Cow<'_, str> { + fn name(&self) -> &'static str { let name = std::any::type_name::<Self>(); - if let Some(tail) = name.rfind(':') { - Cow::from(&name[tail + 1..]) - } else { - Cow::from(name) - } + if let Some((_, tail)) = name.rsplit_once(':') { tail } else { name } } fn is_enabled(&self, _sess: &Session) -> bool { @@ -32,7 +26,7 @@ impl<'tcx, T> MirPass<'tcx> for Lint<T> where T: MirLint<'tcx>, { - fn name(&self) -> Cow<'_, str> { + fn name(&self) -> &'static str { self.0.name() } @@ -55,7 +49,7 @@ impl<'tcx, T> MirPass<'tcx> for WithMinOptLevel<T> where T: MirPass<'tcx>, { - fn name(&self) -> Cow<'_, str> { + fn name(&self) -> &'static str { self.1.name() } @@ -66,69 +60,94 @@ where fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { self.1.run_pass(tcx, body) } +} - fn phase_change(&self) -> Option<MirPhase> { - self.1.phase_change() - } +/// Run the sequence of passes without validating the MIR after each pass. The MIR is still +/// validated at the end. +pub fn run_passes_no_validate<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + passes: &[&dyn MirPass<'tcx>], + phase_change: Option<MirPhase>, +) { + run_passes_inner(tcx, body, passes, phase_change, false); } -pub fn run_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, passes: &[&dyn MirPass<'tcx>]) { - let start_phase = body.phase; - let mut cnt = 0; +/// The optional `phase_change` is applied after executing all the passes, if present +pub fn run_passes<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + passes: &[&dyn MirPass<'tcx>], + phase_change: Option<MirPhase>, +) { + run_passes_inner(tcx, body, passes, phase_change, true); +} - let validate = tcx.sess.opts.unstable_opts.validate_mir; +fn run_passes_inner<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + passes: &[&dyn MirPass<'tcx>], + phase_change: Option<MirPhase>, + validate_each: bool, +) { + let validate = validate_each & tcx.sess.opts.unstable_opts.validate_mir & !body.should_skip(); let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes; trace!(?overridden_passes); - if validate { - validate_body(tcx, body, format!("start of phase transition from {:?}", start_phase)); - } - - for pass in passes { - let name = pass.name(); - - if let Some((_, polarity)) = overridden_passes.iter().rev().find(|(s, _)| s == &*name) { - trace!( - pass = %name, - "{} as requested by flag", - if *polarity { "Running" } else { "Not running" }, + if !body.should_skip() { + for pass in passes { + let name = pass.name(); + + let overridden = overridden_passes.iter().rev().find(|(s, _)| s == &*name).map( + |(_name, polarity)| { + trace!( + pass = %name, + "{} as requested by flag", + if *polarity { "Running" } else { "Not running" }, + ); + *polarity + }, ); - if !polarity { - continue; - } - } else { - if !pass.is_enabled(&tcx.sess) { + if !overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess)) { continue; } - } - let dump_enabled = pass.is_mir_dump_enabled(); - if dump_enabled { - dump_mir(tcx, body, start_phase, &name, cnt, false); - } + let dump_enabled = pass.is_mir_dump_enabled(); - pass.run_pass(tcx, body); + if dump_enabled { + dump_mir_for_pass(tcx, body, &name, false); + } + if validate { + validate_body(tcx, body, format!("before pass {}", name)); + } - if dump_enabled { - dump_mir(tcx, body, start_phase, &name, cnt, true); - cnt += 1; - } + tcx.sess.time(name, || pass.run_pass(tcx, body)); - if let Some(new_phase) = pass.phase_change() { - if body.phase >= new_phase { - panic!("Invalid MIR phase transition from {:?} to {:?}", body.phase, new_phase); + if dump_enabled { + dump_mir_for_pass(tcx, body, &name, true); + } + if validate { + validate_body(tcx, body, format!("after pass {}", name)); } - body.phase = new_phase; + body.pass_count += 1; } + } - if validate { - validate_body(tcx, body, format!("after pass {}", pass.name())); + if let Some(new_phase) = phase_change { + if body.phase >= new_phase { + panic!("Invalid MIR phase transition from {:?} to {:?}", body.phase, new_phase); + } + + body.phase = new_phase; + body.pass_count = 0; + + dump_mir_for_phase_change(tcx, body); + if validate || new_phase == MirPhase::Runtime(RuntimePhase::Optimized) { + validate_body(tcx, body, format!("after phase change to {}", new_phase.name())); } - } - if validate || body.phase == MirPhase::Optimized { - validate_body(tcx, body, format!("end of phase transition to {:?}", body.phase)); + body.pass_count = 1; } } @@ -136,22 +155,23 @@ pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: Strin validate::Validator { when, mir_phase: body.phase }.run_pass(tcx, body); } -pub fn dump_mir<'tcx>( +pub fn dump_mir_for_pass<'tcx>( tcx: TyCtxt<'tcx>, body: &Body<'tcx>, - phase: MirPhase, pass_name: &str, - cnt: usize, is_after: bool, ) { - let phase_index = phase as u32; - mir::dump_mir( tcx, - Some(&format_args!("{:03}-{:03}", phase_index, cnt)), + true, pass_name, if is_after { &"after" } else { &"before" }, body, |_, _| Ok(()), ); } + +pub fn dump_mir_for_phase_change<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + assert_eq!(body.pass_count, 0); + mir::dump_mir(tcx, true, body.phase.name(), &"after", body, |_, _| Ok(())) +} diff --git a/compiler/rustc_mir_transform/src/prettify.rs b/compiler/rustc_mir_transform/src/prettify.rs new file mode 100644 index 00000000000..6f46974ea00 --- /dev/null +++ b/compiler/rustc_mir_transform/src/prettify.rs @@ -0,0 +1,150 @@ +//! These two passes provide no value to the compiler, so are off at every level. +//! +//! However, they can be enabled on the command line +//! (`-Zmir-enable-passes=+ReorderBasicBlocks,+ReorderLocals`) +//! to make the MIR easier to read for humans. + +use crate::MirPass; +use rustc_index::{bit_set::BitSet, IndexSlice, IndexVec}; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_session::Session; + +/// Rearranges the basic blocks into a *reverse post-order*. +/// +/// Thus after this pass, all the successors of a block are later than it in the +/// `IndexVec`, unless that successor is a back-edge (such as from a loop). +pub struct ReorderBasicBlocks; + +impl<'tcx> MirPass<'tcx> for ReorderBasicBlocks { + fn is_enabled(&self, _session: &Session) -> bool { + false + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let rpo: IndexVec<BasicBlock, BasicBlock> = + body.basic_blocks.postorder().iter().copied().rev().collect(); + if rpo.iter().is_sorted() { + return; + } + + let mut updater = BasicBlockUpdater { map: rpo.invert_bijective_mapping(), tcx }; + debug_assert_eq!(updater.map[START_BLOCK], START_BLOCK); + updater.visit_body(body); + + permute(body.basic_blocks.as_mut(), &updater.map); + } +} + +/// Rearranges the locals into *use* order. +/// +/// Thus after this pass, a local with a smaller [`Location`] where it was first +/// assigned or referenced will have a smaller number. +/// +/// (Does not reorder arguments nor the [`RETURN_PLACE`].) +pub struct ReorderLocals; + +impl<'tcx> MirPass<'tcx> for ReorderLocals { + fn is_enabled(&self, _session: &Session) -> bool { + false + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut finder = + LocalFinder { map: IndexVec::new(), seen: BitSet::new_empty(body.local_decls.len()) }; + + // We can't reorder the return place or the arguments + for local in (0..=body.arg_count).map(Local::from_usize) { + finder.track(local); + } + + for (bb, bbd) in body.basic_blocks.iter_enumerated() { + finder.visit_basic_block_data(bb, bbd); + } + + // track everything in case there are some locals that we never saw, + // such as in non-block things like debug info or in non-uses. + for local in body.local_decls.indices() { + finder.track(local); + } + + if finder.map.iter().is_sorted() { + return; + } + + let mut updater = LocalUpdater { map: finder.map.invert_bijective_mapping(), tcx }; + + for local in (0..=body.arg_count).map(Local::from_usize) { + debug_assert_eq!(updater.map[local], local); + } + + updater.visit_body_preserves_cfg(body); + + permute(&mut body.local_decls, &updater.map); + } +} + +fn permute<I: rustc_index::Idx + Ord, T>(data: &mut IndexVec<I, T>, map: &IndexSlice<I, I>) { + // FIXME: It would be nice to have a less-awkward way to apply permutations, + // but I don't know one that exists. `sort_by_cached_key` has logic for it + // internally, but not in a way that we're allowed to use here. + let mut enumerated: Vec<_> = std::mem::take(data).into_iter_enumerated().collect(); + enumerated.sort_by_key(|p| map[p.0]); + *data = enumerated.into_iter().map(|p| p.1).collect(); +} + +struct BasicBlockUpdater<'tcx> { + map: IndexVec<BasicBlock, BasicBlock>, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for BasicBlockUpdater<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, _location: Location) { + for succ in terminator.successors_mut() { + *succ = self.map[*succ]; + } + } +} + +struct LocalFinder { + map: IndexVec<Local, Local>, + seen: BitSet<Local>, +} + +impl LocalFinder { + fn track(&mut self, l: Local) { + if self.seen.insert(l) { + self.map.push(l); + } + } +} + +impl<'tcx> Visitor<'tcx> for LocalFinder { + fn visit_local(&mut self, l: Local, context: PlaceContext, _location: Location) { + // Exclude non-uses to keep `StorageLive` from controlling where we put + // a `Local`, since it might not actually be assigned until much later. + if context.is_use() { + self.track(l); + } + } +} + +struct LocalUpdater<'tcx> { + pub map: IndexVec<Local, Local>, + pub tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) { + *l = self.map[*l]; + } +} diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs new file mode 100644 index 00000000000..bbd9f76ba5c --- /dev/null +++ b/compiler/rustc_mir_transform/src/ref_prop.rs @@ -0,0 +1,408 @@ +use rustc_data_structures::fx::FxHashSet; +use rustc_index::bit_set::BitSet; +use rustc_index::IndexVec; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::impls::MaybeStorageDead; +use rustc_mir_dataflow::storage::always_storage_live_locals; +use rustc_mir_dataflow::Analysis; + +use crate::ssa::{SsaLocals, StorageLiveLocals}; +use crate::MirPass; + +/// Propagate references using SSA analysis. +/// +/// MIR building may produce a lot of borrow-dereference patterns. +/// +/// This pass aims to transform the following pattern: +/// _1 = &raw? mut? PLACE; +/// _3 = *_1; +/// _4 = &raw? mut? *_1; +/// +/// Into +/// _1 = &raw? mut? PLACE; +/// _3 = PLACE; +/// _4 = &raw? mut? PLACE; +/// +/// where `PLACE` is a direct or an indirect place expression. +/// +/// There are 3 properties that need to be upheld for this transformation to be legal: +/// - place stability: `PLACE` must refer to the same memory wherever it appears; +/// - pointer liveness: we must not introduce dereferences of dangling pointers; +/// - `&mut` borrow uniqueness. +/// +/// # Stability +/// +/// If `PLACE` is an indirect projection, if its of the form `(*LOCAL).PROJECTIONS` where: +/// - `LOCAL` is SSA; +/// - all projections in `PROJECTIONS` have a stable offset (no dereference and no indexing). +/// +/// If `PLACE` is a direct projection of a local, we consider it as constant if: +/// - the local is always live, or it has a single `StorageLive`; +/// - all projections have a stable offset. +/// +/// # Liveness +/// +/// When performing a substitution, we must take care not to introduce uses of dangling locals. +/// To ensure this, we walk the body with the `MaybeStorageDead` dataflow analysis: +/// - if we want to replace `*x` by reborrow `*y` and `y` may be dead, we allow replacement and +/// mark storage statements on `y` for removal; +/// - if we want to replace `*x` by non-reborrow `y` and `y` must be live, we allow replacement; +/// - if we want to replace `*x` by non-reborrow `y` and `y` may be dead, we do not replace. +/// +/// # Uniqueness +/// +/// For `&mut` borrows, we also need to preserve the uniqueness property: +/// we must avoid creating a state where we interleave uses of `*_1` and `_2`. +/// To do it, we only perform full substitution of mutable borrows: +/// we replace either all or none of the occurrences of `*_1`. +/// +/// Some care has to be taken when `_1` is copied in other locals. +/// _1 = &raw? mut? _2; +/// _3 = *_1; +/// _4 = _1 +/// _5 = *_4 +/// In such cases, fully substituting `_1` means fully substituting all of the copies. +/// +/// For immutable borrows, we do not need to preserve such uniqueness property, +/// so we perform all the possible substitutions without removing the `_1 = &_2` statement. +pub struct ReferencePropagation; + +impl<'tcx> MirPass<'tcx> for ReferencePropagation { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + #[instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + while propagate_ssa(tcx, body) {} + } +} + +fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { + let ssa = SsaLocals::new(body); + + let mut replacer = compute_replacement(tcx, body, &ssa); + debug!(?replacer.targets); + debug!(?replacer.allowed_replacements); + debug!(?replacer.storage_to_remove); + + replacer.visit_body_preserves_cfg(body); + + if replacer.any_replacement { + crate::simplify::remove_unused_definitions(body); + } + + replacer.any_replacement +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum Value<'tcx> { + /// Not a pointer, or we can't know. + Unknown, + /// We know the value to be a pointer to this place. + /// The boolean indicates whether the reference is mutable, subject the uniqueness rule. + Pointer(Place<'tcx>, bool), +} + +/// For each local, save the place corresponding to `*local`. +#[instrument(level = "trace", skip(tcx, body))] +fn compute_replacement<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + ssa: &SsaLocals, +) -> Replacer<'tcx> { + let always_live_locals = always_storage_live_locals(body); + + // Compute which locals have a single `StorageLive` statement ever. + let storage_live = StorageLiveLocals::new(body, &always_live_locals); + + // Compute `MaybeStorageDead` dataflow to check that we only replace when the pointee is + // definitely live. + let mut maybe_dead = MaybeStorageDead::new(always_live_locals) + .into_engine(tcx, body) + .iterate_to_fixpoint() + .into_results_cursor(body); + + // Map for each local to the pointee. + let mut targets = IndexVec::from_elem(Value::Unknown, &body.local_decls); + // Set of locals for which we will remove their storage statement. This is useful for + // reborrowed references. + let mut storage_to_remove = BitSet::new_empty(body.local_decls.len()); + + let fully_replacable_locals = fully_replacable_locals(ssa); + + // Returns true iff we can use `place` as a pointee. + // + // Note that we only need to verify that there is a single `StorageLive` statement, and we do + // not need to verify that it dominates all uses of that local. + // + // Consider the three statements: + // SL : StorageLive(a) + // DEF: b = &raw? mut? a + // USE: stuff that uses *b + // + // First, we recall that DEF is checked to dominate USE. Now imagine for the sake of + // contradiction there is a DEF -> SL -> USE path. Consider two cases: + // + // - DEF dominates SL. We always have UB the first time control flow reaches DEF, + // because the storage of `a` is dead. Since DEF dominates USE, that means we cannot + // reach USE and so our optimization is ok. + // + // - DEF does not dominate SL. Then there is a `START_BLOCK -> SL` path not including DEF. + // But we can extend this path to USE, meaning there is also a `START_BLOCK -> USE` path not + // including DEF. This violates the DEF dominates USE condition, and so is impossible. + let is_constant_place = |place: Place<'_>| { + // We only allow `Deref` as the first projection, to avoid surprises. + if place.projection.first() == Some(&PlaceElem::Deref) { + // `place == (*some_local).xxx`, it is constant only if `some_local` is constant. + // We approximate constness using SSAness. + ssa.is_ssa(place.local) && place.projection[1..].iter().all(PlaceElem::is_stable_offset) + } else { + storage_live.has_single_storage(place.local) + && place.projection[..].iter().all(PlaceElem::is_stable_offset) + } + }; + + let mut can_perform_opt = |target: Place<'tcx>, loc: Location| { + if target.projection.first() == Some(&PlaceElem::Deref) { + // We are creating a reborrow. As `place.local` is a reference, removing the storage + // statements should not make it much harder for LLVM to optimize. + storage_to_remove.insert(target.local); + true + } else { + // This is a proper dereference. We can only allow it if `target` is live. + maybe_dead.seek_after_primary_effect(loc); + let maybe_dead = maybe_dead.contains(target.local); + !maybe_dead + } + }; + + for (local, rvalue, location) in ssa.assignments(body) { + debug!(?local); + + // Only visit if we have something to do. + let Value::Unknown = targets[local] else { bug!() }; + + let ty = body.local_decls[local].ty; + + // If this is not a reference or pointer, do nothing. + if !ty.is_any_ptr() { + debug!("not a reference or pointer"); + continue; + } + + // Whether the current local is subject to the uniqueness rule. + let needs_unique = ty.is_mutable_ptr(); + + // If this a mutable reference that we cannot fully replace, mark it as unknown. + if needs_unique && !fully_replacable_locals.contains(local) { + debug!("not fully replaceable"); + continue; + } + + debug!(?rvalue); + match rvalue { + // This is a copy, just use the value we have in store for the previous one. + // As we are visiting in `assignment_order`, ie. reverse postorder, `rhs` should + // have been visited before. + Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) + | Rvalue::CopyForDeref(place) => { + if let Some(rhs) = place.as_local() && ssa.is_ssa(rhs) { + let target = targets[rhs]; + // Only see through immutable reference and pointers, as we do not know yet if + // mutable references are fully replaced. + if !needs_unique && matches!(target, Value::Pointer(..)) { + targets[local] = target; + } else { + targets[local] = Value::Pointer(tcx.mk_place_deref(rhs.into()), needs_unique); + } + } + } + Rvalue::Ref(_, _, place) | Rvalue::AddressOf(_, place) => { + let mut place = *place; + // Try to see through `place` in order to collapse reborrow chains. + if place.projection.first() == Some(&PlaceElem::Deref) + && let Value::Pointer(target, inner_needs_unique) = targets[place.local] + // Only see through immutable reference and pointers, as we do not know yet if + // mutable references are fully replaced. + && !inner_needs_unique + // Only collapse chain if the pointee is definitely live. + && can_perform_opt(target, location) + { + place = target.project_deeper(&place.projection[1..], tcx); + } + assert_ne!(place.local, local); + if is_constant_place(place) { + targets[local] = Value::Pointer(place, needs_unique); + } + } + // We do not know what to do, so keep as not-a-pointer. + _ => {} + } + } + + debug!(?targets); + + let mut finder = ReplacementFinder { + targets: &mut targets, + can_perform_opt, + allowed_replacements: FxHashSet::default(), + }; + let reachable_blocks = traversal::reachable_as_bitset(body); + for (bb, bbdata) in body.basic_blocks.iter_enumerated() { + // Only visit reachable blocks as we rely on dataflow. + if reachable_blocks.contains(bb) { + finder.visit_basic_block_data(bb, bbdata); + } + } + + let allowed_replacements = finder.allowed_replacements; + return Replacer { + tcx, + targets, + storage_to_remove, + allowed_replacements, + fully_replacable_locals, + any_replacement: false, + }; + + struct ReplacementFinder<'a, 'tcx, F> { + targets: &'a mut IndexVec<Local, Value<'tcx>>, + can_perform_opt: F, + allowed_replacements: FxHashSet<(Local, Location)>, + } + + impl<'tcx, F> Visitor<'tcx> for ReplacementFinder<'_, 'tcx, F> + where + F: FnMut(Place<'tcx>, Location) -> bool, + { + fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, loc: Location) { + if matches!(ctxt, PlaceContext::NonUse(_)) { + // There is no need to check liveness for non-uses. + return; + } + + if place.projection.first() != Some(&PlaceElem::Deref) { + // This is not a dereference, nothing to do. + return; + } + + let mut place = place.as_ref(); + loop { + if let Value::Pointer(target, needs_unique) = self.targets[place.local] { + let perform_opt = (self.can_perform_opt)(target, loc); + debug!(?place, ?target, ?needs_unique, ?perform_opt); + + // This a reborrow chain, recursively allow the replacement. + // + // This also allows to detect cases where `target.local` is not replacable, + // and mark it as such. + if let &[PlaceElem::Deref] = &target.projection[..] { + assert!(perform_opt); + self.allowed_replacements.insert((target.local, loc)); + place.local = target.local; + continue; + } else if perform_opt { + self.allowed_replacements.insert((target.local, loc)); + } else if needs_unique { + // This mutable reference is not fully replacable, so drop it. + self.targets[place.local] = Value::Unknown; + } + } + + break; + } + } + } +} + +/// Compute the set of locals that can be fully replaced. +/// +/// We consider a local to be replacable iff it's only used in a `Deref` projection `*_local` or +/// non-use position (like storage statements and debuginfo). +fn fully_replacable_locals(ssa: &SsaLocals) -> BitSet<Local> { + let mut replacable = BitSet::new_empty(ssa.num_locals()); + + // First pass: for each local, whether its uses can be fully replaced. + for local in ssa.locals() { + if ssa.num_direct_uses(local) == 0 { + replacable.insert(local); + } + } + + // Second pass: a local can only be fully replaced if all its copies can. + ssa.meet_copy_equivalence(&mut replacable); + + replacable +} + +/// Utility to help performing subtitution of `*pattern` by `target`. +struct Replacer<'tcx> { + tcx: TyCtxt<'tcx>, + targets: IndexVec<Local, Value<'tcx>>, + storage_to_remove: BitSet<Local>, + allowed_replacements: FxHashSet<(Local, Location)>, + any_replacement: bool, + fully_replacable_locals: BitSet<Local>, +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_var_debug_info(&mut self, debuginfo: &mut VarDebugInfo<'tcx>) { + if let VarDebugInfoContents::Place(ref mut place) = debuginfo.value + && place.projection.is_empty() + && let Value::Pointer(target, _) = self.targets[place.local] + && target.projection.iter().all(|p| p.can_use_in_debuginfo()) + { + if let Some((&PlaceElem::Deref, rest)) = target.projection.split_last() { + *place = Place::from(target.local).project_deeper(rest, self.tcx); + self.any_replacement = true; + } else if self.fully_replacable_locals.contains(place.local) + && let Some(references) = debuginfo.references.checked_add(1) + { + debuginfo.references = references; + *place = target; + self.any_replacement = true; + } + } + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, ctxt: PlaceContext, loc: Location) { + if place.projection.first() != Some(&PlaceElem::Deref) { + return; + } + + loop { + if let Value::Pointer(target, _) = self.targets[place.local] { + let perform_opt = matches!(ctxt, PlaceContext::NonUse(_)) + || self.allowed_replacements.contains(&(target.local, loc)); + + if perform_opt { + *place = target.project_deeper(&place.projection[1..], self.tcx); + self.any_replacement = true; + continue; + } + } + + break; + } + } + + fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) { + match stmt.kind { + StatementKind::StorageLive(l) | StatementKind::StorageDead(l) + if self.storage_to_remove.contains(l) => + { + stmt.make_nop(); + } + // Do not remove assignments as they may still be useful for debuginfo. + _ => self.super_statement(stmt, loc), + } + } +} diff --git a/compiler/rustc_mir_transform/src/remove_false_edges.rs b/compiler/rustc_mir_transform/src/remove_false_edges.rs deleted file mode 100644 index 71f5ccf7e24..00000000000 --- a/compiler/rustc_mir_transform/src/remove_false_edges.rs +++ /dev/null @@ -1,29 +0,0 @@ -use rustc_middle::mir::{Body, TerminatorKind}; -use rustc_middle::ty::TyCtxt; - -use crate::MirPass; - -/// Removes `FalseEdge` and `FalseUnwind` terminators from the MIR. -/// -/// These are only needed for borrow checking, and can be removed afterwards. -/// -/// FIXME: This should probably have its own MIR phase. -pub struct RemoveFalseEdges; - -impl<'tcx> MirPass<'tcx> for RemoveFalseEdges { - fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - for block in body.basic_blocks_mut() { - let terminator = block.terminator_mut(); - terminator.kind = match terminator.kind { - TerminatorKind::FalseEdge { real_target, .. } => { - TerminatorKind::Goto { target: real_target } - } - TerminatorKind::FalseUnwind { real_target, .. } => { - TerminatorKind::Goto { target: real_target } - } - - _ => continue, - } - } - } -} diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs index 89808d3d4cd..4941c9edce3 100644 --- a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs +++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs @@ -33,8 +33,10 @@ impl RemoveNoopLandingPads { StatementKind::FakeRead(..) | StatementKind::StorageLive(_) | StatementKind::StorageDead(_) + | StatementKind::PlaceMention(..) | StatementKind::AscribeUserType(..) | StatementKind::Coverage(..) + | StatementKind::ConstEvalCounter | StatementKind::Nop => { // These are all noops in a landing pad } @@ -51,7 +53,7 @@ impl RemoveNoopLandingPads { StatementKind::Assign { .. } | StatementKind::SetDiscriminant { .. } | StatementKind::Deinit(..) - | StatementKind::CopyNonOverlapping(..) + | StatementKind::Intrinsic(..) | StatementKind::Retag { .. } => { return false; } @@ -70,11 +72,10 @@ impl RemoveNoopLandingPads { TerminatorKind::GeneratorDrop | TerminatorKind::Yield { .. } | TerminatorKind::Return - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Unreachable | TerminatorKind::Call { .. } | TerminatorKind::Assert { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::Drop { .. } | TerminatorKind::InlineAsm { .. } => false, } @@ -83,9 +84,9 @@ impl RemoveNoopLandingPads { fn remove_nop_landing_pads(&self, body: &mut Body<'_>) { debug!("body: {:#?}", body); - // make sure there's a single resume block + // make sure there's a resume block let resume_block = { - let patch = MirPatch::new(body); + let mut patch = MirPatch::new(body); let resume_block = patch.resume_block(); patch.apply(body); resume_block @@ -94,7 +95,7 @@ impl RemoveNoopLandingPads { let mut jumps_folded = 0; let mut landing_pads_removed = 0; - let mut nop_landing_pads = BitSet::new_empty(body.basic_blocks().len()); + let mut nop_landing_pads = BitSet::new_empty(body.basic_blocks.len()); // This is a post-order traversal, so that if A post-dominates B // then A will be visited before B. @@ -102,11 +103,11 @@ impl RemoveNoopLandingPads { for bb in postorder { debug!(" processing {:?}", bb); if let Some(unwind) = body[bb].terminator_mut().unwind_mut() { - if let Some(unwind_bb) = *unwind { + if let UnwindAction::Cleanup(unwind_bb) = *unwind { if nop_landing_pads.contains(unwind_bb) { debug!(" removing noop landing pad"); landing_pads_removed += 1; - *unwind = None; + *unwind = UnwindAction::Continue; } } } diff --git a/compiler/rustc_mir_transform/src/remove_place_mention.rs b/compiler/rustc_mir_transform/src/remove_place_mention.rs new file mode 100644 index 00000000000..8be1c37572d --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_place_mention.rs @@ -0,0 +1,23 @@ +//! This pass removes `PlaceMention` statement, which has no effect at codegen. + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct RemovePlaceMention; + +impl<'tcx> MirPass<'tcx> for RemovePlaceMention { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + !sess.opts.unstable_opts.mir_keep_place_mention + } + + fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running RemovePlaceMention on {:?}", body.source); + for data in body.basic_blocks.as_mut_preserves_cfg() { + data.statements.retain(|statement| match statement.kind { + StatementKind::PlaceMention(..) | StatementKind::Nop => false, + _ => true, + }) + } + } +} diff --git a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs index c48aa9a90ef..1f9e521d315 100644 --- a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs +++ b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs @@ -1,14 +1,15 @@ use rustc_index::bit_set::ChunkedBitSet; -use rustc_middle::mir::{Body, Field, Rvalue, Statement, StatementKind, TerminatorKind}; +use rustc_middle::mir::{Body, TerminatorKind}; use rustc_middle::ty::subst::SubstsRef; use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, VariantDef}; use rustc_mir_dataflow::impls::MaybeInitializedPlaces; use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex}; use rustc_mir_dataflow::{self, move_path_children_matching, Analysis, MoveDataParamEnv}; +use rustc_target::abi::FieldIdx; use crate::MirPass; -/// Removes `Drop` and `DropAndReplace` terminators whose target is known to be uninitialized at +/// Removes `Drop` terminators whose target is known to be uninitialized at /// that point. /// /// This is redundant with drop elaboration, but we need to do it prior to const-checking, and @@ -21,7 +22,7 @@ pub struct RemoveUninitDrops; impl<'tcx> MirPass<'tcx> for RemoveUninitDrops { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let param_env = tcx.param_env(body.source.def_id()); - let Ok(move_data) = MoveData::gather_moves(body, tcx, param_env) else { + let Ok((_,move_data)) = MoveData::gather_moves(body, tcx, param_env) else { // We could continue if there are move errors, but there's not much point since our // init data isn't complete. return; @@ -35,10 +36,9 @@ impl<'tcx> MirPass<'tcx> for RemoveUninitDrops { .into_results_cursor(body); let mut to_remove = vec![]; - for (bb, block) in body.basic_blocks().iter_enumerated() { + for (bb, block) in body.basic_blocks.iter_enumerated() { let terminator = block.terminator(); - let (TerminatorKind::Drop { place, .. } | TerminatorKind::DropAndReplace { place, .. }) - = &terminator.kind + let TerminatorKind::Drop { place, .. } = &terminator.kind else { continue }; maybe_inits.seek_before_primary_effect(body.terminator_loc(bb)); @@ -64,24 +64,12 @@ impl<'tcx> MirPass<'tcx> for RemoveUninitDrops { for bb in to_remove { let block = &mut body.basic_blocks_mut()[bb]; - let (TerminatorKind::Drop { target, .. } | TerminatorKind::DropAndReplace { target, .. }) + let TerminatorKind::Drop { target, .. } = &block.terminator().kind else { unreachable!() }; // Replace block terminator with `Goto`. - let target = *target; - let old_terminator_kind = std::mem::replace( - &mut block.terminator_mut().kind, - TerminatorKind::Goto { target }, - ); - - // If this is a `DropAndReplace`, we need to emulate the assignment to the return place. - if let TerminatorKind::DropAndReplace { place, value, .. } = old_terminator_kind { - block.statements.push(Statement { - source_info: block.terminator().source_info, - kind: StatementKind::Assign(Box::new((place, Rvalue::Use(value)))), - }); - } + block.terminator_mut().kind = TerminatorKind::Goto { target: *target }; } } } @@ -143,7 +131,7 @@ fn is_needs_drop_and_init<'tcx>( .fields .iter() .enumerate() - .map(|(f, field)| (Field::from_usize(f), field.ty(tcx, substs), mpi)) + .map(|(f, field)| (FieldIdx::from_usize(f), field.ty(tcx, substs), mpi)) .any(field_needs_drop_and_init) }) } @@ -151,7 +139,7 @@ fn is_needs_drop_and_init<'tcx>( ty::Tuple(fields) => fields .iter() .enumerate() - .map(|(f, f_ty)| (Field::from_usize(f), f_ty, mpi)) + .map(|(f, f_ty)| (FieldIdx::from_usize(f), f_ty, mpi)) .any(field_needs_drop_and_init), _ => true, diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs index 40be4f146db..1f37f03cff1 100644 --- a/compiler/rustc_mir_transform/src/remove_zsts.rs +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -1,8 +1,9 @@ -//! Removes assignments to ZST places. +//! Removes operations on ZST places, and convert ZST operands to constants. use crate::MirPass; -use rustc_middle::mir::tcx::PlaceTy; -use rustc_middle::mir::{Body, LocalDecls, Place, StatementKind}; +use rustc_middle::mir::interpret::ConstValue; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; use rustc_middle::ty::{self, Ty, TyCtxt}; pub struct RemoveZsts; @@ -14,49 +15,36 @@ impl<'tcx> MirPass<'tcx> for RemoveZsts { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // Avoid query cycles (generators require optimized MIR for layout). - if tcx.type_of(body.source.def_id()).is_generator() { + if tcx.type_of(body.source.def_id()).subst_identity().is_generator() { return; } - let param_env = tcx.param_env(body.source.def_id()); - let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); let local_decls = &body.local_decls; - for block in basic_blocks { - for statement in block.statements.iter_mut() { - if let StatementKind::Assign(box (place, _)) | StatementKind::Deinit(box place) = - statement.kind - { - let place_ty = place.ty(local_decls, tcx).ty; - if !maybe_zst(place_ty) { - continue; - } - let Ok(layout) = tcx.layout_of(param_env.and(place_ty)) else { - continue; - }; - if !layout.is_zst() { - continue; - } - if involves_a_union(place, local_decls, tcx) { - continue; - } - if tcx.consider_optimizing(|| { - format!( - "RemoveZsts - Place: {:?} SourceInfo: {:?}", - place, statement.source_info - ) - }) { - statement.make_nop(); - } - } - } + let mut replacer = Replacer { tcx, param_env, local_decls }; + for var_debug_info in &mut body.var_debug_info { + replacer.visit_var_debug_info(var_debug_info); + } + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + replacer.visit_basic_block_data(bb, data); } } } +struct Replacer<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + local_decls: &'a LocalDecls<'tcx>, +} + /// A cheap, approximate check to avoid unnecessary `layout_of` calls. fn maybe_zst(ty: Ty<'_>) -> bool { match ty.kind() { // maybe ZST (could be more precise) - ty::Adt(..) | ty::Array(..) | ty::Closure(..) | ty::Tuple(..) | ty::Opaque(..) => true, + ty::Adt(..) + | ty::Array(..) + | ty::Closure(..) + | ty::Tuple(..) + | ty::Alias(ty::Opaque, ..) => true, // definitely ZST ty::FnDef(..) | ty::Never => true, // unreachable or can't be ZST @@ -64,23 +52,92 @@ fn maybe_zst(ty: Ty<'_>) -> bool { } } -/// Miri lazily allocates memory for locals on assignment, -/// so we must preserve writes to unions and union fields, -/// or it will ICE on reads of those fields. -fn involves_a_union<'tcx>( - place: Place<'tcx>, - local_decls: &LocalDecls<'tcx>, - tcx: TyCtxt<'tcx>, -) -> bool { - let mut place_ty = PlaceTy::from_ty(local_decls[place.local].ty); - if place_ty.ty.is_union() { - return true; +impl<'tcx> Replacer<'_, 'tcx> { + fn known_to_be_zst(&self, ty: Ty<'tcx>) -> bool { + if !maybe_zst(ty) { + return false; + } + let Ok(layout) = self.tcx.layout_of(self.param_env.and(ty)) else { + return false; + }; + layout.is_zst() } - for elem in place.projection { - place_ty = place_ty.projection_ty(tcx, elem); - if place_ty.ty.is_union() { - return true; + + fn make_zst(&self, ty: Ty<'tcx>) -> Constant<'tcx> { + debug_assert!(self.known_to_be_zst(ty)); + Constant { + span: rustc_span::DUMMY_SP, + user_ty: None, + literal: ConstantKind::Val(ConstValue::ZeroSized, ty), + } + } +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { + match var_debug_info.value { + VarDebugInfoContents::Const(_) => {} + VarDebugInfoContents::Place(place) => { + let place_ty = place.ty(self.local_decls, self.tcx).ty; + if self.known_to_be_zst(place_ty) { + var_debug_info.value = VarDebugInfoContents::Const(self.make_zst(place_ty)) + } + } + VarDebugInfoContents::Composite { ty, fragments: _ } => { + if self.known_to_be_zst(ty) { + var_debug_info.value = VarDebugInfoContents::Const(self.make_zst(ty)) + } + } + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) { + if let Operand::Constant(_) = operand { + return; + } + let op_ty = operand.ty(self.local_decls, self.tcx); + if self.known_to_be_zst(op_ty) + && self.tcx.consider_optimizing(|| { + format!("RemoveZsts - Operand: {:?} Location: {:?}", operand, loc) + }) + { + *operand = Operand::Constant(Box::new(self.make_zst(op_ty))) + } + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, loc: Location) { + let place_for_ty = match statement.kind { + StatementKind::Assign(box (place, ref rvalue)) => { + rvalue.is_safe_to_remove().then_some(place) + } + StatementKind::Deinit(box place) + | StatementKind::SetDiscriminant { box place, variant_index: _ } + | StatementKind::AscribeUserType(box (place, _), _) + | StatementKind::Retag(_, box place) + | StatementKind::PlaceMention(box place) + | StatementKind::FakeRead(box (_, place)) => Some(place), + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + Some(local.into()) + } + StatementKind::Coverage(_) + | StatementKind::Intrinsic(_) + | StatementKind::Nop + | StatementKind::ConstEvalCounter => None, + }; + if let Some(place_for_ty) = place_for_ty + && let ty = place_for_ty.ty(self.local_decls, self.tcx).ty + && self.known_to_be_zst(ty) + && self.tcx.consider_optimizing(|| { + format!("RemoveZsts - Place: {:?} SourceInfo: {:?}", place_for_ty, statement.source_info) + }) + { + statement.make_nop(); + } else { + self.super_statement(statement, loc); } } - return false; } diff --git a/compiler/rustc_mir_transform/src/required_consts.rs b/compiler/rustc_mir_transform/src/required_consts.rs index 827ce0c02ac..243cb463560 100644 --- a/compiler/rustc_mir_transform/src/required_consts.rs +++ b/compiler/rustc_mir_transform/src/required_consts.rs @@ -1,5 +1,5 @@ use rustc_middle::mir::visit::Visitor; -use rustc_middle::mir::{Constant, Location}; +use rustc_middle::mir::{Constant, ConstantKind, Location}; use rustc_middle::ty::ConstKind; pub struct RequiredConstsVisitor<'a, 'tcx> { @@ -15,8 +15,13 @@ impl<'a, 'tcx> RequiredConstsVisitor<'a, 'tcx> { impl<'tcx> Visitor<'tcx> for RequiredConstsVisitor<'_, 'tcx> { fn visit_constant(&mut self, constant: &Constant<'tcx>, _: Location) { let literal = constant.literal; - if let Some(ct) = literal.const_for_ty() && let ConstKind::Unevaluated(_) = ct.kind() { - self.required_consts.push(*constant); + match literal { + ConstantKind::Ty(c) => match c.kind() { + ConstKind::Param(_) | ConstKind::Error(_) | ConstKind::Value(_) => {} + _ => bug!("only ConstKind::Param/Value should be encountered here, got {:#?}", c), + }, + ConstantKind::Unevaluated(..) => self.required_consts.push(*constant), + ConstantKind::Val(..) => {} } } } diff --git a/compiler/rustc_mir_transform/src/reveal_all.rs b/compiler/rustc_mir_transform/src/reveal_all.rs index 8ea550fa123..23442f8b97b 100644 --- a/compiler/rustc_mir_transform/src/reveal_all.rs +++ b/compiler/rustc_mir_transform/src/reveal_all.rs @@ -9,7 +9,7 @@ pub struct RevealAll; impl<'tcx> MirPass<'tcx> for RevealAll { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.opts.mir_opt_level() >= 3 || super::inline::Inline.is_enabled(sess) + sess.mir_opt_level() >= 3 || super::inline::Inline.is_enabled(sess) } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -19,7 +19,7 @@ impl<'tcx> MirPass<'tcx> for RevealAll { } let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - RevealAllVisitor { tcx, param_env }.visit_body(body); + RevealAllVisitor { tcx, param_env }.visit_body_preserves_cfg(body); } } @@ -35,10 +35,22 @@ impl<'tcx> MutVisitor<'tcx> for RevealAllVisitor<'tcx> { } #[inline] + fn visit_constant(&mut self, constant: &mut Constant<'tcx>, _: Location) { + // We have to use `try_normalize_erasing_regions` here, since it's + // possible that we visit impossible-to-satisfy where clauses here, + // see #91745 + if let Ok(c) = self.tcx.try_normalize_erasing_regions(self.param_env, constant.literal) { + constant.literal = c; + } + } + + #[inline] fn visit_ty(&mut self, ty: &mut Ty<'tcx>, _: TyContext) { // We have to use `try_normalize_erasing_regions` here, since it's // possible that we visit impossible-to-satisfy where clauses here, // see #91745 - *ty = self.tcx.try_normalize_erasing_regions(self.param_env, *ty).unwrap_or(*ty); + if let Ok(t) = self.tcx.try_normalize_erasing_regions(self.param_env, *ty) { + *ty = t; + } } } diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs index 925eb10a1f7..f35a5fb4276 100644 --- a/compiler/rustc_mir_transform/src/separate_const_switch.rs +++ b/compiler/rustc_mir_transform/src/separate_const_switch.rs @@ -46,7 +46,7 @@ pub struct SeparateConstSwitch; impl<'tcx> MirPass<'tcx> for SeparateConstSwitch { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 4 + sess.mir_opt_level() >= 2 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -62,7 +62,7 @@ impl<'tcx> MirPass<'tcx> for SeparateConstSwitch { pub fn separate_const_switch(body: &mut Body<'_>) -> usize { let mut new_blocks: SmallVec<[(BasicBlock, BasicBlock); 6]> = SmallVec::new(); let predecessors = body.basic_blocks.predecessors(); - 'block_iter: for (block_id, block) in body.basic_blocks().iter_enumerated() { + 'block_iter: for (block_id, block) in body.basic_blocks.iter_enumerated() { if let TerminatorKind::SwitchInt { discr: Operand::Copy(switch_place) | Operand::Move(switch_place), .. @@ -90,7 +90,7 @@ pub fn separate_const_switch(body: &mut Body<'_>) -> usize { let mut predecessors_left = predecessors[block_id].len(); 'predec_iter: for predecessor_id in predecessors[block_id].iter().copied() { - let predecessor = &body.basic_blocks()[predecessor_id]; + let predecessor = &body.basic_blocks[predecessor_id]; // First we make sure the predecessor jumps // in a reasonable way @@ -108,12 +108,11 @@ pub fn separate_const_switch(body: &mut Body<'_>) -> usize { // The following terminators are not allowed TerminatorKind::Resume | TerminatorKind::Drop { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::Call { .. } | TerminatorKind::Assert { .. } | TerminatorKind::FalseUnwind { .. } | TerminatorKind::Yield { .. } - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::InlineAsm { .. } @@ -165,12 +164,11 @@ pub fn separate_const_switch(body: &mut Body<'_>) -> usize { } TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Terminate | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::GeneratorDrop | TerminatorKind::Assert { .. } - | TerminatorKind::DropAndReplace { .. } | TerminatorKind::FalseUnwind { .. } | TerminatorKind::Drop { .. } | TerminatorKind::Call { .. } @@ -247,9 +245,11 @@ fn is_likely_const<'tcx>(mut tracked_place: Place<'tcx>, block: &BasicBlockData< | StatementKind::StorageLive(_) | StatementKind::Retag(_, _) | StatementKind::AscribeUserType(_, _) + | StatementKind::PlaceMention(..) | StatementKind::Coverage(_) | StatementKind::StorageDead(_) - | StatementKind::CopyNonOverlapping(_) + | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter | StatementKind::Nop => {} } } @@ -303,8 +303,7 @@ fn find_determining_place<'tcx>( | Rvalue::NullaryOp(_, _) | Rvalue::ShallowInitBox(_, _) | Rvalue::UnaryOp(_, Operand::Constant(_)) - | Rvalue::Cast(_, Operand::Constant(_), _) - => return None, + | Rvalue::Cast(_, Operand::Constant(_), _) => return None, } } @@ -316,8 +315,10 @@ fn find_determining_place<'tcx>( | StatementKind::StorageDead(_) | StatementKind::Retag(_, _) | StatementKind::AscribeUserType(_, _) + | StatementKind::PlaceMention(..) | StatementKind::Coverage(_) - | StatementKind::CopyNonOverlapping(_) + | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter | StatementKind::Nop => {} // If the discriminant is set, it is always set diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index eaa61d8614d..5f12f1937c0 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -2,12 +2,12 @@ use rustc_hir as hir; use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; -use rustc_middle::ty::query::Providers; -use rustc_middle::ty::subst::{InternalSubsts, Subst}; -use rustc_middle::ty::{self, EarlyBinder, Ty, TyCtxt}; -use rustc_target::abi::VariantIdx; +use rustc_middle::query::Providers; +use rustc_middle::ty::InternalSubsts; +use rustc_middle::ty::{self, EarlyBinder, GeneratorSubsts, Ty, TyCtxt}; +use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; -use rustc_index::vec::{Idx, IndexVec}; +use rustc_index::{Idx, IndexVec}; use rustc_span::Span; use rustc_target::spec::abi::Abi; @@ -15,10 +15,9 @@ use rustc_target::spec::abi::Abi; use std::fmt; use std::iter; -use crate::util::expand_aggregate; use crate::{ - abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, marker, pass_manager as pm, - remove_noop_landing_pads, simplify, + abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator, + pass_manager as pm, remove_noop_landing_pads, simplify, }; use rustc_middle::mir::patch::MirPatch; use rustc_mir_dataflow::elaborate_drops::{self, DropElaborator, DropFlagMode, DropStyle}; @@ -32,12 +31,12 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' let mut result = match instance { ty::InstanceDef::Item(..) => bug!("item {:?} passed to make_shim", instance), - ty::InstanceDef::VtableShim(def_id) => { + ty::InstanceDef::VTableShim(def_id) => { build_call_shim(tcx, instance, Some(Adjustment::Deref), CallKind::Direct(def_id)) } ty::InstanceDef::FnPtrShim(def_id, ty) => { let trait_ = tcx.trait_of_item(def_id).unwrap(); - let adjustment = match tcx.fn_trait_kind_from_lang_item(trait_) { + let adjustment = match tcx.fn_trait_kind_from_def_id(trait_) { Some(ty::ClosureKind::FnOnce) => Adjustment::Identity, Some(ty::ClosureKind::FnMut | ty::ClosureKind::Fn) => Adjustment::Deref, None => bug!("fn pointer {:?} is not an fn", ty), @@ -70,14 +69,16 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' // of this function. Is this intentional? if let Some(ty::Generator(gen_def_id, substs, _)) = ty.map(Ty::kind) { let body = tcx.optimized_mir(*gen_def_id).generator_drop().unwrap(); - let body = EarlyBinder(body.clone()).subst(tcx, substs); + let body = EarlyBinder::bind(body.clone()).subst(tcx, substs); debug!("make_shim({:?}) = {:?}", instance, body); return body; } build_drop_shim(tcx, def_id, ty) } + ty::InstanceDef::ThreadLocalShim(..) => build_thread_local_shim(tcx, instance), ty::InstanceDef::CloneShim(def_id, ty) => build_clone_shim(tcx, def_id, ty), + ty::InstanceDef::FnPtrAddrShim(def_id, ty) => build_fn_ptr_addr_shim(tcx, def_id, ty), ty::InstanceDef::Virtual(..) => { bug!("InstanceDef::Virtual ({:?}) is for direct calls only", instance) } @@ -92,12 +93,13 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' &mut result, &[ &add_moves_for_packed_drops::AddMovesForPackedDrops, + &deref_separator::Derefer, &remove_noop_landing_pads::RemoveNoopLandingPads, - &simplify::SimplifyCfg::new("make_shim"), + &simplify::SimplifyCfg::MakeShim, &add_call_guards::CriticalCallEdges, &abort_unwinding_calls::AbortUnwindingCalls, - &marker::PhaseChange(MirPhase::Const), ], + Some(MirPhase::Runtime(RuntimePhase::Optimized)), ); debug!("make_shim({:?}) = {:?}", instance, result); @@ -113,7 +115,7 @@ enum Adjustment { /// We get passed `&[mut] self` and call the target with `*self`. /// /// This either copies `self` (if `Self: Copy`, eg. for function items), or moves out of it - /// (for `VtableShim`, which effectively is passed `&own Self`). + /// (for `VTableShim`, which effectively is passed `&own Self`). Deref, /// We get passed `self: Self` and call the target with `&mut self`. @@ -147,11 +149,11 @@ fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) assert!(!matches!(ty, Some(ty) if ty.is_generator())); let substs = if let Some(ty) = ty { - tcx.intern_substs(&[ty.into()]) + tcx.mk_substs(&[ty.into()]) } else { InternalSubsts::identity_for_item(tcx, def_id) }; - let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs); + let sig = tcx.fn_sig(def_id).subst(tcx, substs); let sig = tcx.erase_late_bound_regions(sig); let span = tcx.def_span(def_id); @@ -173,19 +175,36 @@ fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) let mut body = new_body(source, blocks, local_decls_for_sig(&sig, span), sig.inputs().len(), span); - if ty.is_some() { - // The first argument (index 0), but add 1 for the return value. - let dropee_ptr = Place::from(Local::new(1 + 0)); - if tcx.sess.opts.unstable_opts.mir_emit_retag { - // Function arguments should be retagged, and we make this one raw. - body.basic_blocks_mut()[START_BLOCK].statements.insert( - 0, - Statement { - source_info, - kind: StatementKind::Retag(RetagKind::Raw, Box::new(dropee_ptr)), - }, - ); + // The first argument (index 0), but add 1 for the return value. + let mut dropee_ptr = Place::from(Local::new(1 + 0)); + if tcx.sess.opts.unstable_opts.mir_emit_retag { + // We want to treat the function argument as if it was passed by `&mut`. As such, we + // generate + // ``` + // temp = &mut *arg; + // Retag(temp, FnEntry) + // ``` + // It's important that we do this first, before anything that depends on `dropee_ptr` + // has been put into the body. + let reborrow = Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { allow_two_phase_borrow: false }, + tcx.mk_place_deref(dropee_ptr), + ); + let ref_ty = reborrow.ty(body.local_decls(), tcx); + dropee_ptr = body.local_decls.push(LocalDecl::new(ref_ty, span)).into(); + let new_statements = [ + StatementKind::Assign(Box::new((dropee_ptr, reborrow))), + StatementKind::Retag(RetagKind::FnEntry, Box::new(dropee_ptr)), + ]; + for s in new_statements { + body.basic_blocks_mut()[START_BLOCK] + .statements + .push(Statement { source_info, kind: s }); } + } + + if ty.is_some() { let patch = { let param_env = tcx.param_env_reveal_all_normalized(def_id); let mut elaborator = @@ -290,7 +309,7 @@ impl<'a, 'tcx> DropElaborator<'a, 'tcx> for DropShimElaborator<'a, 'tcx> { fn clear_drop_flag(&mut self, _location: Location, _path: Self::Path, _mode: DropFlagMode) {} - fn field_subpath(&self, _path: Self::Path, _field: Field) -> Option<Self::Path> { + fn field_subpath(&self, _path: Self::Path, _field: FieldIdx) -> Option<Self::Path> { None } fn deref_subpath(&self, _path: Self::Path) -> Option<Self::Path> { @@ -304,14 +323,42 @@ impl<'a, 'tcx> DropElaborator<'a, 'tcx> for DropShimElaborator<'a, 'tcx> { } } +fn build_thread_local_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'tcx> { + let def_id = instance.def_id(); + + let span = tcx.def_span(def_id); + let source_info = SourceInfo::outermost(span); + + let mut blocks = IndexVec::with_capacity(1); + blocks.push(BasicBlockData { + statements: vec![Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::ThreadLocalRef(def_id), + ))), + }], + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }); + + new_body( + MirSource::from_instance(instance), + blocks, + IndexVec::from_raw(vec![LocalDecl::new(tcx.thread_local_ptr_ty(def_id), span)]), + 0, + span, + ) +} + /// Builds a `Clone::clone` shim for `self_ty`. Here, `def_id` is `Clone::clone`. fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Body<'tcx> { debug!("build_clone_shim(def_id={:?})", def_id); - let param_env = tcx.param_env(def_id); + let param_env = tcx.param_env_reveal_all_normalized(def_id); let mut builder = CloneShimBuilder::new(tcx, def_id, self_ty); - let is_copy = self_ty.is_copy_modulo_regions(tcx.at(builder.span), param_env); + let is_copy = self_ty.is_copy_modulo_regions(tcx, param_env); let dest = Place::return_place(); let src = tcx.mk_place_deref(Place::from(Local::new(1 + 0))); @@ -322,6 +369,9 @@ fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) - builder.tuple_like_shim(dest, src, substs.as_closure().upvar_tys()) } ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()), + ty::Generator(gen_def_id, substs, hir::Movability::Movable) => { + builder.generator_shim(dest, src, *gen_def_id, substs.as_generator()) + } _ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty), }; @@ -342,8 +392,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { // we must subst the self_ty because it's // otherwise going to be TySelf and we can't index // or access fields of a Place of type TySelf. - let substs = tcx.mk_substs_trait(self_ty, &[]); - let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs); + let sig = tcx.fn_sig(def_id).subst(tcx, &[self_ty.into()]); let sig = tcx.erase_late_bound_regions(sig); let span = tcx.def_span(def_id); @@ -387,7 +436,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { /// offset=0 will give you the index of the next BasicBlock, /// offset=1 will give the index of the next-to-next block, /// offset=-1 will give you the index of the last-created block - fn block_index_offset(&mut self, offset: usize) -> BasicBlock { + fn block_index_offset(&self, offset: usize) -> BasicBlock { BasicBlock::new(self.blocks.len() + offset) } @@ -407,7 +456,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { fn make_place(&mut self, mutability: Mutability, ty: Ty<'tcx>) -> Place<'tcx> { let span = self.span; let mut local = LocalDecl::new(ty, span); - if mutability == Mutability::Not { + if mutability.is_not() { local = local.immutable(); } Place::from(self.local_decls.push(local)) @@ -423,10 +472,8 @@ impl<'tcx> CloneShimBuilder<'tcx> { ) { let tcx = self.tcx; - let substs = tcx.mk_substs_trait(ty, &[]); - // `func == Clone::clone(&ty) -> ty` - let func_ty = tcx.mk_fn_def(self.def_id, substs); + let func_ty = tcx.mk_fn_def(self.def_id, [ty]); let func = Operand::Constant(Box::new(Constant { span: self.span, user_ty: None, @@ -452,7 +499,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { args: vec![Operand::Move(ref_loc)], destination: dest, target: Some(next), - cleanup: Some(cleanup), + unwind: UnwindAction::Cleanup(cleanup), from_hir_call: true, fn_span: self.span, }, @@ -460,65 +507,122 @@ impl<'tcx> CloneShimBuilder<'tcx> { ); } - fn tuple_like_shim<I>(&mut self, dest: Place<'tcx>, src: Place<'tcx>, tys: I) + fn clone_fields<I>( + &mut self, + dest: Place<'tcx>, + src: Place<'tcx>, + target: BasicBlock, + mut unwind: BasicBlock, + tys: I, + ) -> BasicBlock where I: IntoIterator<Item = Ty<'tcx>>, { - let mut previous_field = None; + // For an iterator of length n, create 2*n + 1 blocks. for (i, ity) in tys.into_iter().enumerate() { - let field = Field::new(i); + // Each iteration creates two blocks, referred to here as block 2*i and block 2*i + 1. + // + // Block 2*i attempts to clone the field. If successful it branches to 2*i + 2 (the + // next clone block). If unsuccessful it branches to the previous unwind block, which + // is initially the `unwind` argument passed to this function. + // + // Block 2*i + 1 is the unwind block for this iteration. It drops the cloned value + // created by block 2*i. We store this block in `unwind` so that the next clone block + // will unwind to it if cloning fails. + + let field = FieldIdx::new(i); let src_field = self.tcx.mk_place_field(src, field, ity); let dest_field = self.tcx.mk_place_field(dest, field, ity); - // #(2i + 1) is the cleanup block for the previous clone operation - let cleanup_block = self.block_index_offset(1); - // #(2i + 2) is the next cloning block - // (or the Return terminator if this is the last block) + let next_unwind = self.block_index_offset(1); let next_block = self.block_index_offset(2); + self.make_clone_call(dest_field, src_field, ity, next_block, unwind); + self.block( + vec![], + TerminatorKind::Drop { + place: dest_field, + target: unwind, + unwind: UnwindAction::Terminate, + replace: false, + }, + true, + ); + unwind = next_unwind; + } + // If all clones succeed then we end up here. + self.block(vec![], TerminatorKind::Goto { target }, false); + unwind + } - // BB #(2i) - // `dest.i = Clone::clone(&src.i);` - // Goto #(2i + 2) if ok, #(2i + 1) if unwinding happens. - self.make_clone_call(dest_field, src_field, ity, next_block, cleanup_block); - - // BB #(2i + 1) (cleanup) - if let Some((previous_field, previous_cleanup)) = previous_field.take() { - // Drop previous field and goto previous cleanup block. - self.block( - vec![], - TerminatorKind::Drop { - place: previous_field, - target: previous_cleanup, - unwind: None, - }, - true, - ); - } else { - // Nothing to drop, just resume. - self.block(vec![], TerminatorKind::Resume, true); - } + fn tuple_like_shim<I>(&mut self, dest: Place<'tcx>, src: Place<'tcx>, tys: I) + where + I: IntoIterator<Item = Ty<'tcx>>, + { + self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false); + let unwind = self.block(vec![], TerminatorKind::Resume, true); + let target = self.block(vec![], TerminatorKind::Return, false); - previous_field = Some((dest_field, cleanup_block)); - } + let _final_cleanup_block = self.clone_fields(dest, src, target, unwind, tys); + } - self.block(vec![], TerminatorKind::Return, false); + fn generator_shim( + &mut self, + dest: Place<'tcx>, + src: Place<'tcx>, + gen_def_id: DefId, + substs: GeneratorSubsts<'tcx>, + ) { + self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false); + let unwind = self.block(vec![], TerminatorKind::Resume, true); + // This will get overwritten with a switch once we know the target blocks + let switch = self.block(vec![], TerminatorKind::Unreachable, false); + let unwind = self.clone_fields(dest, src, switch, unwind, substs.upvar_tys()); + let target = self.block(vec![], TerminatorKind::Return, false); + let unreachable = self.block(vec![], TerminatorKind::Unreachable, false); + let mut cases = Vec::with_capacity(substs.state_tys(gen_def_id, self.tcx).count()); + for (index, state_tys) in substs.state_tys(gen_def_id, self.tcx).enumerate() { + let variant_index = VariantIdx::new(index); + let dest = self.tcx.mk_place_downcast_unnamed(dest, variant_index); + let src = self.tcx.mk_place_downcast_unnamed(src, variant_index); + let clone_block = self.block_index_offset(1); + let start_block = self.block( + vec![self.make_statement(StatementKind::SetDiscriminant { + place: Box::new(Place::return_place()), + variant_index, + })], + TerminatorKind::Goto { target: clone_block }, + false, + ); + cases.push((index as u128, start_block)); + let _final_cleanup_block = self.clone_fields(dest, src, target, unwind, state_tys); + } + let discr_ty = substs.discr_ty(self.tcx); + let temp = self.make_place(Mutability::Mut, discr_ty); + let rvalue = Rvalue::Discriminant(src); + let statement = self.make_statement(StatementKind::Assign(Box::new((temp, rvalue)))); + match &mut self.blocks[switch] { + BasicBlockData { statements, terminator: Some(Terminator { kind, .. }), .. } => { + statements.push(statement); + *kind = TerminatorKind::SwitchInt { + discr: Operand::Move(temp), + targets: SwitchTargets::new(cases.into_iter(), unreachable), + }; + } + BasicBlockData { terminator: None, .. } => unreachable!(), + } } } /// Builds a "call" shim for `instance`. The shim calls the function specified by `call_kind`, /// first adjusting its first argument according to `rcvr_adjustment`. +#[instrument(level = "debug", skip(tcx), ret)] fn build_call_shim<'tcx>( tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>, rcvr_adjustment: Option<Adjustment>, call_kind: CallKind<'tcx>, ) -> Body<'tcx> { - debug!( - "build_call_shim(instance={:?}, rcvr_adjustment={:?}, call_kind={:?})", - instance, rcvr_adjustment, call_kind - ); - // `FnPtrShim` contains the fn pointer type that a call shim is being built for - this is used // to substitute into the signature of the shim. It is not necessary for users of this // MIR body to perform further substitutions (see `InstanceDef::has_polymorphic_mir_body`). @@ -528,21 +632,23 @@ fn build_call_shim<'tcx>( let untuple_args = sig.inputs(); // Create substitutions for the `Self` and `Args` generic parameters of the shim body. - let arg_tup = tcx.mk_tup(untuple_args.iter()); - let sig_substs = tcx.mk_substs_trait(ty, &[ty::subst::GenericArg::from(arg_tup)]); + let arg_tup = tcx.mk_tup(untuple_args); - (Some(sig_substs), Some(untuple_args)) + (Some([ty.into(), arg_tup.into()]), Some(untuple_args)) } else { (None, None) }; let def_id = instance.def_id(); - let sig = tcx.bound_fn_sig(def_id); + let sig = tcx.fn_sig(def_id); let sig = sig.map_bound(|sig| tcx.erase_late_bound_regions(sig)); assert_eq!(sig_substs.is_some(), !instance.has_polymorphic_mir_body()); - let mut sig = - if let Some(sig_substs) = sig_substs { sig.subst(tcx, sig_substs) } else { sig.0 }; + let mut sig = if let Some(sig_substs) = sig_substs { + sig.subst(tcx, &sig_substs) + } else { + sig.subst_identity() + }; if let CallKind::Indirect(fnty) = call_kind { // `sig` determines our local decls, and thus the callee type in the `Call` terminator. This @@ -564,23 +670,23 @@ fn build_call_shim<'tcx>( Adjustment::Deref => tcx.mk_imm_ptr(fnty), Adjustment::RefMut => tcx.mk_mut_ptr(fnty), }; - sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output); + sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output); } // FIXME(eddyb) avoid having this snippet both here and in // `Instance::fn_sig` (introduce `InstanceDef::fn_sig`?). - if let ty::InstanceDef::VtableShim(..) = instance { + if let ty::InstanceDef::VTableShim(..) = instance { // Modify fn(self, ...) to fn(self: *mut Self, ...) let mut inputs_and_output = sig.inputs_and_output.to_vec(); let self_arg = &mut inputs_and_output[0]; debug_assert!(tcx.generics_of(def_id).has_self && *self_arg == tcx.types.self_param); *self_arg = tcx.mk_mut_ptr(*self_arg); - sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output); + sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output); } let span = tcx.def_span(def_id); - debug!("build_call_shim: sig={:?}", sig); + debug!(?sig); let mut local_decls = local_decls_for_sig(&sig, span); let source_info = SourceInfo::outermost(span); @@ -624,7 +730,7 @@ fn build_call_shim<'tcx>( // `FnDef` call with optional receiver. CallKind::Direct(def_id) => { - let ty = tcx.type_of(def_id); + let ty = tcx.type_of(def_id).subst_identity(); ( Operand::Constant(Box::new(Constant { span, @@ -655,7 +761,7 @@ fn build_call_shim<'tcx>( if let Some(untuple_args) = untuple_args { let tuple_arg = Local::new(1 + (sig.inputs().len() - 1)); args.extend(untuple_args.iter().enumerate().map(|(i, ity)| { - Operand::Move(tcx.mk_place_field(Place::from(tuple_arg), Field::new(i), *ity)) + Operand::Move(tcx.mk_place_field(Place::from(tuple_arg), FieldIdx::new(i), *ity)) })); } @@ -678,10 +784,10 @@ fn build_call_shim<'tcx>( args, destination: Place::return_place(), target: Some(BasicBlock::new(1)), - cleanup: if let Some(Adjustment::RefMut) = rcvr_adjustment { - Some(BasicBlock::new(3)) + unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment { + UnwindAction::Cleanup(BasicBlock::new(3)) } else { - None + UnwindAction::Continue }, from_hir_call: true, fn_span: span, @@ -694,7 +800,12 @@ fn build_call_shim<'tcx>( block( &mut blocks, vec![], - TerminatorKind::Drop { place: rcvr_place(), target: BasicBlock::new(2), unwind: None }, + TerminatorKind::Drop { + place: rcvr_place(), + target: BasicBlock::new(2), + unwind: UnwindAction::Continue, + replace: false, + }, false, ); } @@ -705,7 +816,12 @@ fn build_call_shim<'tcx>( block( &mut blocks, vec![], - TerminatorKind::Drop { place: rcvr_place(), target: BasicBlock::new(4), unwind: None }, + TerminatorKind::Drop { + place: rcvr_place(), + target: BasicBlock::new(4), + unwind: UnwindAction::Terminate, + replace: false, + }, true, ); @@ -726,10 +842,14 @@ fn build_call_shim<'tcx>( pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { debug_assert!(tcx.is_constructor(ctor_id)); - let param_env = tcx.param_env(ctor_id); + let param_env = tcx.param_env_reveal_all_normalized(ctor_id); // Normalize the sig. - let sig = tcx.fn_sig(ctor_id).no_bound_vars().expect("LBR in ADT constructor signature"); + let sig = tcx + .fn_sig(ctor_id) + .subst_identity() + .no_bound_vars() + .expect("LBR in ADT constructor signature"); let sig = tcx.normalize_erasing_regions(param_env, sig); let ty::Adt(adt_def, substs) = sig.output().kind() else { @@ -744,11 +864,8 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { let source_info = SourceInfo::outermost(span); - let variant_index = if adt_def.is_enum() { - adt_def.variant_index_with_ctor_id(ctor_id) - } else { - VariantIdx::new(0) - }; + let variant_index = + if adt_def.is_enum() { adt_def.variant_index_with_ctor_id(ctor_id) } else { FIRST_VARIANT }; // Generate the following MIR: // @@ -758,19 +875,23 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { // return; debug!("build_ctor: variant_index={:?}", variant_index); - let statements = expand_aggregate( - Place::return_place(), - adt_def.variant(variant_index).fields.iter().enumerate().map(|(idx, field_def)| { - (Operand::Move(Place::from(Local::new(idx + 1))), field_def.ty(tcx, substs)) - }), - AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None), + let kind = AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None); + let variant = adt_def.variant(variant_index); + let statement = Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate( + Box::new(kind), + (0..variant.fields.len()) + .map(|idx| Operand::Move(Place::from(Local::new(idx + 1)))) + .collect(), + ), + ))), source_info, - tcx, - ) - .collect(); + }; let start_block = BasicBlockData { - statements, + statements: vec![statement], terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), is_cleanup: false, }; @@ -784,7 +905,43 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { span, ); - rustc_middle::mir::dump_mir(tcx, None, "mir_map", &0, &body, |_, _| Ok(())); + crate::pass_manager::dump_mir_for_phase_change(tcx, &body); body } + +/// ```ignore (pseudo-impl) +/// impl FnPtr for fn(u32) { +/// fn addr(self) -> usize { +/// self as usize +/// } +/// } +/// ``` +fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Body<'tcx> { + assert!(matches!(self_ty.kind(), ty::FnPtr(..)), "expected fn ptr, found {self_ty}"); + let span = tcx.def_span(def_id); + let Some(sig) = tcx.fn_sig(def_id).subst(tcx, &[self_ty.into()]).no_bound_vars() else { + span_bug!(span, "FnPtr::addr with bound vars for `{self_ty}`"); + }; + let locals = local_decls_for_sig(&sig, span); + + let source_info = SourceInfo::outermost(span); + // FIXME: use `expose_addr` once we figure out whether function pointers have meaningful provenance. + let rvalue = Rvalue::Cast( + CastKind::FnPtrToPtr, + Operand::Move(Place::from(Local::new(1))), + tcx.mk_imm_ptr(tcx.types.unit), + ); + let stmt = Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), + }; + let statements = vec![stmt]; + let start_block = BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + let source = MirSource::from_instance(ty::InstanceDef::FnPtrAddrShim(def_id, self_ty)); + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) +} diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs index 980af984362..e59219321b7 100644 --- a/compiler/rustc_mir_transform/src/simplify.rs +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -28,27 +28,45 @@ //! return. use crate::MirPass; -use rustc_index::vec::{Idx, IndexVec}; +use rustc_data_structures::fx::{FxHashSet, FxIndexSet}; +use rustc_index::{Idx, IndexSlice, IndexVec}; use rustc_middle::mir::coverage::*; use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; use smallvec::SmallVec; -use std::borrow::Cow; -use std::convert::TryInto; -pub struct SimplifyCfg { - label: String, +pub enum SimplifyCfg { + Initial, + PromoteConsts, + RemoveFalseEdges, + EarlyOpt, + ElaborateDrops, + Final, + MakeShim, + AfterUninhabitedEnumBranching, } impl SimplifyCfg { - pub fn new(label: &str) -> Self { - SimplifyCfg { label: format!("SimplifyCfg-{}", label) } + pub fn name(&self) -> &'static str { + match self { + SimplifyCfg::Initial => "SimplifyCfg-initial", + SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts", + SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges", + SimplifyCfg::EarlyOpt => "SimplifyCfg-early-opt", + SimplifyCfg::ElaborateDrops => "SimplifyCfg-elaborate-drops", + SimplifyCfg::Final => "SimplifyCfg-final", + SimplifyCfg::MakeShim => "SimplifyCfg-make_shim", + SimplifyCfg::AfterUninhabitedEnumBranching => { + "SimplifyCfg-after-uninhabited-enum-branching" + } + } } } pub fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { CfgSimplifier::new(body).simplify(); + remove_duplicate_unreachable_blocks(tcx, body); remove_dead_blocks(tcx, body); // FIXME: Should probably be moved into some kind of pass manager @@ -56,24 +74,24 @@ pub fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { } impl<'tcx> MirPass<'tcx> for SimplifyCfg { - fn name(&self) -> Cow<'_, str> { - Cow::Borrowed(&self.label) + fn name(&self) -> &'static str { + &self.name() } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - debug!("SimplifyCfg({:?}) - simplifying {:?}", self.label, body.source); + debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source); simplify_cfg(tcx, body); } } pub struct CfgSimplifier<'a, 'tcx> { - basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>, + basic_blocks: &'a mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>, pred_count: IndexVec<BasicBlock, u32>, } impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> { pub fn new(body: &'a mut Body<'tcx>) -> Self { - let mut pred_count = IndexVec::from_elem(0u32, body.basic_blocks()); + let mut pred_count = IndexVec::from_elem(0u32, &body.basic_blocks); // we can't use mir.predecessors() here because that counts // dead blocks, which we don't want to. @@ -260,14 +278,72 @@ impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> { } } +pub fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) { + if let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind { + let otherwise = targets.otherwise(); + if targets.iter().any(|t| t.1 == otherwise) { + *targets = SwitchTargets::new( + targets.iter().filter(|t| t.1 != otherwise), + targets.otherwise(), + ); + } + } +} + +pub fn remove_duplicate_unreachable_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + struct OptApplier<'tcx> { + tcx: TyCtxt<'tcx>, + duplicates: FxIndexSet<BasicBlock>, + } + + impl<'tcx> MutVisitor<'tcx> for OptApplier<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { + for target in terminator.successors_mut() { + // We don't have to check whether `target` is a cleanup block, because have + // entirely excluded cleanup blocks in building the set of duplicates. + if self.duplicates.contains(target) { + *target = self.duplicates[0]; + } + } + + simplify_duplicate_switch_targets(terminator); + + self.super_terminator(terminator, location); + } + } + + let unreachable_blocks = body + .basic_blocks + .iter_enumerated() + .filter(|(_, bb)| { + // CfgSimplifier::simplify leaves behind some unreachable basic blocks without a + // terminator. Those blocks will be deleted by remove_dead_blocks, but we run just + // before then so we need to handle missing terminators. + // We also need to prevent confusing cleanup and non-cleanup blocks. In practice we + // don't emit empty unreachable cleanup blocks, so this simple check suffices. + bb.terminator.is_some() && bb.is_empty_unreachable() && !bb.is_cleanup + }) + .map(|(block, _)| block) + .collect::<FxIndexSet<_>>(); + + if unreachable_blocks.len() > 1 { + OptApplier { tcx, duplicates: unreachable_blocks }.visit_body(body); + } +} + pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let reachable = traversal::reachable_as_bitset(body); - let num_blocks = body.basic_blocks().len(); + let num_blocks = body.basic_blocks.len(); if num_blocks == reachable.count() { return; } - let basic_blocks = body.basic_blocks_mut(); + let basic_blocks = body.basic_blocks.as_mut(); + let source_scopes = &body.source_scopes; let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect(); let mut used_blocks = 0; for alive_index in reachable.iter() { @@ -282,7 +358,7 @@ pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { } if tcx.sess.instrument_coverage() { - save_unreachable_coverage(basic_blocks, used_blocks); + save_unreachable_coverage(basic_blocks, source_scopes, used_blocks); } basic_blocks.raw.truncate(used_blocks); @@ -311,61 +387,87 @@ pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { /// `Unreachable` coverage statements. These are non-executable statements whose /// code regions are still recorded in the coverage map, representing regions /// with `0` executions. +/// +/// If there are no live `Counter` `Coverage` statements remaining, we remove +/// `Coverage` statements along with the dead blocks. Since at least one +/// counter per function is required by LLVM (and necessary, to add the +/// `function_hash` to the counter's call to the LLVM intrinsic +/// `instrprof.increment()`). +/// +/// The `generator::StateTransform` MIR pass and MIR inlining can create +/// atypical conditions, where all live `Counter`s are dropped from the MIR. +/// +/// With MIR inlining we can have coverage counters belonging to different +/// instances in a single body, so the strategy described above is applied to +/// coverage counters from each instance individually. fn save_unreachable_coverage( - basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, + basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'_>>, + source_scopes: &IndexSlice<SourceScope, SourceScopeData<'_>>, first_dead_block: usize, ) { - let has_live_counters = basic_blocks.raw[0..first_dead_block].iter().any(|live_block| { - live_block.statements.iter().any(|statement| { - if let StatementKind::Coverage(coverage) = &statement.kind { - matches!(coverage.kind, CoverageKind::Counter { .. }) - } else { - false + // Identify instances that still have some live coverage counters left. + let mut live = FxHashSet::default(); + for basic_block in &basic_blocks.raw[0..first_dead_block] { + for statement in &basic_block.statements { + let StatementKind::Coverage(coverage) = &statement.kind else { continue }; + let CoverageKind::Counter { .. } = coverage.kind else { continue }; + let instance = statement.source_info.scope.inlined_instance(source_scopes); + live.insert(instance); + } + } + + for block in &mut basic_blocks.raw[..first_dead_block] { + for statement in &mut block.statements { + let StatementKind::Coverage(_) = &statement.kind else { continue }; + let instance = statement.source_info.scope.inlined_instance(source_scopes); + if !live.contains(&instance) { + statement.make_nop(); } - }) - }); - if !has_live_counters { - // If there are no live `Counter` `Coverage` statements anymore, don't - // move dead coverage to the `START_BLOCK`. Just allow the dead - // `Coverage` statements to be dropped with the dead blocks. - // - // The `generator::StateTransform` MIR pass can create atypical - // conditions, where all live `Counter`s are dropped from the MIR. - // - // At least one Counter per function is required by LLVM (and necessary, - // to add the `function_hash` to the counter's call to the LLVM - // intrinsic `instrprof.increment()`). + } + } + + if live.is_empty() { return; } - // Retain coverage info for dead blocks, so coverage reports will still - // report `0` executions for the uncovered code regions. - let mut dropped_coverage = Vec::new(); - for dead_block in basic_blocks.raw[first_dead_block..].iter() { - for statement in dead_block.statements.iter() { - if let StatementKind::Coverage(coverage) = &statement.kind { - if let Some(code_region) = &coverage.code_region { - dropped_coverage.push((statement.source_info, code_region.clone())); - } + // Retain coverage for instances that still have some live counters left. + let mut retained_coverage = Vec::new(); + for dead_block in &basic_blocks.raw[first_dead_block..] { + for statement in &dead_block.statements { + let StatementKind::Coverage(coverage) = &statement.kind else { continue }; + let Some(code_region) = &coverage.code_region else { continue }; + let instance = statement.source_info.scope.inlined_instance(source_scopes); + if live.contains(&instance) { + retained_coverage.push((statement.source_info, code_region.clone())); } } } let start_block = &mut basic_blocks[START_BLOCK]; - for (source_info, code_region) in dropped_coverage { - start_block.statements.push(Statement { + start_block.statements.extend(retained_coverage.into_iter().map( + |(source_info, code_region)| Statement { source_info, kind: StatementKind::Coverage(Box::new(Coverage { kind: CoverageKind::Unreachable, code_region: Some(code_region), })), - }) - } + }, + )); } -pub struct SimplifyLocals; +pub enum SimplifyLocals { + BeforeConstProp, + Final, +} impl<'tcx> MirPass<'tcx> for SimplifyLocals { + fn name(&self) -> &'static str { + match &self { + SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop", + SimplifyLocals::Final => "SimplifyLocals-final", + } + } + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { sess.mir_opt_level() > 0 } @@ -376,6 +478,18 @@ impl<'tcx> MirPass<'tcx> for SimplifyLocals { } } +pub fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) { + // First, we're going to get a count of *actual* uses for every `Local`. + let mut used_locals = UsedLocals::new(body); + + // Next, we're going to remove any `Local` with zero actual uses. When we remove those + // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals` + // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from + // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a + // fixedpoint where there are no more unused locals. + remove_unused_definitions_helper(&mut used_locals, body); +} + pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) { // First, we're going to get a count of *actual* uses for every `Local`. let mut used_locals = UsedLocals::new(body); @@ -385,7 +499,7 @@ pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) { // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a // fixedpoint where there are no more unused locals. - remove_unused_definitions(&mut used_locals, body); + remove_unused_definitions_helper(&mut used_locals, body); // Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s. let map = make_local_map(&mut body.local_decls, &used_locals); @@ -394,7 +508,7 @@ pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) { if map.iter().any(Option::is_none) { // Update references to all vars and tmps now let mut updater = LocalUpdater { map, tcx }; - updater.visit_body(body); + updater.visit_body_preserves_cfg(body); body.local_decls.shrink_to_fit(); } @@ -405,7 +519,7 @@ fn make_local_map<V>( local_decls: &mut IndexVec<Local, V>, used_locals: &UsedLocals, ) -> IndexVec<Local, Option<Local>> { - let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, &*local_decls); + let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls); let mut used = Local::new(0); for alive_index in local_decls.indices() { @@ -456,7 +570,7 @@ impl UsedLocals { self.increment = false; // The location of the statement is irrelevant. - let location = Location { block: START_BLOCK, statement_index: 0 }; + let location = Location::START; self.visit_statement(statement, location); } @@ -481,15 +595,16 @@ impl UsedLocals { impl<'tcx> Visitor<'tcx> for UsedLocals { fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { match statement.kind { - StatementKind::CopyNonOverlapping(..) + StatementKind::Intrinsic(..) | StatementKind::Retag(..) | StatementKind::Coverage(..) | StatementKind::FakeRead(..) + | StatementKind::PlaceMention(..) | StatementKind::AscribeUserType(..) => { self.super_statement(statement, location); } - StatementKind::Nop => {} + StatementKind::ConstEvalCounter | StatementKind::Nop => {} StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {} @@ -520,7 +635,7 @@ impl<'tcx> Visitor<'tcx> for UsedLocals { } /// Removes unused definitions. Updates the used locals to reflect the changes made. -fn remove_unused_definitions(used_locals: &mut UsedLocals, body: &mut Body<'_>) { +fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) { // The use counts are updated as we remove the statements. A local might become unused // during the retain operation, leading to a temporary inconsistency (storage statements or // definitions referencing the local might remain). For correctness it is crucial that this @@ -530,7 +645,7 @@ fn remove_unused_definitions(used_locals: &mut UsedLocals, body: &mut Body<'_>) while modified { modified = false; - for data in body.basic_blocks_mut() { + for data in body.basic_blocks.as_mut_preserves_cfg() { // Remove unnecessary StorageLive and StorageDead annotations. data.statements.retain(|statement| { let keep = match &statement.kind { @@ -541,6 +656,7 @@ fn remove_unused_definitions(used_locals: &mut UsedLocals, body: &mut Body<'_>) StatementKind::SetDiscriminant { ref place, .. } | StatementKind::Deinit(ref place) => used_locals.is_used(place.local), + StatementKind::Nop => false, _ => true, }; diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs index 3bbae5b8976..1ff48816986 100644 --- a/compiler/rustc_mir_transform/src/simplify_branches.rs +++ b/compiler/rustc_mir_transform/src/simplify_branches.rs @@ -2,36 +2,28 @@ use crate::MirPass; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; -use std::borrow::Cow; - -/// A pass that replaces a branch with a goto when its condition is known. -pub struct SimplifyConstCondition { - label: String, -} - -impl SimplifyConstCondition { - pub fn new(label: &str) -> Self { - SimplifyConstCondition { label: format!("SimplifyConstCondition-{}", label) } - } +pub enum SimplifyConstCondition { + AfterConstProp, + Final, } - +/// A pass that replaces a branch with a goto when its condition is known. impl<'tcx> MirPass<'tcx> for SimplifyConstCondition { - fn name(&self) -> Cow<'_, str> { - Cow::Borrowed(&self.label) + fn name(&self) -> &'static str { + match self { + SimplifyConstCondition::AfterConstProp => "SimplifyConstCondition-after-const-prop", + SimplifyConstCondition::Final => "SimplifyConstCondition-final", + } } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let param_env = tcx.param_env(body.source.def_id()); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); for block in body.basic_blocks_mut() { let terminator = block.terminator_mut(); terminator.kind = match terminator.kind { TerminatorKind::SwitchInt { - discr: Operand::Constant(ref c), - switch_ty, - ref targets, - .. + discr: Operand::Constant(ref c), ref targets, .. } => { - let constant = c.literal.try_eval_bits(tcx, param_env, switch_ty); + let constant = c.literal.try_eval_bits(tcx, param_env, c.ty()); if let Some(constant) = constant { let target = targets.target_for_value(constant); TerminatorKind::Goto { target } diff --git a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs index bbfaace7041..113ca2fc5ad 100644 --- a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs +++ b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs @@ -37,7 +37,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral { let opts = helper.find_optimizations(); let mut storage_deads_to_insert = vec![]; let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![]; - let param_env = tcx.param_env(body.source.def_id()); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); for opt in opts { trace!("SUCCESS: Applying {:?}", opt); // replace terminator with a switchInt that switches on the integer directly @@ -127,11 +127,8 @@ impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral { let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise); let terminator = bb.terminator_mut(); - terminator.kind = TerminatorKind::SwitchInt { - discr: Operand::Move(opt.to_switch_on), - switch_ty: opt.branch_value_ty, - targets, - }; + terminator.kind = + TerminatorKind::SwitchInt { discr: Operand::Move(opt.to_switch_on), targets }; } for (idx, bb_idx) in storage_deads_to_remove { @@ -151,7 +148,7 @@ struct OptimizationFinder<'a, 'tcx> { impl<'tcx> OptimizationFinder<'_, 'tcx> { fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> { self.body - .basic_blocks() + .basic_blocks .iter_enumerated() .filter_map(|(bb_idx, bb)| { // find switch diff --git a/compiler/rustc_mir_transform/src/simplify_try.rs b/compiler/rustc_mir_transform/src/simplify_try.rs deleted file mode 100644 index d52f1261b23..00000000000 --- a/compiler/rustc_mir_transform/src/simplify_try.rs +++ /dev/null @@ -1,822 +0,0 @@ -//! The general point of the optimizations provided here is to simplify something like: -//! -//! ```rust -//! # fn foo<T, E>(x: Result<T, E>) -> Result<T, E> { -//! match x { -//! Ok(x) => Ok(x), -//! Err(x) => Err(x) -//! } -//! # } -//! ``` -//! -//! into just `x`. - -use crate::{simplify, MirPass}; -use itertools::Itertools as _; -use rustc_index::{bit_set::BitSet, vec::IndexVec}; -use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; -use rustc_middle::mir::*; -use rustc_middle::ty::{self, List, Ty, TyCtxt}; -use rustc_target::abi::VariantIdx; -use std::iter::{once, Enumerate, Peekable}; -use std::slice::Iter; - -/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. -/// -/// This is done by transforming basic blocks where the statements match: -/// -/// ```ignore (MIR) -/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY ); -/// _TMP_2 = _LOCAL_TMP; -/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2; -/// discriminant(_LOCAL_0) = VAR_IDX; -/// ``` -/// -/// into: -/// -/// ```ignore (MIR) -/// _LOCAL_0 = move _LOCAL_1 -/// ``` -pub struct SimplifyArmIdentity; - -#[derive(Debug)] -struct ArmIdentityInfo<'tcx> { - /// Storage location for the variant's field - local_temp_0: Local, - /// Storage location holding the variant being read from - local_1: Local, - /// The variant field being read from - vf_s0: VarField<'tcx>, - /// Index of the statement which loads the variant being read - get_variant_field_stmt: usize, - - /// Tracks each assignment to a temporary of the variant's field - field_tmp_assignments: Vec<(Local, Local)>, - - /// Storage location holding the variant's field that was read from - local_tmp_s1: Local, - /// Storage location holding the enum that we are writing to - local_0: Local, - /// The variant field being written to - vf_s1: VarField<'tcx>, - - /// Storage location that the discriminant is being written to - set_discr_local: Local, - /// The variant being written - set_discr_var_idx: VariantIdx, - - /// Index of the statement that should be overwritten as a move - stmt_to_overwrite: usize, - /// SourceInfo for the new move - source_info: SourceInfo, - - /// Indices of matching Storage{Live,Dead} statements encountered. - /// (StorageLive index,, StorageDead index, Local) - storage_stmts: Vec<(usize, usize, Local)>, - - /// The statements that should be removed (turned into nops) - stmts_to_remove: Vec<usize>, - - /// Indices of debug variables that need to be adjusted to point to - // `{local_0}.{dbg_projection}`. - dbg_info_to_adjust: Vec<usize>, - - /// The projection used to rewrite debug info. - dbg_projection: &'tcx List<PlaceElem<'tcx>>, -} - -fn get_arm_identity_info<'a, 'tcx>( - stmts: &'a [Statement<'tcx>], - locals_count: usize, - debug_info: &'a [VarDebugInfo<'tcx>], -) -> Option<ArmIdentityInfo<'tcx>> { - // This can't possibly match unless there are at least 3 statements in the block - // so fail fast on tiny blocks. - if stmts.len() < 3 { - return None; - } - - let mut tmp_assigns = Vec::new(); - let mut nop_stmts = Vec::new(); - let mut storage_stmts = Vec::new(); - let mut storage_live_stmts = Vec::new(); - let mut storage_dead_stmts = Vec::new(); - - type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>; - - fn is_storage_stmt(stmt: &Statement<'_>) -> bool { - matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_)) - } - - /// Eats consecutive Statements which match `test`, performing the specified `action` for each. - /// The iterator `stmt_iter` is not advanced if none were matched. - fn try_eat<'a, 'tcx>( - stmt_iter: &mut StmtIter<'a, 'tcx>, - test: impl Fn(&'a Statement<'tcx>) -> bool, - mut action: impl FnMut(usize, &'a Statement<'tcx>), - ) { - while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) { - let (idx, stmt) = stmt_iter.next().unwrap(); - - action(idx, stmt); - } - } - - /// Eats consecutive `StorageLive` and `StorageDead` Statements. - /// The iterator `stmt_iter` is not advanced if none were found. - fn try_eat_storage_stmts( - stmt_iter: &mut StmtIter<'_, '_>, - storage_live_stmts: &mut Vec<(usize, Local)>, - storage_dead_stmts: &mut Vec<(usize, Local)>, - ) { - try_eat(stmt_iter, is_storage_stmt, |idx, stmt| { - if let StatementKind::StorageLive(l) = stmt.kind { - storage_live_stmts.push((idx, l)); - } else if let StatementKind::StorageDead(l) = stmt.kind { - storage_dead_stmts.push((idx, l)); - } - }) - } - - fn is_tmp_storage_stmt(stmt: &Statement<'_>) -> bool { - use rustc_middle::mir::StatementKind::Assign; - if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind { - place.as_local().is_some() && p.as_local().is_some() - } else { - false - } - } - - /// Eats consecutive `Assign` Statements. - // The iterator `stmt_iter` is not advanced if none were found. - fn try_eat_assign_tmp_stmts( - stmt_iter: &mut StmtIter<'_, '_>, - tmp_assigns: &mut Vec<(Local, Local)>, - nop_stmts: &mut Vec<usize>, - ) { - try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| { - use rustc_middle::mir::StatementKind::Assign; - if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = - &stmt.kind - { - tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap())); - nop_stmts.push(idx); - } - }) - } - - fn find_storage_live_dead_stmts_for_local( - local: Local, - stmts: &[Statement<'_>], - ) -> Option<(usize, usize)> { - trace!("looking for {:?}", local); - let mut storage_live_stmt = None; - let mut storage_dead_stmt = None; - for (idx, stmt) in stmts.iter().enumerate() { - if stmt.kind == StatementKind::StorageLive(local) { - storage_live_stmt = Some(idx); - } else if stmt.kind == StatementKind::StorageDead(local) { - storage_dead_stmt = Some(idx); - } - } - - Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX))) - } - - // Try to match the expected MIR structure with the basic block we're processing. - // We want to see something that looks like: - // ``` - // (StorageLive(_) | StorageDead(_));* - // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); - // (StorageLive(_) | StorageDead(_));* - // (tmp_n+1 = tmp_n);* - // (StorageLive(_) | StorageDead(_));* - // (tmp_n+1 = tmp_n);* - // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp; - // discriminant(LOCAL_FROM) = VariantIdx; - // (StorageLive(_) | StorageDead(_));* - // ``` - let mut stmt_iter = stmts.iter().enumerate().peekable(); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - let (get_variant_field_stmt, stmt) = stmt_iter.next()?; - let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?; - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); - - let (idx, stmt) = stmt_iter.next()?; - let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?; - nop_stmts.push(idx); - - let (idx, stmt) = stmt_iter.next()?; - let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?; - let discr_stmt_source_info = stmt.source_info; - nop_stmts.push(idx); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - for (live_idx, live_local) in storage_live_stmts { - if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) { - let (dead_idx, _) = storage_dead_stmts.swap_remove(i); - storage_stmts.push((live_idx, dead_idx, live_local)); - - if live_local == local_tmp_s0 { - nop_stmts.push(get_variant_field_stmt); - } - } - } - // We sort primitive usize here so we can use unstable sort - nop_stmts.sort_unstable(); - - // Use one of the statements we're going to discard between the point - // where the storage location for the variant field becomes live and - // is killed. - let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?; - let stmt_to_overwrite = - nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx); - - let mut tmp_assigned_vars = BitSet::new_empty(locals_count); - for (l, r) in &tmp_assigns { - tmp_assigned_vars.insert(*l); - tmp_assigned_vars.insert(*r); - } - - let dbg_info_to_adjust: Vec<_> = debug_info - .iter() - .enumerate() - .filter_map(|(i, var_info)| { - if let VarDebugInfoContents::Place(p) = var_info.value { - if tmp_assigned_vars.contains(p.local) { - return Some(i); - } - } - - None - }) - .collect(); - - Some(ArmIdentityInfo { - local_temp_0: local_tmp_s0, - local_1, - vf_s0, - get_variant_field_stmt, - field_tmp_assignments: tmp_assigns, - local_tmp_s1, - local_0, - vf_s1, - set_discr_local, - set_discr_var_idx, - stmt_to_overwrite: *stmt_to_overwrite?, - source_info: discr_stmt_source_info, - storage_stmts, - stmts_to_remove: nop_stmts, - dbg_info_to_adjust, - dbg_projection, - }) -} - -fn optimization_applies<'tcx>( - opt_info: &ArmIdentityInfo<'tcx>, - local_decls: &IndexVec<Local, LocalDecl<'tcx>>, - local_uses: &IndexVec<Local, usize>, - var_debug_info: &[VarDebugInfo<'tcx>], -) -> bool { - trace!("testing if optimization applies..."); - - // FIXME(wesleywiser): possibly relax this restriction? - if opt_info.local_0 == opt_info.local_1 { - trace!("NO: moving into ourselves"); - return false; - } else if opt_info.vf_s0 != opt_info.vf_s1 { - trace!("NO: the field-and-variant information do not match"); - return false; - } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty { - // FIXME(Centril,oli-obk): possibly relax to same layout? - trace!("NO: source and target locals have different types"); - return false; - } else if (opt_info.local_0, opt_info.vf_s0.var_idx) - != (opt_info.set_discr_local, opt_info.set_discr_var_idx) - { - trace!("NO: the discriminants do not match"); - return false; - } - - // Verify the assignment chain consists of the form b = a; c = b; d = c; etc... - if opt_info.field_tmp_assignments.is_empty() { - trace!("NO: no assignments found"); - return false; - } - let mut last_assigned_to = opt_info.field_tmp_assignments[0].1; - let source_local = last_assigned_to; - for (l, r) in &opt_info.field_tmp_assignments { - if *r != last_assigned_to { - trace!("NO: found unexpected assignment {:?} = {:?}", l, r); - return false; - } - - last_assigned_to = *l; - } - - // Check that the first and last used locals are only used twice - // since they are of the form: - // - // ``` - // _first = ((_x as Variant).n: ty); - // _n = _first; - // ... - // ((_y as Variant).n: ty) = _n; - // discriminant(_y) = z; - // ``` - for (l, r) in &opt_info.field_tmp_assignments { - if local_uses[*l] != 2 { - warn!("NO: FAILED assignment chain local {:?} was used more than twice", l); - return false; - } else if local_uses[*r] != 2 { - warn!("NO: FAILED assignment chain local {:?} was used more than twice", r); - return false; - } - } - - // Check that debug info only points to full Locals and not projections. - for dbg_idx in &opt_info.dbg_info_to_adjust { - let dbg_info = &var_debug_info[*dbg_idx]; - if let VarDebugInfoContents::Place(p) = dbg_info.value { - if !p.projection.is_empty() { - trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p); - return false; - } - } - } - - if source_local != opt_info.local_temp_0 { - trace!( - "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}", - source_local, - opt_info.local_temp_0 - ); - return false; - } else if last_assigned_to != opt_info.local_tmp_s1 { - trace!( - "NO: end of assignment chain does not match written enum temp: {:?} != {:?}", - last_assigned_to, - opt_info.local_tmp_s1 - ); - return false; - } - - trace!("SUCCESS: optimization applies!"); - true -} - -impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // FIXME(77359): This optimization can result in unsoundness. - if !tcx.sess.opts.unstable_opts.unsound_mir_opts { - return; - } - - let source = body.source; - trace!("running SimplifyArmIdentity on {:?}", source); - - let local_uses = LocalUseCounter::get_local_uses(body); - for bb in body.basic_blocks.as_mut() { - if let Some(opt_info) = - get_arm_identity_info(&bb.statements, body.local_decls.len(), &body.var_debug_info) - { - trace!("got opt_info = {:#?}", opt_info); - if !optimization_applies( - &opt_info, - &body.local_decls, - &local_uses, - &body.var_debug_info, - ) { - debug!("optimization skipped for {:?}", source); - continue; - } - - // Also remove unused Storage{Live,Dead} statements which correspond - // to temps used previously. - for (live_idx, dead_idx, local) in &opt_info.storage_stmts { - // The temporary that we've read the variant field into is scoped to this block, - // so we can remove the assignment. - if *local == opt_info.local_temp_0 { - bb.statements[opt_info.get_variant_field_stmt].make_nop(); - } - - for (left, right) in &opt_info.field_tmp_assignments { - if local == left || local == right { - bb.statements[*live_idx].make_nop(); - bb.statements[*dead_idx].make_nop(); - } - } - } - - // Right shape; transform - for stmt_idx in opt_info.stmts_to_remove { - bb.statements[stmt_idx].make_nop(); - } - - let stmt = &mut bb.statements[opt_info.stmt_to_overwrite]; - stmt.source_info = opt_info.source_info; - stmt.kind = StatementKind::Assign(Box::new(( - opt_info.local_0.into(), - Rvalue::Use(Operand::Move(opt_info.local_1.into())), - ))); - - bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop); - - // Fix the debug info to point to the right local - for dbg_index in opt_info.dbg_info_to_adjust { - let dbg_info = &mut body.var_debug_info[dbg_index]; - assert!( - matches!(dbg_info.value, VarDebugInfoContents::Place(_)), - "value was not a Place" - ); - if let VarDebugInfoContents::Place(p) = &mut dbg_info.value { - assert!(p.projection.is_empty()); - p.local = opt_info.local_0; - p.projection = opt_info.dbg_projection; - } - } - - trace!("block is now {:?}", bb.statements); - } - } - } -} - -struct LocalUseCounter { - local_uses: IndexVec<Local, usize>, -} - -impl LocalUseCounter { - fn get_local_uses(body: &Body<'_>) -> IndexVec<Local, usize> { - let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) }; - counter.visit_body(body); - counter.local_uses - } -} - -impl Visitor<'_> for LocalUseCounter { - fn visit_local(&mut self, local: Local, context: PlaceContext, _location: Location) { - if context.is_storage_marker() - || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo) - { - return; - } - - self.local_uses[local] += 1; - } -} - -/// Match on: -/// ```ignore (MIR) -/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); -/// ``` -fn match_get_variant_field<'tcx>( - stmt: &Statement<'tcx>, -) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> { - match &stmt.kind { - StatementKind::Assign(box ( - place_into, - Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)), - )) => { - let local_into = place_into.as_local()?; - let (local_from, vf) = match_variant_field_place(*pf)?; - Some((local_into, local_from, vf, pf.projection)) - } - _ => None, - } -} - -/// Match on: -/// ```ignore (MIR) -/// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO; -/// ``` -fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> { - match &stmt.kind { - StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => { - let local_into = place_into.as_local()?; - let (local_from, vf) = match_variant_field_place(*place_from)?; - Some((local_into, local_from, vf)) - } - _ => None, - } -} - -/// Match on: -/// ```ignore (MIR) -/// discriminant(_LOCAL_TO_SET) = VAR_IDX; -/// ``` -fn match_set_discr(stmt: &Statement<'_>) -> Option<(Local, VariantIdx)> { - match &stmt.kind { - StatementKind::SetDiscriminant { place, variant_index } => { - Some((place.as_local()?, *variant_index)) - } - _ => None, - } -} - -#[derive(PartialEq, Debug)] -struct VarField<'tcx> { - field: Field, - field_ty: Ty<'tcx>, - var_idx: VariantIdx, -} - -/// Match on `((_LOCAL as Variant).FIELD: TY)`. -fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> { - match place.as_ref() { - PlaceRef { - local, - projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)], - } => Some((local, VarField { field, field_ty: ty, var_idx })), - _ => None, - } -} - -/// Simplifies `SwitchInt(_) -> [targets]`, -/// where all the `targets` have the same form, -/// into `goto -> target_first`. -pub struct SimplifyBranchSame; - -impl<'tcx> MirPass<'tcx> for SimplifyBranchSame { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // This optimization is disabled by default for now due to - // soundness concerns; see issue #89485 and PR #89489. - if !tcx.sess.opts.unstable_opts.unsound_mir_opts { - return; - } - - trace!("Running SimplifyBranchSame on {:?}", body.source); - let finder = SimplifyBranchSameOptimizationFinder { body, tcx }; - let opts = finder.find(); - - let did_remove_blocks = opts.len() > 0; - for opt in opts.iter() { - trace!("SUCCESS: Applying optimization {:?}", opt); - // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`. - body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind = - TerminatorKind::Goto { target: opt.bb_to_goto }; - } - - if did_remove_blocks { - // We have dead blocks now, so remove those. - simplify::remove_dead_blocks(tcx, body); - } - } -} - -#[derive(Debug)] -struct SimplifyBranchSameOptimization { - /// All basic blocks are equal so go to this one - bb_to_goto: BasicBlock, - /// Basic block where the terminator can be simplified to a goto - bb_to_opt_terminator: BasicBlock, -} - -struct SwitchTargetAndValue { - target: BasicBlock, - // None in case of the `otherwise` case - value: Option<u128>, -} - -struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> { - body: &'a Body<'tcx>, - tcx: TyCtxt<'tcx>, -} - -impl<'tcx> SimplifyBranchSameOptimizationFinder<'_, 'tcx> { - fn find(&self) -> Vec<SimplifyBranchSameOptimization> { - self.body - .basic_blocks() - .iter_enumerated() - .filter_map(|(bb_idx, bb)| { - let (discr_switched_on, targets_and_values) = match &bb.terminator().kind { - TerminatorKind::SwitchInt { targets, discr, .. } => { - let targets_and_values: Vec<_> = targets.iter() - .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) }) - .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None })) - .collect(); - (discr, targets_and_values) - }, - _ => return None, - }; - - // find the adt that has its discriminant read - // assuming this must be the last statement of the block - let adt_matched_on = match &bb.statements.last()?.kind { - StatementKind::Assign(box (place, rhs)) - if Some(*place) == discr_switched_on.place() => - { - match rhs { - Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place, - _ => { - trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs); - return None; - } - } - } - other => { - trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other); - return None - }, - }; - - let mut iter_bbs_reachable = targets_and_values - .iter() - .map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target])) - .filter(|(_, bb)| { - // Reaching `unreachable` is UB so assume it doesn't happen. - bb.terminator().kind != TerminatorKind::Unreachable - }) - .peekable(); - - let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx); - let mut all_successors_equivalent = StatementEquality::TrivialEqual; - - // All successor basic blocks must be equal or contain statements that are pairwise considered equal. - for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() { - let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup - && bb_l.terminator().kind == bb_r.terminator().kind - && bb_l.statements.len() == bb_r.statements.len(); - let statement_check = || { - bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| { - let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r); - if matches!(stmt_equality, StatementEquality::NotEqual) { - // short circuit - None - } else { - Some(acc.combine(&stmt_equality)) - } - }) - .unwrap_or(StatementEquality::NotEqual) - }; - if !trivial_checks { - all_successors_equivalent = StatementEquality::NotEqual; - break; - } - all_successors_equivalent = all_successors_equivalent.combine(&statement_check()); - }; - - match all_successors_equivalent{ - StatementEquality::TrivialEqual => { - // statements are trivially equal, so just take first - trace!("Statements are trivially equal"); - Some(SimplifyBranchSameOptimization { - bb_to_goto: bb_first.target, - bb_to_opt_terminator: bb_idx, - }) - } - StatementEquality::ConsideredEqual(bb_to_choose) => { - trace!("Statements are considered equal"); - Some(SimplifyBranchSameOptimization { - bb_to_goto: bb_to_choose, - bb_to_opt_terminator: bb_idx, - }) - } - StatementEquality::NotEqual => { - trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx); - None - } - } - }) - .collect() - } - - /// Tests if two statements can be considered equal - /// - /// Statements can be trivially equal if the kinds match. - /// But they can also be considered equal in the following case A: - /// ```ignore (MIR) - /// discriminant(_0) = 0; // bb1 - /// _0 = move _1; // bb2 - /// ``` - /// In this case the two statements are equal iff - /// - `_0` is an enum where the variant index 0 is fieldless, and - /// - bb1 was targeted by a switch where the discriminant of `_1` was switched on - fn statement_equality( - &self, - adt_matched_on: Place<'tcx>, - x: &Statement<'tcx>, - x_target_and_value: &SwitchTargetAndValue, - y: &Statement<'tcx>, - y_target_and_value: &SwitchTargetAndValue, - ) -> StatementEquality { - let helper = |rhs: &Rvalue<'tcx>, - place: &Place<'tcx>, - variant_index: VariantIdx, - switch_value: u128, - side_to_choose| { - let place_type = place.ty(self.body, self.tcx).ty; - let adt = match *place_type.kind() { - ty::Adt(adt, _) if adt.is_enum() => adt, - _ => return StatementEquality::NotEqual, - }; - // We need to make sure that the switch value that targets the bb with - // SetDiscriminant is the same as the variant discriminant. - let variant_discr = adt.discriminant_for_variant(self.tcx, variant_index).val; - if variant_discr != switch_value { - trace!( - "NO: variant discriminant {} does not equal switch value {}", - variant_discr, - switch_value - ); - return StatementEquality::NotEqual; - } - let variant_is_fieldless = adt.variant(variant_index).fields.is_empty(); - if !variant_is_fieldless { - trace!("NO: variant {:?} was not fieldless", variant_index); - return StatementEquality::NotEqual; - } - - match rhs { - Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => { - StatementEquality::ConsideredEqual(side_to_choose) - } - _ => { - trace!( - "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}", - rhs, - adt_matched_on - ); - StatementEquality::NotEqual - } - } - }; - match (&x.kind, &y.kind) { - // trivial case - (x, y) if x == y => StatementEquality::TrivialEqual, - - // check for case A - ( - StatementKind::Assign(box (_, rhs)), - &StatementKind::SetDiscriminant { ref place, variant_index }, - ) if y_target_and_value.value.is_some() => { - // choose basic block of x, as that has the assign - helper( - rhs, - place, - variant_index, - y_target_and_value.value.unwrap(), - x_target_and_value.target, - ) - } - ( - &StatementKind::SetDiscriminant { ref place, variant_index }, - &StatementKind::Assign(box (_, ref rhs)), - ) if x_target_and_value.value.is_some() => { - // choose basic block of y, as that has the assign - helper( - rhs, - place, - variant_index, - x_target_and_value.value.unwrap(), - y_target_and_value.target, - ) - } - _ => { - trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y); - StatementEquality::NotEqual - } - } - } -} - -#[derive(Copy, Clone, Eq, PartialEq)] -enum StatementEquality { - /// The two statements are trivially equal; same kind - TrivialEqual, - /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization. - /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index - ConsideredEqual(BasicBlock), - /// The two statements are not equal - NotEqual, -} - -impl StatementEquality { - fn combine(&self, other: &StatementEquality) -> StatementEquality { - use StatementEquality::*; - match (self, other) { - (TrivialEqual, TrivialEqual) => TrivialEqual, - (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => { - ConsideredEqual(*b) - } - (ConsideredEqual(b1), ConsideredEqual(b2)) => { - if b1 == b2 { - ConsideredEqual(*b1) - } else { - NotEqual - } - } - (_, NotEqual) | (NotEqual, _) => NotEqual, - } - } -} diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs new file mode 100644 index 00000000000..e4b3b8b9262 --- /dev/null +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -0,0 +1,470 @@ +use crate::MirPass; +use rustc_index::bit_set::{BitSet, GrowableBitSet}; +use rustc_index::IndexVec; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; +use rustc_target::abi::{FieldIdx, ReprFlags, FIRST_VARIANT}; + +pub struct ScalarReplacementOfAggregates; + +impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + #[instrument(level = "debug", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + + // Avoid query cycles (generators require optimized MIR for layout). + if tcx.type_of(body.source.def_id()).subst_identity().is_generator() { + return; + } + + let mut excluded = excluded_locals(body); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + loop { + debug!(?excluded); + let escaping = escaping_locals(tcx, param_env, &excluded, body); + debug!(?escaping); + let replacements = compute_flattening(tcx, param_env, body, escaping); + debug!(?replacements); + let all_dead_locals = replace_flattened_locals(tcx, body, replacements); + if !all_dead_locals.is_empty() { + excluded.union(&all_dead_locals); + excluded = { + let mut growable = GrowableBitSet::from(excluded); + growable.ensure(body.local_decls.len()); + growable.into() + }; + } else { + break; + } + } + } +} + +/// Identify all locals that are not eligible for SROA. +/// +/// There are 3 cases: +/// - the aggregated local is used or passed to other code (function parameters and arguments); +/// - the locals is a union or an enum; +/// - the local's address is taken, and thus the relative addresses of the fields are observable to +/// client code. +fn escaping_locals<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + excluded: &BitSet<Local>, + body: &Body<'tcx>, +) -> BitSet<Local> { + let is_excluded_ty = |ty: Ty<'tcx>| { + if ty.is_union() || ty.is_enum() { + return true; + } + if let ty::Adt(def, _substs) = ty.kind() { + if def.repr().flags.contains(ReprFlags::IS_SIMD) { + // Exclude #[repr(simd)] types so that they are not de-optimized into an array + return true; + } + // We already excluded unions and enums, so this ADT must have one variant + let variant = def.variant(FIRST_VARIANT); + if variant.fields.len() > 1 { + // If this has more than one field, it cannot be a wrapper that only provides a + // niche, so we do not want to automatically exclude it. + return false; + } + let Ok(layout) = tcx.layout_of(param_env.and(ty)) else { + // We can't get the layout + return true; + }; + if layout.layout.largest_niche().is_some() { + // This type has a niche + return true; + } + } + // Default for non-ADTs + false + }; + + let mut set = BitSet::new_empty(body.local_decls.len()); + set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); + for (local, decl) in body.local_decls().iter_enumerated() { + if excluded.contains(local) || is_excluded_ty(decl.ty) { + set.insert(local); + } + } + let mut visitor = EscapeVisitor { set }; + visitor.visit_body(body); + return visitor.set; + + struct EscapeVisitor { + set: BitSet<Local>, + } + + impl<'tcx> Visitor<'tcx> for EscapeVisitor { + fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) { + self.set.insert(local); + } + + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + // Mirror the implementation in PreFlattenVisitor. + if let &[PlaceElem::Field(..), ..] = &place.projection[..] { + return; + } + self.super_place(place, context, location); + } + + fn visit_assign( + &mut self, + lvalue: &Place<'tcx>, + rvalue: &Rvalue<'tcx>, + location: Location, + ) { + if lvalue.as_local().is_some() { + match rvalue { + // Aggregate assignments are expanded in run_pass. + Rvalue::Aggregate(..) | Rvalue::Use(..) => { + self.visit_rvalue(rvalue, location); + return; + } + _ => {} + } + } + self.super_assign(lvalue, rvalue, location) + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match statement.kind { + // Storage statements are expanded in run_pass. + StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Deinit(..) => return, + _ => self.super_statement(statement, location), + } + } + + // We ignore anything that happens in debuginfo, since we expand it using + // `VarDebugInfoContents::Composite`. + fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {} + } +} + +#[derive(Default, Debug)] +struct ReplacementMap<'tcx> { + /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage + /// and deinit statement and debuginfo. + fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>, +} + +impl<'tcx> ReplacementMap<'tcx> { + fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> { + let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; }; + let fields = self.fragments[place.local].as_ref()?; + let (_, new_local) = fields[f]?; + Some(Place { local: new_local, projection: tcx.mk_place_elems(&rest) }) + } + + fn place_fragments( + &self, + place: Place<'tcx>, + ) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)> + '_> { + let local = place.as_local()?; + let fields = self.fragments[local].as_ref()?; + Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| { + let (ty, local) = opt_ty_local?; + Some((field, ty, local)) + })) + } +} + +/// Compute the replacement of flattened places into locals. +/// +/// For each eligible place, we assign a new local to each accessed field. +/// The replacement will be done later in `ReplacementVisitor`. +fn compute_flattening<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + body: &mut Body<'tcx>, + escaping: BitSet<Local>, +) -> ReplacementMap<'tcx> { + let mut fragments = IndexVec::from_elem(None, &body.local_decls); + + for local in body.local_decls.indices() { + if escaping.contains(local) { + continue; + } + let decl = body.local_decls[local].clone(); + let ty = decl.ty; + iter_fields(ty, tcx, param_env, |variant, field, field_ty| { + if variant.is_some() { + // Downcasts are currently not supported. + return; + }; + let new_local = + body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() }); + fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local)); + }); + } + ReplacementMap { fragments } +} + +/// Perform the replacement computed by `compute_flattening`. +fn replace_flattened_locals<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + replacements: ReplacementMap<'tcx>, +) -> BitSet<Local> { + let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len()); + for (local, replacements) in replacements.fragments.iter_enumerated() { + if replacements.is_some() { + all_dead_locals.insert(local); + } + } + debug!(?all_dead_locals); + if all_dead_locals.is_empty() { + return all_dead_locals; + } + + let mut visitor = ReplacementVisitor { + tcx, + local_decls: &body.local_decls, + replacements: &replacements, + all_dead_locals, + patch: MirPatch::new(body), + }; + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + visitor.visit_basic_block_data(bb, data); + } + for scope in &mut body.source_scopes { + visitor.visit_source_scope_data(scope); + } + for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() { + visitor.visit_user_type_annotation(index, annotation); + } + for var_debug_info in &mut body.var_debug_info { + visitor.visit_var_debug_info(var_debug_info); + } + let ReplacementVisitor { patch, all_dead_locals, .. } = visitor; + patch.apply(body); + all_dead_locals +} + +struct ReplacementVisitor<'tcx, 'll> { + tcx: TyCtxt<'tcx>, + /// This is only used to compute the type for `VarDebugInfoContents::Composite`. + local_decls: &'ll LocalDecls<'tcx>, + /// Work to do. + replacements: &'ll ReplacementMap<'tcx>, + /// This is used to check that we are not leaving references to replaced locals behind. + all_dead_locals: BitSet<Local>, + patch: MirPatch<'tcx>, +} + +impl<'tcx> ReplacementVisitor<'tcx, '_> { + fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> { + let mut fragments = Vec::new(); + let parts = self.replacements.place_fragments(local.into())?; + for (field, ty, replacement_local) in parts { + fragments.push(VarDebugInfoFragment { + projection: vec![PlaceElem::Field(field, ty)], + contents: Place::from(replacement_local), + }); + } + Some(fragments) + } +} + +impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { + *place = repl + } else { + self.super_place(place, context, location) + } + } + + #[instrument(level = "trace", skip(self))] + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + match statement.kind { + // Duplicate storage and deinit statements, as they pretty much apply to all fields. + StatementKind::StorageLive(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageLive(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::StorageDead(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageDead(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::Deinit(box place) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + for (_, _, fl) in final_locals { + self.patch + .add_statement(location, StatementKind::Deinit(Box::new(fl.into()))); + } + statement.make_nop(); + return; + } + } + + // We have `a = Struct { 0: x, 1: y, .. }`. + // We replace it by + // ``` + // a_0 = x + // a_1 = y + // ... + // ``` + StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => { + if let Some(local) = place.as_local() + && let Some(final_locals) = &self.replacements.fragments[local] + { + // This is ok as we delete the statement later. + let operands = std::mem::take(operands); + for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) { + if let Some((_, new_local)) = opt_ty_local { + // Replace mentions of SROA'd locals that appear in the operand. + self.visit_operand(&mut operand, location); + + let rvalue = Rvalue::Use(operand); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + } + statement.make_nop(); + return; + } + } + + // We have `a = some constant` + // We add the projections. + // ``` + // a_0 = a.0 + // a_1 = a.1 + // ... + // ``` + // ConstProp will pick up the pieces and replace them by actual constants. + StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + // Put the deaggregated statements *after* the original one. + let location = location.successor_within_block(); + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(place, field, ty); + let rvalue = Rvalue::Use(Operand::Move(rplace)); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + // We still need `place.local` to exist, so don't make it nop. + return; + } + } + + // We have `a = move? place` + // We replace it by + // ``` + // a_0 = move? place.0 + // a_1 = move? place.1 + // ... + // ``` + StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => { + let (rplace, copy) = match *op { + Operand::Copy(rplace) => (rplace, true), + Operand::Move(rplace) => (rplace, false), + Operand::Constant(_) => bug!(), + }; + if let Some(final_locals) = self.replacements.place_fragments(lhs) { + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(rplace, field, ty); + debug!(?rplace); + let rplace = self + .replacements + .replace_place(self.tcx, rplace.as_ref()) + .unwrap_or(rplace); + debug!(?rplace); + let rvalue = if copy { + Rvalue::Use(Operand::Copy(rplace)) + } else { + Rvalue::Use(Operand::Move(rplace)) + }; + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + statement.make_nop(); + return; + } + } + + _ => {} + } + self.super_statement(statement, location) + } + + #[instrument(level = "trace", skip(self))] + fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { + match &mut var_debug_info.value { + VarDebugInfoContents::Place(ref mut place) => { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { + *place = repl; + } else if let Some(local) = place.as_local() + && let Some(fragments) = self.gather_debug_info_fragments(local) + { + let ty = place.ty(self.local_decls, self.tcx).ty; + var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments }; + } + } + VarDebugInfoContents::Composite { ty: _, ref mut fragments } => { + let mut new_fragments = Vec::new(); + debug!(?fragments); + fragments + .drain_filter(|fragment| { + if let Some(repl) = + self.replacements.replace_place(self.tcx, fragment.contents.as_ref()) + { + fragment.contents = repl; + false + } else if let Some(local) = fragment.contents.as_local() + && let Some(frg) = self.gather_debug_info_fragments(local) + { + new_fragments.extend(frg.into_iter().map(|mut f| { + f.projection.splice(0..0, fragment.projection.iter().copied()); + f + })); + true + } else { + false + } + }) + .for_each(drop); + debug!(?fragments); + debug!(?new_fragments); + fragments.extend(new_fragments); + } + VarDebugInfoContents::Const(_) => {} + } + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert!(!self.all_dead_locals.contains(*local)); + } +} diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs new file mode 100644 index 00000000000..7a0d3a025f3 --- /dev/null +++ b/compiler/rustc_mir_transform/src/ssa.rs @@ -0,0 +1,345 @@ +//! We denote as "SSA" the set of locals that verify the following properties: +//! 1/ They are only assigned-to once, either as a function parameter, or in an assign statement; +//! 2/ This single assignment dominates all uses; +//! +//! As a consequence of rule 2, we consider that borrowed locals are not SSA, even if they are +//! `Freeze`, as we do not track that the assignment dominates all uses of the borrow. + +use either::Either; +use rustc_data_structures::graph::dominators::Dominators; +use rustc_index::bit_set::BitSet; +use rustc_index::{IndexSlice, IndexVec}; +use rustc_middle::middle::resolve_bound_vars::Set1; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; + +#[derive(Debug)] +pub struct SsaLocals { + /// Assignments to each local. This defines whether the local is SSA. + assignments: IndexVec<Local, Set1<LocationExtended>>, + /// We visit the body in reverse postorder, to ensure each local is assigned before it is used. + /// We remember the order in which we saw the assignments to compute the SSA values in a single + /// pass. + assignment_order: Vec<Local>, + /// Copy equivalence classes between locals. See `copy_classes` for documentation. + copy_classes: IndexVec<Local, Local>, + /// Number of "direct" uses of each local, ie. uses that are not dereferences. + /// We ignore non-uses (Storage statements, debuginfo). + direct_uses: IndexVec<Local, u32>, +} + +/// We often encounter MIR bodies with 1 or 2 basic blocks. In those cases, it's unnecessary to +/// actually compute dominators, we can just compare block indices because bb0 is always the first +/// block, and in any body all other blocks are always dominated by bb0. +struct SmallDominators<'a> { + inner: Option<&'a Dominators<BasicBlock>>, +} + +impl SmallDominators<'_> { + fn dominates(&self, first: Location, second: Location) -> bool { + if first.block == second.block { + first.statement_index <= second.statement_index + } else if let Some(inner) = &self.inner { + inner.dominates(first.block, second.block) + } else { + first.block < second.block + } + } + + fn check_dominates(&mut self, set: &mut Set1<LocationExtended>, loc: Location) { + let assign_dominates = match *set { + Set1::Empty | Set1::Many => false, + Set1::One(LocationExtended::Arg) => true, + Set1::One(LocationExtended::Plain(assign)) => { + self.dominates(assign.successor_within_block(), loc) + } + }; + // We are visiting a use that is not dominated by an assignment. + // Either there is a cycle involved, or we are reading for uninitialized local. + // Bail out. + if !assign_dominates { + *set = Set1::Many; + } + } +} + +impl SsaLocals { + pub fn new<'tcx>(body: &Body<'tcx>) -> SsaLocals { + let assignment_order = Vec::with_capacity(body.local_decls.len()); + + let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls); + let dominators = + if body.basic_blocks.len() > 2 { Some(body.basic_blocks.dominators()) } else { None }; + let dominators = SmallDominators { inner: dominators }; + + let direct_uses = IndexVec::from_elem(0, &body.local_decls); + let mut visitor = SsaVisitor { assignments, assignment_order, dominators, direct_uses }; + + for local in body.args_iter() { + visitor.assignments[local] = Set1::One(LocationExtended::Arg); + } + + if body.basic_blocks.len() > 2 { + for (bb, data) in traversal::reverse_postorder(body) { + visitor.visit_basic_block_data(bb, data); + } + } else { + for (bb, data) in body.basic_blocks.iter_enumerated() { + visitor.visit_basic_block_data(bb, data); + } + } + + for var_debug_info in &body.var_debug_info { + visitor.visit_var_debug_info(var_debug_info); + } + + debug!(?visitor.assignments); + debug!(?visitor.direct_uses); + + visitor + .assignment_order + .retain(|&local| matches!(visitor.assignments[local], Set1::One(_))); + debug!(?visitor.assignment_order); + + let mut ssa = SsaLocals { + assignments: visitor.assignments, + assignment_order: visitor.assignment_order, + direct_uses: visitor.direct_uses, + // This is filled by `compute_copy_classes`. + copy_classes: IndexVec::default(), + }; + compute_copy_classes(&mut ssa, body); + ssa + } + + pub fn num_locals(&self) -> usize { + self.assignments.len() + } + + pub fn locals(&self) -> impl Iterator<Item = Local> { + self.assignments.indices() + } + + pub fn is_ssa(&self, local: Local) -> bool { + matches!(self.assignments[local], Set1::One(_)) + } + + /// Return the number of uses if a local that are not "Deref". + pub fn num_direct_uses(&self, local: Local) -> u32 { + self.direct_uses[local] + } + + pub fn assignments<'a, 'tcx>( + &'a self, + body: &'a Body<'tcx>, + ) -> impl Iterator<Item = (Local, &'a Rvalue<'tcx>, Location)> + 'a { + self.assignment_order.iter().filter_map(|&local| { + if let Set1::One(LocationExtended::Plain(loc)) = self.assignments[local] { + // `loc` must point to a direct assignment to `local`. + let Either::Left(stmt) = body.stmt_at(loc) else { bug!() }; + let Some((target, rvalue)) = stmt.kind.as_assign() else { bug!() }; + assert_eq!(target.as_local(), Some(local)); + Some((local, rvalue, loc)) + } else { + None + } + }) + } + + /// Compute the equivalence classes for locals, based on copy statements. + /// + /// The returned vector maps each local to the one it copies. In the following case: + /// _a = &mut _0 + /// _b = move? _a + /// _c = move? _a + /// _d = move? _c + /// We return the mapping + /// _a => _a // not a copy so, represented by itself + /// _b => _a + /// _c => _a + /// _d => _a // transitively through _c + /// + /// Exception: we do not see through the return place, as it cannot be substituted. + pub fn copy_classes(&self) -> &IndexSlice<Local, Local> { + &self.copy_classes + } + + /// Make a property uniform on a copy equivalence class by removing elements. + pub fn meet_copy_equivalence(&self, property: &mut BitSet<Local>) { + // Consolidate to have a local iff all its copies are. + // + // `copy_classes` defines equivalence classes between locals. The `local`s that recursively + // move/copy the same local all have the same `head`. + for (local, &head) in self.copy_classes.iter_enumerated() { + // If any copy does not have `property`, then the head is not. + if !property.contains(local) { + property.remove(head); + } + } + for (local, &head) in self.copy_classes.iter_enumerated() { + // If any copy does not have `property`, then the head doesn't either, + // then no copy has `property`. + if !property.contains(head) { + property.remove(local); + } + } + + // Verify that we correctly computed equivalence classes. + #[cfg(debug_assertions)] + for (local, &head) in self.copy_classes.iter_enumerated() { + assert_eq!(property.contains(local), property.contains(head)); + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum LocationExtended { + Plain(Location), + Arg, +} + +struct SsaVisitor<'a> { + dominators: SmallDominators<'a>, + assignments: IndexVec<Local, Set1<LocationExtended>>, + assignment_order: Vec<Local>, + direct_uses: IndexVec<Local, u32>, +} + +impl<'tcx> Visitor<'tcx> for SsaVisitor<'_> { + fn visit_local(&mut self, local: Local, ctxt: PlaceContext, loc: Location) { + match ctxt { + PlaceContext::MutatingUse(MutatingUseContext::Projection) + | PlaceContext::NonMutatingUse(NonMutatingUseContext::Projection) => bug!(), + // Anything can happen with raw pointers, so remove them. + // We do not verify that all uses of the borrow dominate the assignment to `local`, + // so we have to remove them too. + PlaceContext::NonMutatingUse( + NonMutatingUseContext::SharedBorrow + | NonMutatingUseContext::ShallowBorrow + | NonMutatingUseContext::AddressOf, + ) + | PlaceContext::MutatingUse(_) => { + self.assignments[local] = Set1::Many; + } + PlaceContext::NonMutatingUse(_) => { + self.dominators.check_dominates(&mut self.assignments[local], loc); + self.direct_uses[local] += 1; + } + PlaceContext::NonUse(_) => {} + } + } + + fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, loc: Location) { + if place.projection.first() == Some(&PlaceElem::Deref) { + // Do not do anything for storage statements and debuginfo. + if ctxt.is_use() { + // Only change the context if it is a real use, not a "use" in debuginfo. + let new_ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy); + + self.visit_projection(place.as_ref(), new_ctxt, loc); + self.dominators.check_dominates(&mut self.assignments[place.local], loc); + } + return; + } else { + self.visit_projection(place.as_ref(), ctxt, loc); + self.visit_local(place.local, ctxt, loc); + } + } + + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, loc: Location) { + if let Some(local) = place.as_local() { + self.assignments[local].insert(LocationExtended::Plain(loc)); + if let Set1::One(_) = self.assignments[local] { + // Only record if SSA-like, to avoid growing the vector needlessly. + self.assignment_order.push(local); + } + } else { + self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), loc); + } + self.visit_rvalue(rvalue, loc); + } +} + +#[instrument(level = "trace", skip(ssa, body))] +fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) { + let mut direct_uses = std::mem::take(&mut ssa.direct_uses); + let mut copies = IndexVec::from_fn_n(|l| l, body.local_decls.len()); + + for (local, rvalue, _) in ssa.assignments(body) { + let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) | Rvalue::CopyForDeref(place)) + = rvalue + else { continue }; + + let Some(rhs) = place.as_local() else { continue }; + if !ssa.is_ssa(rhs) { + continue; + } + + // We visit in `assignment_order`, ie. reverse post-order, so `rhs` has been + // visited before `local`, and we just have to copy the representing local. + let head = copies[rhs]; + + if local == RETURN_PLACE { + // `_0` is special, we cannot rename it. Instead, rename the class of `rhs` to + // `RETURN_PLACE`. This is only possible if the class head is a temporary, not an + // argument. + if body.local_kind(head) != LocalKind::Temp { + continue; + } + for h in copies.iter_mut() { + if *h == head { + *h = RETURN_PLACE; + } + } + } else { + copies[local] = head; + } + direct_uses[rhs] -= 1; + } + + debug!(?copies); + debug!(?direct_uses); + + // Invariant: `copies` must point to the head of an equivalence class. + #[cfg(debug_assertions)] + for &head in copies.iter() { + assert_eq!(copies[head], head); + } + debug_assert_eq!(copies[RETURN_PLACE], RETURN_PLACE); + + ssa.direct_uses = direct_uses; + ssa.copy_classes = copies; +} + +#[derive(Debug)] +pub(crate) struct StorageLiveLocals { + /// Set of "StorageLive" statements for each local. + storage_live: IndexVec<Local, Set1<LocationExtended>>, +} + +impl StorageLiveLocals { + pub(crate) fn new( + body: &Body<'_>, + always_storage_live_locals: &BitSet<Local>, + ) -> StorageLiveLocals { + let mut storage_live = IndexVec::from_elem(Set1::Empty, &body.local_decls); + for local in always_storage_live_locals.iter() { + storage_live[local] = Set1::One(LocationExtended::Arg); + } + for (block, bbdata) in body.basic_blocks.iter_enumerated() { + for (statement_index, statement) in bbdata.statements.iter().enumerate() { + if let StatementKind::StorageLive(local) = statement.kind { + storage_live[local] + .insert(LocationExtended::Plain(Location { block, statement_index })); + } + } + } + debug!(?storage_live); + StorageLiveLocals { storage_live } + } + + #[inline] + pub(crate) fn has_single_storage(&self, local: Local) -> bool { + matches!(self.storage_live[local], Set1::One(_)) + } +} diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs index bd196f11879..5389b9f52eb 100644 --- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs +++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs @@ -1,7 +1,7 @@ //! A pass that eliminates branches on uninhabited enum variants. use crate::MirPass; -use rustc_data_structures::stable_set::FxHashSet; +use rustc_data_structures::fx::FxHashSet; use rustc_middle::mir::{ BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator, TerminatorKind, @@ -65,7 +65,7 @@ fn variant_discriminants<'tcx>( Variants::Multiple { variants, .. } => variants .iter_enumerated() .filter_map(|(idx, layout)| { - (layout.abi() != Abi::Uninhabited) + (layout.abi != Abi::Uninhabited) .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) }) .collect(), @@ -79,7 +79,7 @@ fn ensure_otherwise_unreachable<'tcx>( targets: &SwitchTargets, ) -> Option<BasicBlockData<'tcx>> { let otherwise = targets.otherwise(); - let bb = &body.basic_blocks()[otherwise]; + let bb = &body.basic_blocks[otherwise]; if bb.terminator().kind == TerminatorKind::Unreachable && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_))) { @@ -102,14 +102,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("UninhabitedEnumBranching starting for {:?}", body.source); - for bb in body.basic_blocks().indices() { + for bb in body.basic_blocks.indices() { trace!("processing block {:?}", bb); - let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks()[bb], tcx, body) else { + let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body) else { continue; }; - let layout = tcx.layout_of(tcx.param_env(body.source.def_id()).and(discriminant_ty)); + let layout = tcx.layout_of( + tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty), + ); let allowed_variants = if let Ok(layout) = layout { variant_discriminants(&layout, discriminant_ty, tcx) diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index f916ca36217..bd1724bf842 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -12,9 +12,8 @@ pub struct UnreachablePropagation; impl MirPass<'_> for UnreachablePropagation { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // Enable only under -Zmir-opt-level=4 as in some cases (check the deeply-nested-opt - // perf benchmark) LLVM may spend quite a lot of time optimizing the generated code. - sess.mir_opt_level() >= 4 + // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable. + sess.mir_opt_level() >= 2 } fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -38,7 +37,19 @@ impl MirPass<'_> for UnreachablePropagation { } } + // We do want do keep some unreachable blocks, but make them empty. + for bb in unreachable_blocks { + if !tcx.consider_optimizing(|| { + format!("UnreachablePropagation {:?} ", body.source.def_id()) + }) { + break; + } + + body.basic_blocks_mut()[bb].statements.clear(); + } + let replaced = !replacements.is_empty(); + for (bb, terminator_kind) in replacements { if !tcx.consider_optimizing(|| { format!("UnreachablePropagation {:?} ", body.source.def_id()) @@ -57,42 +68,51 @@ impl MirPass<'_> for UnreachablePropagation { fn remove_successors<'tcx, F>( terminator_kind: &TerminatorKind<'tcx>, - predicate: F, + is_unreachable: F, ) -> Option<TerminatorKind<'tcx>> where F: Fn(BasicBlock) -> bool, { - let terminator = match *terminator_kind { - TerminatorKind::Goto { target } if predicate(target) => TerminatorKind::Unreachable, - TerminatorKind::SwitchInt { ref discr, switch_ty, ref targets } => { + let terminator = match terminator_kind { + // This will unconditionally run into an unreachable and is therefore unreachable as well. + TerminatorKind::Goto { target } if is_unreachable(*target) => TerminatorKind::Unreachable, + TerminatorKind::SwitchInt { targets, discr } => { let otherwise = targets.otherwise(); - let original_targets_len = targets.iter().len() + 1; - let (mut values, mut targets): (Vec<_>, Vec<_>) = - targets.iter().filter(|(_, bb)| !predicate(*bb)).unzip(); - - if !predicate(otherwise) { - targets.push(otherwise); - } else { - values.pop(); - } + // If all targets are unreachable, we can be unreachable as well. + if targets.all_targets().iter().all(|bb| is_unreachable(*bb)) { + TerminatorKind::Unreachable + } else if is_unreachable(otherwise) { + // If there are multiple targets, don't delete unreachable branches (like an unreachable otherwise) + // unless otherwise is unreachable, in which case deleting a normal branch causes it to be merged with + // the otherwise, keeping its unreachable. + // This looses information about reachability causing worse codegen. + // For example (see tests/codegen/match-optimizes-away.rs) + // + // pub enum Two { A, B } + // pub fn identity(x: Two) -> Two { + // match x { + // Two::A => Two::A, + // Two::B => Two::B, + // } + // } + // + // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to + // turn it into just `x` later. Without the unreachable, such a transformation would be illegal. + // If the otherwise branch is unreachable, we can delete all other unreachable targets, as they will + // still point to the unreachable and therefore not lose reachability information. + let reachable_iter = targets.iter().filter(|(_, bb)| !is_unreachable(*bb)); - let retained_targets_len = targets.len(); + let new_targets = SwitchTargets::new(reachable_iter, otherwise); - if targets.is_empty() { - TerminatorKind::Unreachable - } else if targets.len() == 1 { - TerminatorKind::Goto { target: targets[0] } - } else if original_targets_len != retained_targets_len { - TerminatorKind::SwitchInt { - discr: discr.clone(), - switch_ty, - targets: SwitchTargets::new( - values.iter().copied().zip(targets.iter().copied()), - *targets.last().unwrap(), - ), + // No unreachable branches were removed. + if new_targets.all_targets().len() == targets.all_targets().len() { + return None; } + + TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets } } else { + // If the otherwise branch is reachable, we don't want to delete any unreachable branches. return None; } } |
