diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
56 files changed, 4442 insertions, 2189 deletions
| diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs index ba70a4453d6..d43fca3dc7e 100644 --- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -1,5 +1,6 @@ use rustc_ast::InlineAsmOptions; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::ty::layout; use rustc_middle::ty::{self, TyCtxt}; use rustc_target::spec::abi::Abi; diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs index 430d9572e75..f880476cec2 100644 --- a/compiler/rustc_mir_transform/src/add_retag.rs +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -24,7 +24,7 @@ fn may_contain_reference<'tcx>(ty: Ty<'tcx>, depth: u32, tcx: TyCtxt<'tcx>) -> b | ty::Str | ty::FnDef(..) | ty::Never => false, - // References + // References and Boxes (`noalias` sources) ty::Ref(..) => true, ty::Adt(..) if ty.is_box() => true, ty::Adt(adt, _) if Some(adt.did()) == tcx.lang_items().ptr_unique() => true, @@ -118,22 +118,46 @@ impl<'tcx> MirPass<'tcx> for AddRetag { } // PART 3 - // Add retag after assignments where data "enters" this function: the RHS is behind a deref and the LHS is not. + // Add retag after assignments. for block_data in basic_blocks { // We want to insert statements as we iterate. To this end, we // iterate backwards using indices. for i in (0..block_data.statements.len()).rev() { let (retag_kind, place) = match block_data.statements[i].kind { // Retag after assignments of reference type. - StatementKind::Assign(box (ref place, ref rvalue)) if needs_retag(place) => { + StatementKind::Assign(box (ref place, ref rvalue)) => { let add_retag = match rvalue { // Ptr-creating operations already do their own internal retagging, no // need to also add a retag statement. - Rvalue::Ref(..) | Rvalue::AddressOf(..) => false, - _ => true, + // *Except* if we are deref'ing a Box, because those get desugared to directly working + // with the inner raw pointer! That's relevant for `AddressOf` as Miri otherwise makes it + // a NOP when the original pointer is already raw. + Rvalue::AddressOf(_mutbl, place) => { + // Using `is_box_global` here is a bit sketchy: if this code is + // generic over the allocator, we'll not add a retag! This is a hack + // to make Stacked Borrows compatible with custom allocator code. + // Long-term, we'll want to move to an aliasing model where "cast to + // raw pointer" is a complete NOP, and then this will no longer be + // an issue. + if place.is_indirect_first_projection() + && body.local_decls[place.local].ty.is_box_global(tcx) + { + Some(RetagKind::Raw) + } else { + None + } + } + Rvalue::Ref(..) => None, + _ => { + if needs_retag(place) { + Some(RetagKind::Default) + } else { + None + } + } }; - if add_retag { - (RetagKind::Default, *place) + if let Some(kind) = add_retag { + (kind, *place) } else { continue; } diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs index 9eec724ef21..5199c41c58c 100644 --- a/compiler/rustc_mir_transform/src/check_alignment.rs +++ b/compiler/rustc_mir_transform/src/check_alignment.rs @@ -5,7 +5,7 @@ use rustc_middle::mir::{ interpret::Scalar, visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}, }; -use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, TypeAndMut}; +use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt}; use rustc_session::Session; pub struct CheckAlignment; @@ -16,7 +16,7 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment { if sess.target.llvm_target == "i686-pc-windows-msvc" { return false; } - sess.opts.debug_assertions + sess.ub_checks() } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -106,7 +106,7 @@ impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> { } let pointee_ty = - pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty; + pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer"); // Ideally we'd support this in the future, but for now we are limited to sized types. if !pointee_ty.is_sized(self.tcx, self.param_env) { debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty); @@ -157,7 +157,7 @@ fn insert_alignment_check<'tcx>( new_block: BasicBlock, ) { // Cast the pointer to a *const () - let const_raw_ptr = Ty::new_ptr(tcx, TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not }); + let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit); let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); block_data diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs index a405ed6088d..5f67bd75c48 100644 --- a/compiler/rustc_mir_transform/src/check_packed_ref.rs +++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs @@ -1,5 +1,6 @@ use rustc_middle::mir::visit::{PlaceContext, Visitor}; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::ty::{self, TyCtxt}; use crate::MirLint; diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs deleted file mode 100644 index a0c3de3af58..00000000000 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ /dev/null @@ -1,615 +0,0 @@ -use rustc_data_structures::unord::{ExtendUnord, UnordItems, UnordSet}; -use rustc_hir as hir; -use rustc_hir::def::DefKind; -use rustc_hir::def_id::{DefId, LocalDefId}; -use rustc_hir::hir_id::HirId; -use rustc_hir::intravisit; -use rustc_hir::{BlockCheckMode, ExprKind, Node}; -use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor}; -use rustc_middle::mir::*; -use rustc_middle::query::Providers; -use rustc_middle::ty::{self, TyCtxt}; -use rustc_session::lint::builtin::{UNSAFE_OP_IN_UNSAFE_FN, UNUSED_UNSAFE}; -use rustc_session::lint::Level; - -use std::ops::Bound; - -use crate::errors; - -pub struct UnsafetyChecker<'a, 'tcx> { - body: &'a Body<'tcx>, - body_did: LocalDefId, - violations: Vec<UnsafetyViolation>, - source_info: SourceInfo, - tcx: TyCtxt<'tcx>, - param_env: ty::ParamEnv<'tcx>, - - /// Used `unsafe` blocks in this function. This is used for the "unused_unsafe" lint. - used_unsafe_blocks: UnordSet<HirId>, -} - -impl<'a, 'tcx> UnsafetyChecker<'a, 'tcx> { - fn new( - body: &'a Body<'tcx>, - body_did: LocalDefId, - tcx: TyCtxt<'tcx>, - param_env: ty::ParamEnv<'tcx>, - ) -> Self { - Self { - body, - body_did, - violations: vec![], - source_info: SourceInfo::outermost(body.span), - tcx, - param_env, - used_unsafe_blocks: Default::default(), - } - } -} - -impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - self.source_info = terminator.source_info; - match terminator.kind { - TerminatorKind::Goto { .. } - | TerminatorKind::SwitchInt { .. } - | TerminatorKind::Drop { .. } - | TerminatorKind::Yield { .. } - | TerminatorKind::Assert { .. } - | TerminatorKind::CoroutineDrop - | TerminatorKind::UnwindResume - | TerminatorKind::UnwindTerminate(_) - | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::FalseEdge { .. } - | TerminatorKind::FalseUnwind { .. } => { - // safe (at least as emitted during MIR construction) - } - - TerminatorKind::Call { ref func, .. } => { - let func_ty = func.ty(self.body, self.tcx); - let func_id = - if let ty::FnDef(func_id, _) = func_ty.kind() { Some(func_id) } else { None }; - let sig = func_ty.fn_sig(self.tcx); - if let hir::Unsafety::Unsafe = sig.unsafety() { - self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::CallToUnsafeFunction, - ) - } - - if let Some(func_id) = func_id { - self.check_target_features(*func_id); - } - } - - TerminatorKind::InlineAsm { .. } => self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::UseOfInlineAssembly, - ), - } - self.super_terminator(terminator, location); - } - - fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { - self.source_info = statement.source_info; - match statement.kind { - StatementKind::Assign(..) - | StatementKind::FakeRead(..) - | StatementKind::SetDiscriminant { .. } - | StatementKind::Deinit(..) - | StatementKind::StorageLive(..) - | StatementKind::StorageDead(..) - | StatementKind::Retag { .. } - | StatementKind::PlaceMention(..) - | StatementKind::Coverage(..) - | StatementKind::Intrinsic(..) - | StatementKind::ConstEvalCounter - | StatementKind::Nop => { - // safe (at least as emitted during MIR construction) - } - // `AscribeUserType` just exists to help MIR borrowck. - // It has no semantics, and everything is already reported by `PlaceMention`. - StatementKind::AscribeUserType(..) => return, - } - self.super_statement(statement, location); - } - - fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { - match rvalue { - Rvalue::Aggregate(box ref aggregate, _) => match aggregate { - &AggregateKind::Array(..) | &AggregateKind::Tuple => {} - &AggregateKind::Adt(adt_did, ..) => { - match self.tcx.layout_scalar_valid_range(adt_did) { - (Bound::Unbounded, Bound::Unbounded) => {} - _ => self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::InitializingTypeWith, - ), - } - } - &AggregateKind::Closure(def_id, _) - | &AggregateKind::CoroutineClosure(def_id, _) - | &AggregateKind::Coroutine(def_id, _) => { - let def_id = def_id.expect_local(); - let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = - self.tcx.mir_unsafety_check_result(def_id); - self.register_violations(violations, used_unsafe_blocks.items().copied()); - } - }, - _ => {} - } - self.super_rvalue(rvalue, location); - } - - fn visit_operand(&mut self, op: &Operand<'tcx>, location: Location) { - if let Operand::Constant(constant) = op { - let maybe_uneval = match constant.const_ { - Const::Val(..) | Const::Ty(_) => None, - Const::Unevaluated(uv, _) => Some(uv), - }; - - if let Some(uv) = maybe_uneval { - if uv.promoted.is_none() { - let def_id = uv.def; - if self.tcx.def_kind(def_id) == DefKind::InlineConst { - let local_def_id = def_id.expect_local(); - let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = - self.tcx.mir_unsafety_check_result(local_def_id); - self.register_violations(violations, used_unsafe_blocks.items().copied()); - } - } - } - } - self.super_operand(op, location); - } - - fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { - // On types with `scalar_valid_range`, prevent - // * `&mut x.field` - // * `x.field = y;` - // * `&x.field` if `field`'s type has interior mutability - // because either of these would allow modifying the layout constrained field and - // insert values that violate the layout constraints. - if context.is_mutating_use() || context.is_borrow() { - self.check_mut_borrowing_layout_constrained_field(*place, context.is_mutating_use()); - } - - // Some checks below need the extra meta info of the local declaration. - let decl = &self.body.local_decls[place.local]; - - // Check the base local: it might be an unsafe-to-access static. We only check derefs of the - // temporary holding the static pointer to avoid duplicate errors - // <https://github.com/rust-lang/rust/pull/78068#issuecomment-731753506>. - if place.projection.first() == Some(&ProjectionElem::Deref) { - // If the projection root is an artificial local that we introduced when - // desugaring `static`, give a more specific error message - // (avoid the general "raw pointer" clause below, that would only be confusing). - if let LocalInfo::StaticRef { def_id, .. } = *decl.local_info() { - if self.tcx.is_mutable_static(def_id) { - self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::UseOfMutableStatic, - ); - return; - } else if self.tcx.is_foreign_item(def_id) { - self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::UseOfExternStatic, - ); - return; - } - } - } - - // Check for raw pointer `Deref`. - for (base, proj) in place.iter_projections() { - if proj == ProjectionElem::Deref { - let base_ty = base.ty(self.body, self.tcx).ty; - if base_ty.is_unsafe_ptr() { - self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::DerefOfRawPointer, - ) - } - } - } - - // Check for union fields. For this we traverse right-to-left, as the last `Deref` changes - // whether we *read* the union field or potentially *write* to it (if this place is being assigned to). - let mut saw_deref = false; - for (base, proj) in place.iter_projections().rev() { - if proj == ProjectionElem::Deref { - saw_deref = true; - continue; - } - - let base_ty = base.ty(self.body, self.tcx).ty; - if base_ty.is_union() { - // If we did not hit a `Deref` yet and the overall place use is an assignment, the - // rules are different. - let assign_to_field = !saw_deref - && matches!( - context, - PlaceContext::MutatingUse( - MutatingUseContext::Store - | MutatingUseContext::Drop - | MutatingUseContext::AsmOutput - ) - ); - // If this is just an assignment, determine if the assigned type needs dropping. - if assign_to_field { - // We have to check the actual type of the assignment, as that determines if the - // old value is being dropped. - let assigned_ty = place.ty(&self.body.local_decls, self.tcx).ty; - if assigned_ty.needs_drop(self.tcx, self.param_env) { - // This would be unsafe, but should be outright impossible since we reject - // such unions. - assert!( - self.tcx.dcx().has_errors().is_some(), - "union fields that need dropping should be impossible: {assigned_ty}" - ); - } - } else { - self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::AccessToUnionField, - ) - } - } - } - } -} - -impl<'tcx> UnsafetyChecker<'_, 'tcx> { - fn require_unsafe(&mut self, kind: UnsafetyViolationKind, details: UnsafetyViolationDetails) { - // Violations can turn out to be `UnsafeFn` during analysis, but they should not start out as such. - assert_ne!(kind, UnsafetyViolationKind::UnsafeFn); - - let source_info = self.source_info; - let lint_root = self.body.source_scopes[self.source_info.scope] - .local_data - .as_ref() - .assert_crate_local() - .lint_root; - self.register_violations( - [&UnsafetyViolation { source_info, lint_root, kind, details }], - UnordItems::empty(), - ); - } - - fn register_violations<'a>( - &mut self, - violations: impl IntoIterator<Item = &'a UnsafetyViolation>, - new_used_unsafe_blocks: UnordItems<HirId, impl Iterator<Item = HirId>>, - ) { - let safety = self.body.source_scopes[self.source_info.scope] - .local_data - .as_ref() - .assert_crate_local() - .safety; - match safety { - // `unsafe` blocks are required in safe code - Safety::Safe => violations.into_iter().for_each(|violation| { - match violation.kind { - UnsafetyViolationKind::General => {} - UnsafetyViolationKind::UnsafeFn => { - bug!("`UnsafetyViolationKind::UnsafeFn` in an `Safe` context") - } - } - if !self.violations.contains(violation) { - self.violations.push(violation.clone()) - } - }), - // With the RFC 2585, no longer allow `unsafe` operations in `unsafe fn`s - Safety::FnUnsafe => violations.into_iter().for_each(|violation| { - let mut violation = violation.clone(); - violation.kind = UnsafetyViolationKind::UnsafeFn; - if !self.violations.contains(&violation) { - self.violations.push(violation) - } - }), - Safety::BuiltinUnsafe => {} - Safety::ExplicitUnsafe(hir_id) => violations.into_iter().for_each(|_violation| { - self.used_unsafe_blocks.insert(hir_id); - }), - }; - - self.used_unsafe_blocks.extend_unord(new_used_unsafe_blocks); - } - fn check_mut_borrowing_layout_constrained_field( - &mut self, - place: Place<'tcx>, - is_mut_use: bool, - ) { - for (place_base, elem) in place.iter_projections().rev() { - match elem { - // Modifications behind a dereference don't affect the value of - // the pointer. - ProjectionElem::Deref => return, - ProjectionElem::Field(..) => { - let ty = place_base.ty(&self.body.local_decls, self.tcx).ty; - if let ty::Adt(def, _) = ty.kind() { - if self.tcx.layout_scalar_valid_range(def.did()) - != (Bound::Unbounded, Bound::Unbounded) - { - let details = if is_mut_use { - UnsafetyViolationDetails::MutationOfLayoutConstrainedField - - // Check `is_freeze` as late as possible to avoid cycle errors - // with opaque types. - } else if !place - .ty(self.body, self.tcx) - .ty - .is_freeze(self.tcx, self.param_env) - { - UnsafetyViolationDetails::BorrowOfLayoutConstrainedField - } else { - continue; - }; - self.require_unsafe(UnsafetyViolationKind::General, details); - } - } - } - _ => {} - } - } - } - - /// Checks whether calling `func_did` needs an `unsafe` context or not, i.e. whether - /// the called function has target features the calling function hasn't. - fn check_target_features(&mut self, func_did: DefId) { - // Unsafety isn't required on wasm targets. For more information see - // the corresponding check in typeck/src/collect.rs - if self.tcx.sess.target.options.is_like_wasm { - return; - } - - let callee_features = &self.tcx.codegen_fn_attrs(func_did).target_features; - // The body might be a constant, so it doesn't have codegen attributes. - let self_features = &self.tcx.body_codegen_attrs(self.body_did.to_def_id()).target_features; - - // Is `callee_features` a subset of `calling_features`? - if !callee_features.iter().all(|feature| self_features.contains(feature)) { - let missing: Vec<_> = callee_features - .iter() - .copied() - .filter(|feature| !self_features.contains(feature)) - .collect(); - let build_enabled = self - .tcx - .sess - .target_features - .iter() - .copied() - .filter(|feature| missing.contains(feature)) - .collect(); - self.require_unsafe( - UnsafetyViolationKind::General, - UnsafetyViolationDetails::CallToFunctionWith { missing, build_enabled }, - ) - } - } -} - -pub(crate) fn provide(providers: &mut Providers) { - *providers = Providers { mir_unsafety_check_result, ..*providers }; -} - -/// Context information for [`UnusedUnsafeVisitor`] traversal, -/// saves (innermost) relevant context -#[derive(Copy, Clone, Debug)] -enum Context { - Safe, - /// in an `unsafe fn` - UnsafeFn, - /// in a *used* `unsafe` block - /// (i.e. a block without unused-unsafe warning) - UnsafeBlock(HirId), -} - -struct UnusedUnsafeVisitor<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - used_unsafe_blocks: &'a UnordSet<HirId>, - context: Context, - unused_unsafes: &'a mut Vec<(HirId, UnusedUnsafe)>, -} - -impl<'tcx> intravisit::Visitor<'tcx> for UnusedUnsafeVisitor<'_, 'tcx> { - fn visit_block(&mut self, block: &'tcx hir::Block<'tcx>) { - if let hir::BlockCheckMode::UnsafeBlock(hir::UnsafeSource::UserProvided) = block.rules { - let used = match self.tcx.lint_level_at_node(UNUSED_UNSAFE, block.hir_id) { - (Level::Allow, _) => true, - _ => self.used_unsafe_blocks.contains(&block.hir_id), - }; - let unused_unsafe = match (self.context, used) { - (_, false) => UnusedUnsafe::Unused, - (Context::Safe, true) | (Context::UnsafeFn, true) => { - let previous_context = self.context; - self.context = Context::UnsafeBlock(block.hir_id); - intravisit::walk_block(self, block); - self.context = previous_context; - return; - } - (Context::UnsafeBlock(hir_id), true) => UnusedUnsafe::InUnsafeBlock(hir_id), - }; - self.unused_unsafes.push((block.hir_id, unused_unsafe)); - } - intravisit::walk_block(self, block); - } - - fn visit_inline_const(&mut self, c: &'tcx hir::ConstBlock) { - self.visit_body(self.tcx.hir().body(c.body)) - } - - fn visit_fn( - &mut self, - fk: intravisit::FnKind<'tcx>, - _fd: &'tcx hir::FnDecl<'tcx>, - b: hir::BodyId, - _s: rustc_span::Span, - _id: LocalDefId, - ) { - if matches!(fk, intravisit::FnKind::Closure) { - self.visit_body(self.tcx.hir().body(b)) - } - } -} - -fn check_unused_unsafe( - tcx: TyCtxt<'_>, - def_id: LocalDefId, - used_unsafe_blocks: &UnordSet<HirId>, -) -> Vec<(HirId, UnusedUnsafe)> { - let body_id = tcx.hir().maybe_body_owned_by(def_id); - - let Some(body_id) = body_id else { - debug!("check_unused_unsafe({:?}) - no body found", def_id); - return vec![]; - }; - - let body = tcx.hir().body(body_id); - let hir_id = tcx.local_def_id_to_hir_id(def_id); - let context = match tcx.hir().fn_sig_by_hir_id(hir_id) { - Some(sig) if sig.header.unsafety == hir::Unsafety::Unsafe => Context::UnsafeFn, - _ => Context::Safe, - }; - - debug!( - "check_unused_unsafe({:?}, context={:?}, body={:?}, used_unsafe_blocks={:?})", - def_id, body, context, used_unsafe_blocks - ); - - let mut unused_unsafes = vec![]; - - let mut visitor = UnusedUnsafeVisitor { - tcx, - used_unsafe_blocks, - context, - unused_unsafes: &mut unused_unsafes, - }; - intravisit::Visitor::visit_body(&mut visitor, body); - - unused_unsafes -} - -fn mir_unsafety_check_result(tcx: TyCtxt<'_>, def: LocalDefId) -> &UnsafetyCheckResult { - debug!("unsafety_violations({:?})", def); - - // N.B., this borrow is valid because all the consumers of - // `mir_built` force this. - let body = &tcx.mir_built(def).borrow(); - - if body.is_custom_mir() || body.tainted_by_errors.is_some() { - return tcx.arena.alloc(UnsafetyCheckResult { - violations: Vec::new(), - used_unsafe_blocks: Default::default(), - unused_unsafes: Some(Vec::new()), - }); - } - - let param_env = tcx.param_env(def); - - let mut checker = UnsafetyChecker::new(body, def, tcx, param_env); - checker.visit_body(body); - - let unused_unsafes = (!tcx.is_typeck_child(def.to_def_id())) - .then(|| check_unused_unsafe(tcx, def, &checker.used_unsafe_blocks)); - - tcx.arena.alloc(UnsafetyCheckResult { - violations: checker.violations, - used_unsafe_blocks: checker.used_unsafe_blocks, - unused_unsafes, - }) -} - -fn report_unused_unsafe(tcx: TyCtxt<'_>, kind: UnusedUnsafe, id: HirId) { - let span = tcx.sess.source_map().guess_head_span(tcx.hir().span(id)); - let nested_parent = if let UnusedUnsafe::InUnsafeBlock(id) = kind { - Some(tcx.sess.source_map().guess_head_span(tcx.hir().span(id))) - } else { - None - }; - tcx.emit_node_span_lint(UNUSED_UNSAFE, id, span, errors::UnusedUnsafe { span, nested_parent }); -} - -pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { - debug!("check_unsafety({:?})", def_id); - - // closures and inline consts are handled by their parent fn. - if tcx.is_typeck_child(def_id.to_def_id()) { - return; - } - - let UnsafetyCheckResult { violations, unused_unsafes, .. } = - tcx.mir_unsafety_check_result(def_id); - // Only suggest wrapping the entire function body in an unsafe block once - let mut suggest_unsafe_block = true; - - for &UnsafetyViolation { source_info, lint_root, kind, ref details } in violations.iter() { - let details = - errors::RequiresUnsafeDetail { violation: details.clone(), span: source_info.span }; - - match kind { - UnsafetyViolationKind::General => { - let op_in_unsafe_fn_allowed = unsafe_op_in_unsafe_fn_allowed(tcx, lint_root); - let note_non_inherited = tcx.hir().parent_iter(lint_root).find(|(id, node)| { - if let Node::Expr(block) = node - && let ExprKind::Block(block, _) = block.kind - && let BlockCheckMode::UnsafeBlock(_) = block.rules - { - true - } else if let Some(sig) = tcx.hir().fn_sig_by_hir_id(*id) - && sig.header.is_unsafe() - { - true - } else { - false - } - }); - let enclosing = if let Some((id, _)) = note_non_inherited { - Some(tcx.sess.source_map().guess_head_span(tcx.hir().span(id))) - } else { - None - }; - tcx.dcx().emit_err(errors::RequiresUnsafe { - span: source_info.span, - enclosing, - details, - op_in_unsafe_fn_allowed, - }); - } - UnsafetyViolationKind::UnsafeFn => { - tcx.emit_node_span_lint( - UNSAFE_OP_IN_UNSAFE_FN, - lint_root, - source_info.span, - errors::UnsafeOpInUnsafeFn { - details, - suggest_unsafe_block: suggest_unsafe_block.then(|| { - let hir_id = tcx.local_def_id_to_hir_id(def_id); - let fn_sig = tcx - .hir() - .fn_sig_by_hir_id(hir_id) - .expect("this violation only occurs in fn"); - let body = tcx.hir().body_owned_by(def_id); - let body_span = tcx.hir().body(body).value.span; - let start = tcx.sess.source_map().start_point(body_span).shrink_to_hi(); - let end = tcx.sess.source_map().end_point(body_span).shrink_to_lo(); - (start, end, fn_sig.span) - }), - }, - ); - suggest_unsafe_block = false; - } - } - } - - for &(block_id, kind) in unused_unsafes.as_ref().unwrap() { - report_unused_unsafe(tcx, kind, block_id); - } -} - -fn unsafe_op_in_unsafe_fn_allowed(tcx: TyCtxt<'_>, id: HirId) -> bool { - tcx.lint_level_at_node(UNSAFE_OP_IN_UNSAFE_FN, id).0 == Level::Allow -} diff --git a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs index 5b4bc4fa134..48a6a83e146 100644 --- a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs +++ b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs @@ -5,14 +5,19 @@ //! - [`AscribeUserType`] //! - [`FakeRead`] //! - [`Assign`] statements with a [`Fake`] borrow +//! - [`Coverage`] statements of kind [`BlockMarker`] or [`SpanMarker`] //! //! [`AscribeUserType`]: rustc_middle::mir::StatementKind::AscribeUserType //! [`Assign`]: rustc_middle::mir::StatementKind::Assign //! [`FakeRead`]: rustc_middle::mir::StatementKind::FakeRead //! [`Nop`]: rustc_middle::mir::StatementKind::Nop //! [`Fake`]: rustc_middle::mir::BorrowKind::Fake +//! [`Coverage`]: rustc_middle::mir::StatementKind::Coverage +//! [`BlockMarker`]: rustc_middle::mir::coverage::CoverageKind::BlockMarker +//! [`SpanMarker`]: rustc_middle::mir::coverage::CoverageKind::SpanMarker use crate::MirPass; +use rustc_middle::mir::coverage::CoverageKind; use rustc_middle::mir::{Body, BorrowKind, Rvalue, StatementKind, TerminatorKind}; use rustc_middle::ty::TyCtxt; @@ -24,7 +29,12 @@ impl<'tcx> MirPass<'tcx> for CleanupPostBorrowck { for statement in basic_block.statements.iter_mut() { match statement.kind { StatementKind::AscribeUserType(..) - | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Fake, _))) + | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Fake(_), _))) + | StatementKind::Coverage( + // These kinds of coverage statements are markers inserted during + // MIR building, and are not needed after InstrumentCoverage. + CoverageKind::BlockMarker { .. } | CoverageKind::SpanMarker { .. }, + ) | StatementKind::FakeRead(..) => statement.make_nop(), _ => (), } diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs index 0119b95cced..c1f9313a377 100644 --- a/compiler/rustc_mir_transform/src/copy_prop.rs +++ b/compiler/rustc_mir_transform/src/copy_prop.rs @@ -3,7 +3,6 @@ use rustc_index::IndexSlice; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; -use rustc_mir_dataflow::impls::borrowed_locals; use crate::ssa::SsaLocals; @@ -32,8 +31,8 @@ impl<'tcx> MirPass<'tcx> for CopyProp { } fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let borrowed_locals = borrowed_locals(body); - let ssa = SsaLocals::new(body); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let ssa = SsaLocals::new(tcx, body, param_env); let fully_moved = fully_moved_locals(&ssa, body); debug!(?fully_moved); @@ -51,7 +50,7 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { tcx, copy_classes: ssa.copy_classes(), fully_moved, - borrowed_locals, + borrowed_locals: ssa.borrowed_locals(), storage_to_remove, } .visit_body_preserves_cfg(body); @@ -101,7 +100,7 @@ struct Replacer<'a, 'tcx> { tcx: TyCtxt<'tcx>, fully_moved: BitSet<Local>, storage_to_remove: BitSet<Local>, - borrowed_locals: BitSet<Local>, + borrowed_locals: &'a BitSet<Local>, copy_classes: &'a IndexSlice<Local, Local>, } @@ -112,6 +111,12 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { fn visit_local(&mut self, local: &mut Local, ctxt: PlaceContext, _: Location) { let new_local = self.copy_classes[*local]; + // We must not unify two locals that are borrowed. But this is fine if one is borrowed and + // the other is not. We chose to check the original local, and not the target. That way, if + // the original local is borrowed and the target is not, we do not pessimize the whole class. + if self.borrowed_locals.contains(*local) { + return; + } match ctxt { // Do not modify the local in storage statements. PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {} @@ -122,32 +127,14 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { } } - fn visit_place(&mut self, place: &mut Place<'tcx>, ctxt: PlaceContext, loc: Location) { + fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, loc: Location) { if let Some(new_projection) = self.process_projection(place.projection, loc) { place.projection = self.tcx().mk_place_elems(&new_projection); } - let observes_address = match ctxt { - PlaceContext::NonMutatingUse( - NonMutatingUseContext::SharedBorrow - | NonMutatingUseContext::FakeBorrow - | NonMutatingUseContext::AddressOf, - ) => true, - // For debuginfo, merging locals is ok. - PlaceContext::NonUse(NonUseContext::VarDebugInfo) => { - self.borrowed_locals.contains(place.local) - } - _ => false, - }; - if observes_address && !place.is_indirect() { - // We observe the address of `place.local`. Do not replace it. - } else { - self.visit_local( - &mut place.local, - PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy), - loc, - ) - } + // Any non-mutating use context is ok. + let ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy); + self.visit_local(&mut place.local, ctxt, loc) } fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) { diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index 54b13a40e92..a3e6e5a5a91 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -70,6 +70,7 @@ use rustc_middle::mir::*; use rustc_middle::ty::CoroutineArgs; use rustc_middle::ty::InstanceDef; use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::{bug, span_bug}; use rustc_mir_dataflow::impls::{ MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, }; @@ -80,6 +81,10 @@ use rustc_span::symbol::sym; use rustc_span::Span; use rustc_target::abi::{FieldIdx, VariantIdx}; use rustc_target::spec::PanicStrategy; +use rustc_trait_selection::infer::TyCtxtInferExt as _; +use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _; +use rustc_trait_selection::traits::ObligationCtxt; +use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode}; use std::{iter, ops}; pub struct StateTransform; @@ -168,7 +173,7 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> { Place { local: SELF_ARG, projection: self.tcx().mk_place_elems(&[ProjectionElem::Field( - FieldIdx::new(0), + FieldIdx::ZERO, self.ref_coroutine_ty, )]), }, @@ -267,7 +272,7 @@ impl<'tcx> TransformVisitor<'tcx> { Rvalue::Aggregate( Box::new(AggregateKind::Adt( option_def_id, - VariantIdx::from_usize(0), + VariantIdx::ZERO, self.tcx.mk_args(&[self.old_yield_ty.into()]), None, None, @@ -329,7 +334,7 @@ impl<'tcx> TransformVisitor<'tcx> { Rvalue::Aggregate( Box::new(AggregateKind::Adt( poll_def_id, - VariantIdx::from_usize(0), + VariantIdx::ZERO, args, None, None, @@ -358,7 +363,7 @@ impl<'tcx> TransformVisitor<'tcx> { Rvalue::Aggregate( Box::new(AggregateKind::Adt( option_def_id, - VariantIdx::from_usize(0), + VariantIdx::ZERO, args, None, None, @@ -420,7 +425,7 @@ impl<'tcx> TransformVisitor<'tcx> { Rvalue::Aggregate( Box::new(AggregateKind::Adt( coroutine_state_def_id, - VariantIdx::from_usize(0), + VariantIdx::ZERO, args, None, None, @@ -570,11 +575,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let coroutine_ty = body.local_decls.raw[1].ty; - let ref_coroutine_ty = Ty::new_ref( - tcx, - tcx.lifetimes.re_erased, - ty::TypeAndMut { ty: coroutine_ty, mutbl: Mutability::Mut }, - ); + let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty); // Replace the by value coroutine argument body.local_decls.raw[1].ty = ref_coroutine_ty; @@ -1230,7 +1231,7 @@ fn create_coroutine_drop_shim<'tcx>( tcx: TyCtxt<'tcx>, transform: &TransformVisitor<'tcx>, coroutine_ty: Ty<'tcx>, - body: &mut Body<'tcx>, + body: &Body<'tcx>, drop_clean: BasicBlock, ) -> Body<'tcx> { let mut body = body.clone(); @@ -1260,15 +1261,13 @@ fn create_coroutine_drop_shim<'tcx>( } // Replace the return variable - body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(Ty::new_unit(tcx), source_info); + body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info); make_coroutine_state_argument_indirect(tcx, &mut body); // Change the coroutine argument from &mut to *mut - body.local_decls[SELF_ARG] = LocalDecl::with_source_info( - Ty::new_ptr(tcx, ty::TypeAndMut { ty: coroutine_ty, mutbl: hir::Mutability::Mut }), - source_info, - ); + body.local_decls[SELF_ARG] = + LocalDecl::with_source_info(Ty::new_mut_ptr(tcx, coroutine_ty), source_info); // Make sure we remove dead blocks to remove // unrelated code from the resume part of the function @@ -1590,10 +1589,46 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>( let (_, coroutine_layout, _) = compute_layout(liveness_info, body); check_suspend_tys(tcx, &coroutine_layout, body); + check_field_tys_sized(tcx, &coroutine_layout, def_id); Some(coroutine_layout) } +fn check_field_tys_sized<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_layout: &CoroutineLayout<'tcx>, + def_id: LocalDefId, +) { + // No need to check if unsized_locals/unsized_fn_params is disabled, + // since we will error during typeck. + if !tcx.features().unsized_locals && !tcx.features().unsized_fn_params { + return; + } + + let infcx = tcx.infer_ctxt().ignoring_regions().build(); + let param_env = tcx.param_env(def_id); + + let ocx = ObligationCtxt::new(&infcx); + for field_ty in &coroutine_layout.field_tys { + ocx.register_bound( + ObligationCause::new( + field_ty.source_info.span, + def_id, + ObligationCauseCode::SizedCoroutineInterior(def_id), + ), + param_env, + field_ty.ty, + tcx.require_lang_item(hir::LangItem::Sized, Some(field_ty.source_info.span)), + ); + } + + let errors = ocx.select_all_or_error(); + debug!(?errors); + if !errors.is_empty() { + infcx.err_ctxt().report_fulfillment_errors(errors); + } +} + impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let Some(old_yield_ty) = body.yield_ty() else { @@ -1964,15 +1999,21 @@ fn check_must_not_suspend_ty<'tcx>( debug!("Checking must_not_suspend for {}", ty); match *ty.kind() { - ty::Adt(..) if ty.is_box() => { - let boxed_ty = ty.boxed_ty(); - let descr_pre = &format!("{}boxed ", data.descr_pre); + ty::Adt(_, args) if ty.is_box() => { + let boxed_ty = args.type_at(0); + let allocator_ty = args.type_at(1); check_must_not_suspend_ty( tcx, boxed_ty, hir_id, param_env, - SuspendCheckData { descr_pre, ..data }, + SuspendCheckData { descr_pre: &format!("{}boxed ", data.descr_pre), ..data }, + ) || check_must_not_suspend_ty( + tcx, + allocator_ty, + hir_id, + param_env, + SuspendCheckData { descr_pre: &format!("{}allocator ", data.descr_pre), ..data }, ) } ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data), diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index e40f4520671..10c0567eb4b 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -1,19 +1,89 @@ -//! A MIR pass which duplicates a coroutine's body and removes any derefs which -//! would be present for upvars that are taken by-ref. The result of which will -//! be a coroutine body that takes all of its upvars by-move, and which we stash -//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures. +//! This pass constructs a second coroutine body sufficient for return from +//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures). +//! +//! Consider an async closure like: +//! ```rust +//! #![feature(async_closure)] +//! +//! let x = vec![1, 2, 3]; +//! +//! let closure = async move || { +//! println!("{x:#?}"); +//! }; +//! ``` +//! +//! This desugars to something like: +//! ```rust,ignore (invalid-borrowck) +//! let x = vec![1, 2, 3]; +//! +//! let closure = move || { +//! async { +//! println!("{x:#?}"); +//! } +//! }; +//! ``` +//! +//! Important to note here is that while the outer closure *moves* `x: Vec<i32>` +//! into its upvars, the inner `async` coroutine simply captures a ref of `x`. +//! This is the "magic" of async closures -- the futures that they return are +//! allowed to borrow from their parent closure's upvars. +//! +//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`, +//! since all async closures implement that too)? Well, recall the signature: +//! ``` +//! use std::future::Future; +//! pub trait AsyncFnOnce<Args> +//! { +//! type CallOnceFuture: Future<Output = Self::Output>; +//! type Output; +//! fn async_call_once( +//! self, +//! args: Args +//! ) -> Self::CallOnceFuture; +//! } +//! ``` +//! +//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`. +//! How do we deal with the fact that the coroutine is supposed to take a reference +//! to the captured `x` from the parent closure, when that parent closure has been +//! destroyed? +//! +//! This is the second piece of magic of async closures. We can simply create a +//! *second* `async` coroutine body where that `x` that was previously captured +//! by reference is now captured by value. This means that we consume the outer +//! closure and return a new coroutine that will hold onto all of these captures, +//! and drop them when it is finished (i.e. after it has been `.await`ed). +//! +//! We do this with the analysis below, which detects the captures that come from +//! borrowing from the outer closure, and we simply peel off a `deref` projection +//! from them. This second body is stored alongside the first body, and optimized +//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`, +//! we use this "by-move" body instead. +//! +//! ## How does this work? +//! +//! This pass essentially remaps the body of the (child) closure of the coroutine-closure +//! to take the set of upvars of the parent closure by value. This at least requires +//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure +//! captures something by value; however, it may also require renumbering field indices +//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine +//! to split one field capture into two. -use rustc_data_structures::fx::FxIndexSet; +use rustc_data_structures::unord::UnordMap; use rustc_hir as hir; +use rustc_middle::bug; +use rustc_middle::hir::place::{Projection, ProjectionKind}; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::{self, dump_mir, MirPass}; use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt}; -use rustc_target::abi::FieldIdx; +use rustc_target::abi::{FieldIdx, VariantIdx}; pub struct ByMoveBody; impl<'tcx> MirPass<'tcx> for ByMoveBody { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + // We only need to generate by-move coroutine bodies for coroutines that come + // from coroutine-closures. let Some(coroutine_def_id) = body.source.def_id().as_local() else { return; }; @@ -22,96 +92,117 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { else { return; }; + + // Also, let's skip processing any bodies with errors, since there's no guarantee + // the MIR body will be constructed well. let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; if coroutine_ty.references_error() { return; } - let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") }; - let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(); - if coroutine_kind == ty::ClosureKind::FnOnce { + // We don't need to generate a by-move coroutine if the coroutine body was + // produced by the `CoroutineKindShim`, since it's already by-move. + if matches!(body.source.instance, ty::InstanceDef::CoroutineKindShim { .. }) { return; } - let mut by_ref_fields = FxIndexSet::default(); - let by_move_upvars = Ty::new_tup_from_iter( - tcx, - tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { - if capture.is_by_ref() { - by_ref_fields.insert(FieldIdx::from_usize(idx)); + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") }; + let args = args.as_coroutine(); + + let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap(); + + let parent_def_id = tcx.local_parent(coroutine_def_id); + let ty::CoroutineClosure(_, parent_args) = + *tcx.type_of(parent_def_id).instantiate_identity().kind() + else { + bug!(); + }; + let parent_closure_args = parent_args.as_coroutine_closure(); + let num_args = parent_closure_args + .coroutine_closure_sig() + .skip_binder() + .tupled_inputs_ty + .tuple_fields() + .len(); + + let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures( + tcx.closure_captures(parent_def_id).iter().copied(), + tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(), + |(parent_field_idx, parent_capture), (child_field_idx, child_capture)| { + // Store this set of additional projections (fields and derefs). + // We need to re-apply them later. + let child_precise_captures = + &child_capture.place.projections[parent_capture.place.projections.len()..]; + + // If the parent captures by-move, and the child captures by-ref, then we + // need to peel an additional `deref` off of the body of the child. + let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref(); + if needs_deref { + assert_ne!( + coroutine_kind, + ty::ClosureKind::FnOnce, + "`FnOnce` coroutine-closures return coroutines that capture from \ + their body; it will always result in a borrowck error!" + ); } - capture.place.ty() - }), - ); - let by_move_coroutine_ty = Ty::new_coroutine( - tcx, - coroutine_def_id.to_def_id(), - ty::CoroutineArgs::new( - tcx, - ty::CoroutineArgsParts { - parent_args: args.as_coroutine().parent_args(), - kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), - resume_ty: args.as_coroutine().resume_ty(), - yield_ty: args.as_coroutine().yield_ty(), - return_ty: args.as_coroutine().return_ty(), - witness: args.as_coroutine().witness(), - tupled_upvars_ty: by_move_upvars, - }, - ) - .args, - ); - let mut by_move_body = body.clone(); - MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); - dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); - by_move_body.source = mir::MirSource { - instance: InstanceDef::CoroutineKindShim { - coroutine_def_id: coroutine_def_id.to_def_id(), - target_kind: ty::ClosureKind::FnOnce, + // Finally, store the type of the parent's captured place. We need + // this when building the field projection in the MIR body later on. + let mut parent_capture_ty = parent_capture.place.ty(); + parent_capture_ty = match parent_capture.info.capture_kind { + ty::UpvarCapture::ByValue => parent_capture_ty, + ty::UpvarCapture::ByRef(kind) => Ty::new_ref( + tcx, + tcx.lifetimes.re_erased, + parent_capture_ty, + kind.to_mutbl_lossy(), + ), + }; + + ( + FieldIdx::from_usize(child_field_idx + num_args), + ( + FieldIdx::from_usize(parent_field_idx + num_args), + parent_capture_ty, + needs_deref, + child_precise_captures, + ), + ) }, - promoted: None, - }; - body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + ) + .collect(); + + if coroutine_kind == ty::ClosureKind::FnOnce { + assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len()); + return; + } - // If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body. - // This is actually just a copy of the by-ref body, but with a different self type. - // FIXME(async_closures): We could probably unify this with the by-ref body somehow. - if coroutine_kind == ty::ClosureKind::Fn { - let by_mut_coroutine_ty = Ty::new_coroutine( + let by_move_coroutine_ty = tcx + .instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig()) + .to_coroutine_given_kind_and_upvars( tcx, + parent_closure_args.parent_args(), coroutine_def_id.to_def_id(), - ty::CoroutineArgs::new( - tcx, - ty::CoroutineArgsParts { - parent_args: args.as_coroutine().parent_args(), - kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut), - resume_ty: args.as_coroutine().resume_ty(), - yield_ty: args.as_coroutine().yield_ty(), - return_ty: args.as_coroutine().return_ty(), - witness: args.as_coroutine().witness(), - tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(), - }, - ) - .args, + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_erased, + parent_closure_args.tupled_upvars_ty(), + parent_closure_args.coroutine_captures_by_ref_ty(), ); - let mut by_mut_body = body.clone(); - by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty; - dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(())); - by_mut_body.source = mir::MirSource { - instance: InstanceDef::CoroutineKindShim { - coroutine_def_id: coroutine_def_id.to_def_id(), - target_kind: ty::ClosureKind::FnMut, - }, - promoted: None, - }; - body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body); - } + + let mut by_move_body = body.clone(); + MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body); + dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); + // FIXME: use query feeding to generate the body right here and then only store the `DefId` of the new body. + by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + }); + body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); } } struct MakeByMoveBody<'tcx> { tcx: TyCtxt<'tcx>, - by_ref_fields: FxIndexSet<FieldIdx>, + field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>, by_move_coroutine_ty: Ty<'tcx>, } @@ -126,24 +217,57 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { context: mir::visit::PlaceContext, location: mir::Location, ) { + // Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a + // field projection. If this is in `field_remapping`, then it must not be an + // arg from calling the closure, but instead an upvar. if place.local == ty::CAPTURE_STRUCT_LOCAL - && !place.projection.is_empty() - && let mir::ProjectionElem::Field(idx, ty) = place.projection[0] - && self.by_ref_fields.contains(&idx) + && let Some((&mir::ProjectionElem::Field(idx, _), projection)) = + place.projection.split_first() + && let Some(&(remapped_idx, remapped_ty, needs_deref, bridging_projections)) = + self.field_remapping.get(&idx) { - let (begin, end) = place.projection[1..].split_first().unwrap(); - // FIXME(async_closures): I'm actually a bit surprised to see that we always - // initially deref the by-ref upvars. If this is not actually true, then we - // will at least get an ICE that explains why this isn't true :^) - assert_eq!(*begin, mir::ProjectionElem::Deref); - // Peel one ref off of the ty. - let peeled_ty = ty.builtin_deref(true).unwrap().ty; + // As noted before, if the parent closure captures a field by value, and + // the child captures a field by ref, then for the by-move body we're + // generating, we also are taking that field by value. Peel off a deref, + // since a layer of ref'ing has now become redundant. + let final_projections = if needs_deref { + let Some((mir::ProjectionElem::Deref, projection)) = projection.split_first() + else { + bug!( + "There should be at least a single deref for an upvar local initialization, found {projection:#?}" + ); + }; + // There may be more derefs, since we may also implicitly reborrow + // a captured mut pointer. + projection + } else { + projection + }; + + // These projections are applied in order to "bridge" the local that we are + // currently transforming *from* the old upvar that the by-ref coroutine used + // to capture *to* the upvar of the parent coroutine-closure. For example, if + // the parent captures `&s` but the child captures `&(s.field)`, then we will + // apply a field projection. + let bridging_projections = bridging_projections.iter().map(|elem| match elem.kind { + ProjectionKind::Deref => mir::ProjectionElem::Deref, + ProjectionKind::Field(idx, VariantIdx::ZERO) => { + mir::ProjectionElem::Field(idx, elem.ty) + } + _ => unreachable!("precise captures only through fields and derefs"), + }); + + // We start out with an adjusted field index (and ty), representing the + // upvar that we get from our parent closure. We apply any of the additional + // projections to make sure that to the rest of the body of the closure, the + // place looks the same, and then apply that final deref if necessary. *place = mir::Place { local: place.local, projection: self.tcx.mk_place_elems_from_iter( - [mir::ProjectionElem::Field(idx, peeled_ty)] + [mir::ProjectionElem::Field(remapped_idx, remapped_ty)] .into_iter() - .chain(end.iter().copied()), + .chain(bridging_projections) + .chain(final_projections.iter().copied()), ), }; } diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index 9a1d8bae6b4..b5968517d77 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -1,27 +1,23 @@ +use std::fmt::{self, Debug}; + use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::FxHashMap; -use rustc_data_structures::graph::WithNumNodes; -use rustc_index::bit_set::BitSet; +use rustc_data_structures::graph::DirectedGraph; use rustc_index::IndexVec; -use rustc_middle::mir::coverage::*; +use rustc_middle::bug; +use rustc_middle::mir::coverage::{CounterId, CovTerm, Expression, ExpressionId, Op}; -use super::graph::{BasicCoverageBlock, CoverageGraph, TraverseCoverageGraphWithLoops}; - -use std::fmt::{self, Debug}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, TraverseCoverageGraphWithLoops}; /// The coverage counter or counter expression associated with a particular /// BCB node or BCB edge. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub(super) enum BcbCounter { Counter { id: CounterId }, Expression { id: ExpressionId }, } impl BcbCounter { - fn is_expression(&self) -> bool { - matches!(self, Self::Expression { .. }) - } - pub(super) fn as_term(&self) -> CovTerm { match *self { BcbCounter::Counter { id, .. } => CovTerm::Counter(id), @@ -39,6 +35,13 @@ impl Debug for BcbCounter { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +struct BcbExpression { + lhs: BcbCounter, + op: Op, + rhs: BcbCounter, +} + #[derive(Debug)] pub(super) enum CounterIncrementSite { Node { bcb: BasicCoverageBlock }, @@ -60,13 +63,13 @@ pub(super) struct CoverageCounters { /// We currently don't iterate over this map, but if we do in the future, /// switch it back to `FxIndexMap` to avoid query stability hazards. bcb_edge_counters: FxHashMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, - /// Tracks which BCBs have a counter associated with some incoming edge. - /// Only used by assertions, to verify that BCBs with incoming edge - /// counters do not have their own physical counters (expressions are allowed). - bcb_has_incoming_edge_counters: BitSet<BasicCoverageBlock>, + /// Table of expression data, associating each expression ID with its /// corresponding operator (+ or -) and its LHS/RHS operands. - expressions: IndexVec<ExpressionId, Expression>, + expressions: IndexVec<ExpressionId, BcbExpression>, + /// Remember expressions that have already been created (or simplified), + /// so that we don't create unnecessary duplicates. + expressions_memo: FxHashMap<BcbExpression, BcbCounter>, } impl CoverageCounters { @@ -83,8 +86,8 @@ impl CoverageCounters { counter_increment_sites: IndexVec::new(), bcb_counters: IndexVec::from_elem_n(None, num_bcbs), bcb_edge_counters: FxHashMap::default(), - bcb_has_incoming_edge_counters: BitSet::new_empty(num_bcbs), expressions: IndexVec::new(), + expressions_memo: FxHashMap::default(), }; MakeBcbCounters::new(&mut this, basic_coverage_blocks) @@ -99,8 +102,57 @@ impl CoverageCounters { } fn make_expression(&mut self, lhs: BcbCounter, op: Op, rhs: BcbCounter) -> BcbCounter { - let expression = Expression { lhs: lhs.as_term(), op, rhs: rhs.as_term() }; - let id = self.expressions.push(expression); + let new_expr = BcbExpression { lhs, op, rhs }; + *self + .expressions_memo + .entry(new_expr) + .or_insert_with(|| Self::make_expression_inner(&mut self.expressions, new_expr)) + } + + /// This is an associated function so that we can call it while borrowing + /// `&mut self.expressions_memo`. + fn make_expression_inner( + expressions: &mut IndexVec<ExpressionId, BcbExpression>, + new_expr: BcbExpression, + ) -> BcbCounter { + // Simplify expressions using basic algebra. + // + // Some of these cases might not actually occur in practice, depending + // on the details of how the instrumentor builds expressions. + let BcbExpression { lhs, op, rhs } = new_expr; + + if let BcbCounter::Expression { id } = lhs { + let lhs_expr = &expressions[id]; + + // Simplify `(a - b) + b` to `a`. + if lhs_expr.op == Op::Subtract && op == Op::Add && lhs_expr.rhs == rhs { + return lhs_expr.lhs; + } + // Simplify `(a + b) - b` to `a`. + if lhs_expr.op == Op::Add && op == Op::Subtract && lhs_expr.rhs == rhs { + return lhs_expr.lhs; + } + // Simplify `(a + b) - a` to `b`. + if lhs_expr.op == Op::Add && op == Op::Subtract && lhs_expr.lhs == rhs { + return lhs_expr.rhs; + } + } + + if let BcbCounter::Expression { id } = rhs { + let rhs_expr = &expressions[id]; + + // Simplify `a + (b - a)` to `b`. + if op == Op::Add && rhs_expr.op == Op::Subtract && lhs == rhs_expr.rhs { + return rhs_expr.lhs; + } + // Simplify `a - (a - b)` to `b`. + if op == Op::Subtract && rhs_expr.op == Op::Subtract && lhs == rhs_expr.lhs { + return rhs_expr.rhs; + } + } + + // Simplification failed, so actually create the new expression. + let id = expressions.push(new_expr); BcbCounter::Expression { id } } @@ -122,14 +174,6 @@ impl CoverageCounters { } fn set_bcb_counter(&mut self, bcb: BasicCoverageBlock, counter_kind: BcbCounter) -> BcbCounter { - assert!( - // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also - // have an expression (to be injected into an existing `BasicBlock` represented by this - // `BasicCoverageBlock`). - counter_kind.is_expression() || !self.bcb_has_incoming_edge_counters.contains(bcb), - "attempt to add a `Counter` to a BCB target with existing incoming edge counters" - ); - if let Some(replaced) = self.bcb_counters[bcb].replace(counter_kind) { bug!( "attempt to set a BasicCoverageBlock coverage counter more than once; \ @@ -146,19 +190,6 @@ impl CoverageCounters { to_bcb: BasicCoverageBlock, counter_kind: BcbCounter, ) -> BcbCounter { - // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also - // have an expression (to be injected into an existing `BasicBlock` represented by this - // `BasicCoverageBlock`). - if let Some(node_counter) = self.bcb_counter(to_bcb) - && !node_counter.is_expression() - { - bug!( - "attempt to add an incoming edge counter from {from_bcb:?} \ - when the target BCB already has {node_counter:?}" - ); - } - - self.bcb_has_incoming_edge_counters.insert(to_bcb); if let Some(replaced) = self.bcb_edge_counters.insert((from_bcb, to_bcb), counter_kind) { bug!( "attempt to set an edge counter more than once; from_bcb: \ @@ -196,7 +227,21 @@ impl CoverageCounters { } pub(super) fn into_expressions(self) -> IndexVec<ExpressionId, Expression> { - self.expressions + let old_len = self.expressions.len(); + let expressions = self + .expressions + .into_iter() + .map(|BcbExpression { lhs, op, rhs }| Expression { + lhs: lhs.as_term(), + op, + rhs: rhs.as_term(), + }) + .collect::<IndexVec<ExpressionId, _>>(); + + // Expression IDs are indexes into this vector, so make sure we didn't + // accidentally invalidate them by changing its length. + assert_eq!(old_len, expressions.len()); + expressions } } diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs index c97192435ce..fd74a2a97e2 100644 --- a/compiler/rustc_mir_transform/src/coverage/graph.rs +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -1,9 +1,10 @@ use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::FxHashSet; use rustc_data_structures::graph::dominators::{self, Dominators}; -use rustc_data_structures::graph::{self, GraphSuccessors, WithNumNodes, WithStartNode}; +use rustc_data_structures::graph::{self, DirectedGraph, StartNode}; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::{self, BasicBlock, Terminator, TerminatorKind}; use std::cmp::Ordering; @@ -193,16 +194,14 @@ impl IndexMut<BasicCoverageBlock> for CoverageGraph { impl graph::DirectedGraph for CoverageGraph { type Node = BasicCoverageBlock; -} -impl graph::WithNumNodes for CoverageGraph { #[inline] fn num_nodes(&self) -> usize { self.bcbs.len() } } -impl graph::WithStartNode for CoverageGraph { +impl graph::StartNode for CoverageGraph { #[inline] fn start_node(&self) -> Self::Node { self.bcb_from_bb(mir::START_BLOCK) @@ -210,28 +209,16 @@ impl graph::WithStartNode for CoverageGraph { } } -type BcbSuccessors<'graph> = std::slice::Iter<'graph, BasicCoverageBlock>; - -impl<'graph> graph::GraphSuccessors<'graph> for CoverageGraph { - type Item = BasicCoverageBlock; - type Iter = std::iter::Cloned<BcbSuccessors<'graph>>; -} - -impl graph::WithSuccessors for CoverageGraph { +impl graph::Successors for CoverageGraph { #[inline] - fn successors(&self, node: Self::Node) -> <Self as GraphSuccessors<'_>>::Iter { + fn successors(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> { self.successors[node].iter().cloned() } } -impl<'graph> graph::GraphPredecessors<'graph> for CoverageGraph { - type Item = BasicCoverageBlock; - type Iter = std::iter::Copied<std::slice::Iter<'graph, BasicCoverageBlock>>; -} - -impl graph::WithPredecessors for CoverageGraph { +impl graph::Predecessors for CoverageGraph { #[inline] - fn predecessors(&self, node: Self::Node) -> <Self as graph::GraphPredecessors<'_>>::Iter { + fn predecessors(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> { self.predecessors[node].iter().copied() } } @@ -349,12 +336,20 @@ fn bcb_filtered_successors<'a, 'tcx>(terminator: &'a Terminator<'tcx>) -> Covera | FalseUnwind { real_target: target, .. } | Goto { target } => CoverageSuccessors::Chainable(target), - // These terminators can normally be chained, except when they have no + // A call terminator can normally be chained, except when they have no // successor because they are known to diverge. - Call { target: maybe_target, .. } | InlineAsm { destination: maybe_target, .. } => { - match maybe_target { - Some(target) => CoverageSuccessors::Chainable(target), - None => CoverageSuccessors::NotChainable(&[]), + Call { target: maybe_target, .. } => match maybe_target { + Some(target) => CoverageSuccessors::Chainable(target), + None => CoverageSuccessors::NotChainable(&[]), + }, + + // An inline asm terminator can normally be chained, except when it diverges or uses asm + // goto. + InlineAsm { ref targets, .. } => { + if targets.len() == 1 { + CoverageSuccessors::Chainable(targets[0]) + } else { + CoverageSuccessors::NotChainable(targets) } } diff --git a/compiler/rustc_mir_transform/src/coverage/mappings.rs b/compiler/rustc_mir_transform/src/coverage/mappings.rs new file mode 100644 index 00000000000..61aabea1d8b --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/mappings.rs @@ -0,0 +1,282 @@ +use std::collections::BTreeSet; + +use rustc_data_structures::graph::DirectedGraph; +use rustc_index::bit_set::BitSet; +use rustc_index::IndexVec; +use rustc_middle::mir::coverage::{BlockMarkerId, BranchSpan, ConditionInfo, CoverageKind}; +use rustc_middle::mir::{self, BasicBlock, StatementKind}; +use rustc_span::Span; + +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; +use crate::coverage::spans::{ + extract_refined_covspans, unexpand_into_body_span_with_visible_macro, +}; +use crate::coverage::ExtractedHirInfo; + +/// Associates an ordinary executable code span with its corresponding BCB. +#[derive(Debug)] +pub(super) struct CodeMapping { + pub(super) span: Span, + pub(super) bcb: BasicCoverageBlock, +} + +/// This is separate from [`MCDCBranch`] to help prepare for larger changes +/// that will be needed for improved branch coverage in the future. +/// (See <https://github.com/rust-lang/rust/pull/124217>.) +#[derive(Debug)] +pub(super) struct BranchPair { + pub(super) span: Span, + pub(super) true_bcb: BasicCoverageBlock, + pub(super) false_bcb: BasicCoverageBlock, +} + +/// Associates an MC/DC branch span with condition info besides fields for normal branch. +#[derive(Debug)] +pub(super) struct MCDCBranch { + pub(super) span: Span, + pub(super) true_bcb: BasicCoverageBlock, + pub(super) false_bcb: BasicCoverageBlock, + /// If `None`, this actually represents a normal branch mapping inserted + /// for code that was too complex for MC/DC. + pub(super) condition_info: Option<ConditionInfo>, + pub(super) decision_depth: u16, +} + +/// Associates an MC/DC decision with its join BCBs. +#[derive(Debug)] +pub(super) struct MCDCDecision { + pub(super) span: Span, + pub(super) end_bcbs: BTreeSet<BasicCoverageBlock>, + pub(super) bitmap_idx: u32, + pub(super) conditions_num: u16, + pub(super) decision_depth: u16, +} + +#[derive(Default)] +pub(super) struct ExtractedMappings { + pub(super) code_mappings: Vec<CodeMapping>, + pub(super) branch_pairs: Vec<BranchPair>, + pub(super) mcdc_bitmap_bytes: u32, + pub(super) mcdc_branches: Vec<MCDCBranch>, + pub(super) mcdc_decisions: Vec<MCDCDecision>, +} + +/// Extracts coverage-relevant spans from MIR, and associates them with +/// their corresponding BCBs. +pub(super) fn extract_all_mapping_info_from_mir( + mir_body: &mir::Body<'_>, + hir_info: &ExtractedHirInfo, + basic_coverage_blocks: &CoverageGraph, +) -> ExtractedMappings { + if hir_info.is_async_fn { + // An async function desugars into a function that returns a future, + // with the user code wrapped in a closure. Any spans in the desugared + // outer function will be unhelpful, so just keep the signature span + // and ignore all of the spans in the MIR body. + let mut mappings = ExtractedMappings::default(); + if let Some(span) = hir_info.fn_sig_span_extended { + mappings.code_mappings.push(CodeMapping { span, bcb: START_BCB }); + } + return mappings; + } + + let mut code_mappings = vec![]; + let mut branch_pairs = vec![]; + let mut mcdc_bitmap_bytes = 0; + let mut mcdc_branches = vec![]; + let mut mcdc_decisions = vec![]; + + extract_refined_covspans(mir_body, hir_info, basic_coverage_blocks, &mut code_mappings); + + branch_pairs.extend(extract_branch_pairs(mir_body, hir_info, basic_coverage_blocks)); + + extract_mcdc_mappings( + mir_body, + hir_info.body_span, + basic_coverage_blocks, + &mut mcdc_bitmap_bytes, + &mut mcdc_branches, + &mut mcdc_decisions, + ); + + ExtractedMappings { + code_mappings, + branch_pairs, + mcdc_bitmap_bytes, + mcdc_branches, + mcdc_decisions, + } +} + +impl ExtractedMappings { + pub(super) fn all_bcbs_with_counter_mappings( + &self, + basic_coverage_blocks: &CoverageGraph, // Only used for allocating a correctly-sized set + ) -> BitSet<BasicCoverageBlock> { + // Fully destructure self to make sure we don't miss any fields that have mappings. + let Self { + code_mappings, + branch_pairs, + mcdc_bitmap_bytes: _, + mcdc_branches, + mcdc_decisions, + } = self; + + // Identify which BCBs have one or more mappings. + let mut bcbs_with_counter_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); + let mut insert = |bcb| { + bcbs_with_counter_mappings.insert(bcb); + }; + + for &CodeMapping { span: _, bcb } in code_mappings { + insert(bcb); + } + for &BranchPair { true_bcb, false_bcb, .. } in branch_pairs { + insert(true_bcb); + insert(false_bcb); + } + for &MCDCBranch { true_bcb, false_bcb, .. } in mcdc_branches { + insert(true_bcb); + insert(false_bcb); + } + + // MC/DC decisions refer to BCBs, but don't require those BCBs to have counters. + if bcbs_with_counter_mappings.is_empty() { + debug_assert!( + mcdc_decisions.is_empty(), + "A function with no counter mappings shouldn't have any decisions: {mcdc_decisions:?}", + ); + } + + bcbs_with_counter_mappings + } +} + +fn resolve_block_markers( + branch_info: &mir::coverage::BranchInfo, + mir_body: &mir::Body<'_>, +) -> IndexVec<BlockMarkerId, Option<BasicBlock>> { + let mut block_markers = IndexVec::<BlockMarkerId, Option<BasicBlock>>::from_elem_n( + None, + branch_info.num_block_markers, + ); + + // Fill out the mapping from block marker IDs to their enclosing blocks. + for (bb, data) in mir_body.basic_blocks.iter_enumerated() { + for statement in &data.statements { + if let StatementKind::Coverage(CoverageKind::BlockMarker { id }) = statement.kind { + block_markers[id] = Some(bb); + } + } + } + + block_markers +} + +// FIXME: There is currently a lot of redundancy between +// `extract_branch_pairs` and `extract_mcdc_mappings`. This is needed so +// that they can each be modified without interfering with the other, but in +// the long term we should try to bring them together again when branch coverage +// and MC/DC coverage support are more mature. + +pub(super) fn extract_branch_pairs( + mir_body: &mir::Body<'_>, + hir_info: &ExtractedHirInfo, + basic_coverage_blocks: &CoverageGraph, +) -> Vec<BranchPair> { + let Some(branch_info) = mir_body.coverage_branch_info.as_deref() else { return vec![] }; + + let block_markers = resolve_block_markers(branch_info, mir_body); + + branch_info + .branch_spans + .iter() + .filter_map(|&BranchSpan { span: raw_span, true_marker, false_marker }| { + // For now, ignore any branch span that was introduced by + // expansion. This makes things like assert macros less noisy. + if !raw_span.ctxt().outer_expn_data().is_root() { + return None; + } + let (span, _) = + unexpand_into_body_span_with_visible_macro(raw_span, hir_info.body_span)?; + + let bcb_from_marker = + |marker: BlockMarkerId| basic_coverage_blocks.bcb_from_bb(block_markers[marker]?); + + let true_bcb = bcb_from_marker(true_marker)?; + let false_bcb = bcb_from_marker(false_marker)?; + + Some(BranchPair { span, true_bcb, false_bcb }) + }) + .collect::<Vec<_>>() +} + +pub(super) fn extract_mcdc_mappings( + mir_body: &mir::Body<'_>, + body_span: Span, + basic_coverage_blocks: &CoverageGraph, + mcdc_bitmap_bytes: &mut u32, + mcdc_branches: &mut impl Extend<MCDCBranch>, + mcdc_decisions: &mut impl Extend<MCDCDecision>, +) { + let Some(branch_info) = mir_body.coverage_branch_info.as_deref() else { return }; + + let block_markers = resolve_block_markers(branch_info, mir_body); + + let bcb_from_marker = + |marker: BlockMarkerId| basic_coverage_blocks.bcb_from_bb(block_markers[marker]?); + + let check_branch_bcb = + |raw_span: Span, true_marker: BlockMarkerId, false_marker: BlockMarkerId| { + // For now, ignore any branch span that was introduced by + // expansion. This makes things like assert macros less noisy. + if !raw_span.ctxt().outer_expn_data().is_root() { + return None; + } + let (span, _) = unexpand_into_body_span_with_visible_macro(raw_span, body_span)?; + + let true_bcb = bcb_from_marker(true_marker)?; + let false_bcb = bcb_from_marker(false_marker)?; + Some((span, true_bcb, false_bcb)) + }; + + mcdc_branches.extend(branch_info.mcdc_branch_spans.iter().filter_map( + |&mir::coverage::MCDCBranchSpan { + span: raw_span, + condition_info, + true_marker, + false_marker, + decision_depth, + }| { + let (span, true_bcb, false_bcb) = + check_branch_bcb(raw_span, true_marker, false_marker)?; + Some(MCDCBranch { span, true_bcb, false_bcb, condition_info, decision_depth }) + }, + )); + + mcdc_decisions.extend(branch_info.mcdc_decision_spans.iter().filter_map( + |decision: &mir::coverage::MCDCDecisionSpan| { + let (span, _) = unexpand_into_body_span_with_visible_macro(decision.span, body_span)?; + + let end_bcbs = decision + .end_markers + .iter() + .map(|&marker| bcb_from_marker(marker)) + .collect::<Option<_>>()?; + + // Each decision containing N conditions needs 2^N bits of space in + // the bitmap, rounded up to a whole number of bytes. + // The decision's "bitmap index" points to its first byte in the bitmap. + let bitmap_idx = *mcdc_bitmap_bytes; + *mcdc_bitmap_bytes += (1_u32 << decision.conditions_num).div_ceil(8); + + Some(MCDCDecision { + span, + end_bcbs, + bitmap_idx, + conditions_num: decision.conditions_num as u16, + decision_depth: decision.decision_depth, + }) + }, + )); +} diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index 4c5be0a3f4b..28e0c633d5a 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -2,22 +2,16 @@ pub mod query; mod counters; mod graph; +mod mappings; mod spans; - #[cfg(test)] mod tests; -use self::counters::{CounterIncrementSite, CoverageCounters}; -use self::graph::{BasicCoverageBlock, CoverageGraph}; -use self::spans::{BcbMapping, BcbMappingKind, CoverageSpans}; - -use crate::MirPass; - -use rustc_middle::hir; -use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; -use rustc_middle::mir::coverage::*; +use rustc_middle::mir::coverage::{ + CodeRegion, CoverageKind, DecisionInfo, FunctionCoverageInfo, Mapping, MappingKind, +}; use rustc_middle::mir::{ - self, BasicBlock, BasicBlockData, Coverage, SourceInfo, Statement, StatementKind, Terminator, + self, BasicBlock, BasicBlockData, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, }; use rustc_middle::ty::TyCtxt; @@ -25,6 +19,11 @@ use rustc_span::def_id::LocalDefId; use rustc_span::source_map::SourceMap; use rustc_span::{BytePos, Pos, RelativeBytePos, Span, Symbol}; +use crate::coverage::counters::{CounterIncrementSite, CoverageCounters}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph}; +use crate::coverage::mappings::ExtractedMappings; +use crate::MirPass; + /// Inserts `StatementKind::Coverage` statements that either instrument the binary with injected /// counters, via intrinsic `llvm.instrprof.increment`, and/or inject metadata used during codegen /// to construct the coverage map. @@ -44,7 +43,7 @@ impl<'tcx> MirPass<'tcx> for InstrumentCoverage { let def_id = mir_source.def_id().expect_local(); - if !is_eligible_for_coverage(tcx, def_id) { + if !tcx.is_eligible_for_coverage(def_id) { trace!("InstrumentCoverage skipped for {def_id:?} (not eligible)"); return; } @@ -71,24 +70,27 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir: let basic_coverage_blocks = CoverageGraph::from_mir(mir_body); //////////////////////////////////////////////////// - // Compute coverage spans from the `CoverageGraph`. - let Some(coverage_spans) = - spans::generate_coverage_spans(mir_body, &hir_info, &basic_coverage_blocks) - else { - // No relevant spans were found in MIR, so skip instrumenting this function. - return; - }; + // Extract coverage spans and other mapping info from MIR. + let extracted_mappings = + mappings::extract_all_mapping_info_from_mir(mir_body, &hir_info, &basic_coverage_blocks); //////////////////////////////////////////////////// // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` // and all `Expression` dependencies (operands) are also generated, for any other // `BasicCoverageBlock`s not already associated with a coverage span. - let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); + let bcbs_with_counter_mappings = + extracted_mappings.all_bcbs_with_counter_mappings(&basic_coverage_blocks); + if bcbs_with_counter_mappings.is_empty() { + // No relevant spans were found in MIR, so skip instrumenting this function. + return; + } + + let bcb_has_counter_mappings = |bcb| bcbs_with_counter_mappings.contains(bcb); let coverage_counters = - CoverageCounters::make_bcb_counters(&basic_coverage_blocks, bcb_has_coverage_spans); + CoverageCounters::make_bcb_counters(&basic_coverage_blocks, bcb_has_counter_mappings); - let mappings = create_mappings(tcx, &hir_info, &coverage_spans, &coverage_counters); + let mappings = create_mappings(tcx, &hir_info, &extracted_mappings, &coverage_counters); if mappings.is_empty() { // No spans could be converted into valid mappings, so skip this function. debug!("no spans could be converted into valid mappings; skipping"); @@ -98,15 +100,26 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir: inject_coverage_statements( mir_body, &basic_coverage_blocks, - bcb_has_coverage_spans, + bcb_has_counter_mappings, &coverage_counters, ); + inject_mcdc_statements(mir_body, &basic_coverage_blocks, &extracted_mappings); + + let mcdc_num_condition_bitmaps = extracted_mappings + .mcdc_decisions + .iter() + .map(|&mappings::MCDCDecision { decision_depth, .. }| decision_depth) + .max() + .map_or(0, |max| usize::from(max) + 1); + mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { function_source_hash: hir_info.function_source_hash, num_counters: coverage_counters.num_counters(), + mcdc_bitmap_bytes: extracted_mappings.mcdc_bitmap_bytes, expressions: coverage_counters.into_expressions(), mappings, + mcdc_num_condition_bitmaps, })); } @@ -118,15 +131,18 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir: fn create_mappings<'tcx>( tcx: TyCtxt<'tcx>, hir_info: &ExtractedHirInfo, - coverage_spans: &CoverageSpans, + extracted_mappings: &ExtractedMappings, coverage_counters: &CoverageCounters, ) -> Vec<Mapping> { let source_map = tcx.sess.source_map(); let body_span = hir_info.body_span; let source_file = source_map.lookup_source_file(body_span.lo()); - use rustc_session::RemapFileNameExt; - let file_name = Symbol::intern(&source_file.name.for_codegen(tcx.sess).to_string_lossy()); + + use rustc_session::{config::RemapPathScopeComponents, RemapFileNameExt}; + let file_name = Symbol::intern( + &source_file.name.for_scope(tcx.sess, RemapPathScopeComponents::MACRO).to_string_lossy(), + ); let term_for_bcb = |bcb| { coverage_counters @@ -134,17 +150,59 @@ fn create_mappings<'tcx>( .expect("all BCBs with spans were given counters") .as_term() }; - - coverage_spans - .all_bcb_mappings() - .filter_map(|&BcbMapping { kind: bcb_mapping_kind, span }| { - let kind = match bcb_mapping_kind { - BcbMappingKind::Code(bcb) => MappingKind::Code(term_for_bcb(bcb)), + let region_for_span = |span: Span| make_code_region(source_map, file_name, span, body_span); + + // Fully destructure the mappings struct to make sure we don't miss any kinds. + let ExtractedMappings { + code_mappings, + branch_pairs, + mcdc_bitmap_bytes: _, + mcdc_branches, + mcdc_decisions, + } = extracted_mappings; + let mut mappings = Vec::new(); + + mappings.extend(code_mappings.iter().filter_map( + // Ordinary code mappings are the simplest kind. + |&mappings::CodeMapping { span, bcb }| { + let code_region = region_for_span(span)?; + let kind = MappingKind::Code(term_for_bcb(bcb)); + Some(Mapping { kind, code_region }) + }, + )); + + mappings.extend(branch_pairs.iter().filter_map( + |&mappings::BranchPair { span, true_bcb, false_bcb }| { + let true_term = term_for_bcb(true_bcb); + let false_term = term_for_bcb(false_bcb); + let kind = MappingKind::Branch { true_term, false_term }; + let code_region = region_for_span(span)?; + Some(Mapping { kind, code_region }) + }, + )); + + mappings.extend(mcdc_branches.iter().filter_map( + |&mappings::MCDCBranch { span, true_bcb, false_bcb, condition_info, decision_depth: _ }| { + let code_region = region_for_span(span)?; + let true_term = term_for_bcb(true_bcb); + let false_term = term_for_bcb(false_bcb); + let kind = match condition_info { + Some(mcdc_params) => MappingKind::MCDCBranch { true_term, false_term, mcdc_params }, + None => MappingKind::Branch { true_term, false_term }, }; - let code_region = make_code_region(source_map, file_name, span, body_span)?; Some(Mapping { kind, code_region }) - }) - .collect::<Vec<_>>() + }, + )); + + mappings.extend(mcdc_decisions.iter().filter_map( + |&mappings::MCDCDecision { span, bitmap_idx, conditions_num, .. }| { + let code_region = region_for_span(span)?; + let kind = MappingKind::MCDCDecision(DecisionInfo { bitmap_idx, conditions_num }); + Some(Mapping { kind, code_region }) + }, + )); + + mappings } /// For each BCB node or BCB edge that has an associated coverage counter, @@ -199,6 +257,53 @@ fn inject_coverage_statements<'tcx>( } } +/// For each conditions inject statements to update condition bitmap after it has been evaluated. +/// For each decision inject statements to update test vector bitmap after it has been evaluated. +fn inject_mcdc_statements<'tcx>( + mir_body: &mut mir::Body<'tcx>, + basic_coverage_blocks: &CoverageGraph, + extracted_mappings: &ExtractedMappings, +) { + // Inject test vector update first because `inject_statement` always insert new statement at head. + for &mappings::MCDCDecision { + span: _, + ref end_bcbs, + bitmap_idx, + conditions_num: _, + decision_depth, + } in &extracted_mappings.mcdc_decisions + { + for end in end_bcbs { + let end_bb = basic_coverage_blocks[*end].leader_bb(); + inject_statement( + mir_body, + CoverageKind::TestVectorBitmapUpdate { bitmap_idx, decision_depth }, + end_bb, + ); + } + } + + for &mappings::MCDCBranch { span: _, true_bcb, false_bcb, condition_info, decision_depth } in + &extracted_mappings.mcdc_branches + { + let Some(condition_info) = condition_info else { continue }; + let id = condition_info.condition_id; + + let true_bb = basic_coverage_blocks[true_bcb].leader_bb(); + inject_statement( + mir_body, + CoverageKind::CondBitmapUpdate { id, value: true, decision_depth }, + true_bb, + ); + let false_bb = basic_coverage_blocks[false_bcb].leader_bb(); + inject_statement( + mir_body, + CoverageKind::CondBitmapUpdate { id, value: false, decision_depth }, + false_bb, + ); + } +} + /// Given two basic blocks that have a control-flow edge between them, creates /// and returns a new block that sits between those blocks. fn inject_edge_counter_basic_block( @@ -228,10 +333,7 @@ fn inject_statement(mir_body: &mut mir::Body<'_>, counter_kind: CoverageKind, bb debug!(" injecting statement {counter_kind:?} for {bb:?}"); let data = &mut mir_body[bb]; let source_info = data.terminator().source_info; - let statement = Statement { - source_info, - kind: StatementKind::Coverage(Box::new(Coverage { kind: counter_kind })), - }; + let statement = Statement { source_info, kind: StatementKind::Coverage(counter_kind) }; data.statements.insert(0, statement); } @@ -349,37 +451,6 @@ fn check_code_region(code_region: CodeRegion) -> Option<CodeRegion> { } } -fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { - // Only instrument functions, methods, and closures (not constants since they are evaluated - // at compile time by Miri). - // FIXME(#73156): Handle source code coverage in const eval, but note, if and when const - // expressions get coverage spans, we will probably have to "carve out" space for const - // expressions from coverage spans in enclosing MIR's, like we do for closures. (That might - // be tricky if const expressions have no corresponding statements in the enclosing MIR. - // Closures are carved out by their initial `Assign` statement.) - if !tcx.def_kind(def_id).is_fn_like() { - trace!("InstrumentCoverage skipped for {def_id:?} (not an fn-like)"); - return false; - } - - // Don't instrument functions with `#[automatically_derived]` on their - // enclosing impl block, on the assumption that most users won't care about - // coverage for derived impls. - if let Some(impl_of) = tcx.impl_of_method(def_id.to_def_id()) - && tcx.is_automatically_derived(impl_of) - { - trace!("InstrumentCoverage skipped for {def_id:?} (automatically derived)"); - return false; - } - - if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::NO_COVERAGE) { - trace!("InstrumentCoverage skipped for {def_id:?} (`#[coverage(off)]`)"); - return false; - } - - true -} - /// Function information extracted from HIR by the coverage instrumentor. #[derive(Debug)] struct ExtractedHirInfo { @@ -396,8 +467,7 @@ fn extract_hir_info<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> ExtractedHir // to HIR for it. let hir_node = tcx.hir_node_by_def_id(def_id); - let (_, fn_body_id) = - hir::map::associated_body(hir_node).expect("HIR node is a function with body"); + let fn_body_id = hir_node.body_id().expect("HIR node is a function with body"); let hir_body = tcx.hir().body(fn_body_id); let maybe_fn_sig = hir_node.fn_sig(); diff --git a/compiler/rustc_mir_transform/src/coverage/query.rs b/compiler/rustc_mir_transform/src/coverage/query.rs index dfc7c3a713b..65715253647 100644 --- a/compiler/rustc_mir_transform/src/coverage/query.rs +++ b/compiler/rustc_mir_transform/src/coverage/query.rs @@ -1,14 +1,49 @@ -use super::*; - use rustc_data_structures::captures::Captures; -use rustc_middle::mir::coverage::*; -use rustc_middle::mir::{Body, CoverageIdsInfo}; -use rustc_middle::query::Providers; -use rustc_middle::ty::{self}; +use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; +use rustc_middle::mir::coverage::{CounterId, CoverageKind}; +use rustc_middle::mir::{Body, CoverageIdsInfo, Statement, StatementKind}; +use rustc_middle::query::TyCtxtAt; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::util::Providers; +use rustc_span::def_id::LocalDefId; -/// A `query` provider for retrieving coverage information injected into MIR. +/// Registers query/hook implementations related to coverage. pub(crate) fn provide(providers: &mut Providers) { - providers.coverage_ids_info = |tcx, def_id| coverage_ids_info(tcx, def_id); + providers.hooks.is_eligible_for_coverage = + |TyCtxtAt { tcx, .. }, def_id| is_eligible_for_coverage(tcx, def_id); + providers.queries.coverage_ids_info = coverage_ids_info; +} + +/// Hook implementation for [`TyCtxt::is_eligible_for_coverage`]. +fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { + // Only instrument functions, methods, and closures (not constants since they are evaluated + // at compile time by Miri). + // FIXME(#73156): Handle source code coverage in const eval, but note, if and when const + // expressions get coverage spans, we will probably have to "carve out" space for const + // expressions from coverage spans in enclosing MIR's, like we do for closures. (That might + // be tricky if const expressions have no corresponding statements in the enclosing MIR. + // Closures are carved out by their initial `Assign` statement.) + if !tcx.def_kind(def_id).is_fn_like() { + trace!("InstrumentCoverage skipped for {def_id:?} (not an fn-like)"); + return false; + } + + // Don't instrument functions with `#[automatically_derived]` on their + // enclosing impl block, on the assumption that most users won't care about + // coverage for derived impls. + if let Some(impl_of) = tcx.impl_of_method(def_id.to_def_id()) + && tcx.is_automatically_derived(impl_of) + { + trace!("InstrumentCoverage skipped for {def_id:?} (automatically derived)"); + return false; + } + + if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::NO_COVERAGE) { + trace!("InstrumentCoverage skipped for {def_id:?} (`#[coverage(off)]`)"); + return false; + } + + true } /// Query implementation for `coverage_ids_info`. @@ -19,24 +54,22 @@ fn coverage_ids_info<'tcx>( let mir_body = tcx.instance_mir(instance_def); let max_counter_id = all_coverage_in_mir_body(mir_body) - .filter_map(|coverage| match coverage.kind { + .filter_map(|kind| match *kind { CoverageKind::CounterIncrement { id } => Some(id), _ => None, }) .max() - .unwrap_or(CounterId::START); + .unwrap_or(CounterId::ZERO); CoverageIdsInfo { max_counter_id } } fn all_coverage_in_mir_body<'a, 'tcx>( body: &'a Body<'tcx>, -) -> impl Iterator<Item = &'a Coverage> + Captures<'tcx> { +) -> impl Iterator<Item = &'a CoverageKind> + Captures<'tcx> { body.basic_blocks.iter().flat_map(|bb_data| &bb_data.statements).filter_map(|statement| { match statement.kind { - StatementKind::Coverage(box ref coverage) if !is_inlined(body, statement) => { - Some(coverage) - } + StatementKind::Coverage(ref kind) if !is_inlined(body, statement) => Some(kind), _ => None, } }) diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 4260a6f0c6f..f2f76ac70c2 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -1,89 +1,32 @@ -use rustc_data_structures::graph::WithNumNodes; -use rustc_index::bit_set::BitSet; +use rustc_middle::bug; use rustc_middle::mir; use rustc_span::{BytePos, Span}; -use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph}; +use crate::coverage::mappings; use crate::coverage::spans::from_mir::SpanFromMir; use crate::coverage::ExtractedHirInfo; mod from_mir; -#[derive(Clone, Copy, Debug)] -pub(super) enum BcbMappingKind { - /// Associates an ordinary executable code span with its corresponding BCB. - Code(BasicCoverageBlock), -} - -#[derive(Debug)] -pub(super) struct BcbMapping { - pub(super) kind: BcbMappingKind, - pub(super) span: Span, -} - -pub(super) struct CoverageSpans { - bcb_has_mappings: BitSet<BasicCoverageBlock>, - mappings: Vec<BcbMapping>, -} - -impl CoverageSpans { - pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool { - self.bcb_has_mappings.contains(bcb) - } - - pub(super) fn all_bcb_mappings(&self) -> impl Iterator<Item = &BcbMapping> { - self.mappings.iter() - } -} +// FIXME(#124545) It's awkward that we have to re-export this, because it's an +// internal detail of `from_mir` that is also needed when handling branch and +// MC/DC spans. Ideally we would find a more natural home for it. +pub(super) use from_mir::unexpand_into_body_span_with_visible_macro; -/// Extracts coverage-relevant spans from MIR, and associates them with -/// their corresponding BCBs. -/// -/// Returns `None` if no coverage-relevant spans could be extracted. -pub(super) fn generate_coverage_spans( +pub(super) fn extract_refined_covspans( mir_body: &mir::Body<'_>, hir_info: &ExtractedHirInfo, basic_coverage_blocks: &CoverageGraph, -) -> Option<CoverageSpans> { - let mut mappings = vec![]; - - if hir_info.is_async_fn { - // An async function desugars into a function that returns a future, - // with the user code wrapped in a closure. Any spans in the desugared - // outer function will be unhelpful, so just keep the signature span - // and ignore all of the spans in the MIR body. - if let Some(span) = hir_info.fn_sig_span_extended { - mappings.push(BcbMapping { kind: BcbMappingKind::Code(START_BCB), span }); - } - } else { - let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans( - mir_body, - hir_info, - basic_coverage_blocks, - ); - let coverage_spans = SpansRefiner::refine_sorted_spans(sorted_spans); - mappings.extend(coverage_spans.into_iter().map(|RefinedCovspan { bcb, span, .. }| { - // Each span produced by the generator represents an ordinary code region. - BcbMapping { kind: BcbMappingKind::Code(bcb), span } - })); - } - - if mappings.is_empty() { - return None; - } - - // Identify which BCBs have one or more mappings. - let mut bcb_has_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); - let mut insert = |bcb| { - bcb_has_mappings.insert(bcb); - }; - for &BcbMapping { kind, span: _ } in &mappings { - match kind { - BcbMappingKind::Code(bcb) => insert(bcb), - } - } - - Some(CoverageSpans { bcb_has_mappings, mappings }) + code_mappings: &mut impl Extend<mappings::CodeMapping>, +) { + let sorted_spans = + from_mir::mir_to_initial_sorted_coverage_spans(mir_body, hir_info, basic_coverage_blocks); + let coverage_spans = SpansRefiner::refine_sorted_spans(sorted_spans); + code_mappings.extend(coverage_spans.into_iter().map(|RefinedCovspan { bcb, span, .. }| { + // Each span produced by the generator represents an ordinary code region. + mappings::CodeMapping { span, bcb } + })); } #[derive(Debug)] diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs index 099a354f45d..d1727a94a35 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -1,5 +1,7 @@ use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::FxHashSet; +use rustc_middle::bug; +use rustc_middle::mir::coverage::CoverageKind; use rustc_middle::mir::{ self, AggregateKind, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, @@ -179,16 +181,12 @@ fn is_closure_like(statement: &Statement<'_>) -> bool { /// If the MIR `Statement` has a span contributive to computing coverage spans, /// return it; otherwise return `None`. fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> { - use mir::coverage::CoverageKind; - match statement.kind { // These statements have spans that are often outside the scope of the executed source code // for their parent `BasicBlock`. StatementKind::StorageLive(_) | StatementKind::StorageDead(_) - // Ignore `ConstEvalCounter`s | StatementKind::ConstEvalCounter - // Ignore `Nop`s | StatementKind::Nop => None, // FIXME(#78546): MIR InstrumentCoverage - Can the source_info.span for `FakeRead` @@ -210,25 +208,31 @@ fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> { StatementKind::FakeRead(box (FakeReadCause::ForGuardBinding, _)) => None, // Retain spans from most other statements. - StatementKind::FakeRead(box (_, _)) // Not including `ForGuardBinding` + StatementKind::FakeRead(_) | StatementKind::Intrinsic(..) - | StatementKind::Coverage(box mir::Coverage { + | StatementKind::Coverage( // The purpose of `SpanMarker` is to be matched and accepted here. - kind: CoverageKind::SpanMarker - }) + CoverageKind::SpanMarker, + ) | StatementKind::Assign(_) | StatementKind::SetDiscriminant { .. } | StatementKind::Deinit(..) | StatementKind::Retag(_, _) | StatementKind::PlaceMention(..) - | StatementKind::AscribeUserType(_, _) => { - Some(statement.source_info.span) - } - - StatementKind::Coverage(box mir::Coverage { - // These coverage statements should not exist prior to coverage instrumentation. - kind: CoverageKind::CounterIncrement { .. } | CoverageKind::ExpressionUsed { .. } - }) => bug!("Unexpected coverage statement found during coverage instrumentation: {statement:?}"), + | StatementKind::AscribeUserType(_, _) => Some(statement.source_info.span), + + // Block markers are used for branch coverage, so ignore them here. + StatementKind::Coverage(CoverageKind::BlockMarker { .. }) => None, + + // These coverage statements should not exist prior to coverage instrumentation. + StatementKind::Coverage( + CoverageKind::CounterIncrement { .. } + | CoverageKind::ExpressionUsed { .. } + | CoverageKind::CondBitmapUpdate { .. } + | CoverageKind::TestVectorBitmapUpdate { .. }, + ) => bug!( + "Unexpected coverage statement found during coverage instrumentation: {statement:?}" + ), } } @@ -280,7 +284,7 @@ fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { /// /// [^1]Expansions result from Rust syntax including macros, syntactic sugar, /// etc.). -fn unexpand_into_body_span_with_visible_macro( +pub(crate) fn unexpand_into_body_span_with_visible_macro( original_span: Span, body_span: Span, ) -> Option<(Span, Option<Symbol>)> { diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs index d9a3c0cb162..ca64688e6b8 100644 --- a/compiler/rustc_mir_transform/src/coverage/tests.rs +++ b/compiler/rustc_mir_transform/src/coverage/tests.rs @@ -28,9 +28,9 @@ use super::counters; use super::graph::{self, BasicCoverageBlock}; use itertools::Itertools; -use rustc_data_structures::graph::WithNumNodes; -use rustc_data_structures::graph::WithSuccessors; +use rustc_data_structures::graph::{DirectedGraph, Successors}; use rustc_index::{Idx, IndexVec}; +use rustc_middle::bug; use rustc_middle::mir::*; use rustc_middle::ty; use rustc_span::{BytePos, Pos, Span, DUMMY_SP}; @@ -88,7 +88,6 @@ impl<'tcx> MockBlocks<'tcx> { | TerminatorKind::FalseEdge { real_target: ref mut target, .. } | TerminatorKind::FalseUnwind { real_target: ref mut target, .. } | TerminatorKind::Goto { ref mut target } - | TerminatorKind::InlineAsm { destination: Some(ref mut target), .. } | TerminatorKind::Yield { resume: ref mut target, .. } => *target = to_block, ref invalid => bug!("Invalid from_block: {:?}", invalid), } @@ -185,10 +184,12 @@ fn debug_basic_blocks(mir_body: &Body<'_>) -> String { | TerminatorKind::FalseEdge { real_target: target, .. } | TerminatorKind::FalseUnwind { real_target: target, .. } | TerminatorKind::Goto { target } - | TerminatorKind::InlineAsm { destination: Some(target), .. } | TerminatorKind::Yield { resume: target, .. } => { format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), target) } + TerminatorKind::InlineAsm { targets, .. } => { + format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), targets) + } TerminatorKind::SwitchInt { targets, .. } => { format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), targets) } diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs index 07e6ecccaa4..483fd753e70 100644 --- a/compiler/rustc_mir_transform/src/cross_crate_inline.rs +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -23,10 +23,6 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { return false; } - if tcx.intrinsic(def_id).is_some_and(|i| i.must_be_overridden) { - return false; - } - // This just reproduces the logic from Instance::requires_inline. match tcx.def_kind(def_id) { DefKind::Ctor(..) | DefKind::Closure => return true, diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index 14d9b0b0350..53a016f01ec 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -2,52 +2,23 @@ //! //! Currently, this pass only propagates scalar values. -use rustc_const_eval::interpret::{ - ImmTy, Immediate, InterpCx, OpTy, PlaceTy, PointerArithmetic, Projectable, -}; +use rustc_const_eval::const_eval::{throw_machine_stop_str, DummyMachine}; +use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; use rustc_data_structures::fx::FxHashMap; use rustc_hir::def::DefKind; -use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar}; +use rustc_middle::bug; +use rustc_middle::mir::interpret::{InterpResult, Scalar}; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::query::TyCtxtAt; -use rustc_middle::ty::layout::{LayoutOf, TyAndLayout}; +use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_mir_dataflow::value_analysis::{ Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace, }; use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor}; -use rustc_span::def_id::DefId; use rustc_span::DUMMY_SP; use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT}; -/// Macro for machine-specific `InterpError` without allocation. -/// (These will never be shown to the user, but they help diagnose ICEs.) -pub(crate) macro throw_machine_stop_str($($tt:tt)*) {{ - // We make a new local type for it. The type itself does not carry any information, - // but its vtable (for the `MachineStopType` trait) does. - #[derive(Debug)] - struct Zst; - // Printing this type shows the desired string. - impl std::fmt::Display for Zst { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, $($tt)*) - } - } - - impl rustc_middle::mir::interpret::MachineStopType for Zst { - fn diagnostic_message(&self) -> rustc_errors::DiagMessage { - self.to_string().into() - } - - fn add_args( - self: Box<Self>, - _: &mut dyn FnMut(rustc_errors::DiagArgName, rustc_errors::DiagArgValue), - ) {} - } - throw_machine_stop!(Zst) -}} - // These constants are somewhat random guesses and have not been optimized. // If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive). const BLOCK_LIMIT: usize = 100; @@ -194,7 +165,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { } } } - Rvalue::CheckedBinaryOp(op, box (left, right)) => { + Rvalue::BinaryOp(op, box (left, right)) if op.is_overflowing() => { // Flood everything now, so we can use `insert_value_idx` directly later. state.flood(target.as_ref(), self.map()); @@ -213,7 +184,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { if let Some(overflow_target) = overflow_target { let overflow = match overflow { FlatSet::Top => FlatSet::Top, - FlatSet::Elem(overflow) => FlatSet::Elem(Scalar::from_bool(overflow)), + FlatSet::Elem(overflow) => FlatSet::Elem(overflow), FlatSet::Bottom => FlatSet::Bottom, }; // We have flooded `target` earlier. @@ -232,7 +203,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { if let Some(target_len) = self.map().find_len(target.as_ref()) && let operand_ty = operand.ty(self.local_decls, self.tcx) && let Some(operand_ty) = operand_ty.builtin_deref(true) - && let ty::Array(_, len) = operand_ty.ty.kind() + && let ty::Array(_, len) = operand_ty.kind() && let Some(len) = Const::Ty(*len).try_eval_scalar_int(self.tcx, self.param_env) { state.insert_value_idx(target_len, FlatSet::Elem(len.into()), self.map()); @@ -293,15 +264,16 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { FlatSet::Top => FlatSet::Top, } } - Rvalue::BinaryOp(op, box (left, right)) => { + Rvalue::BinaryOp(op, box (left, right)) if !op.is_overflowing() => { // Overflows must be ignored here. + // The overflowing operators are handled in `handle_assign`. let (val, _overflow) = self.binary_op(state, *op, left, right); val } Rvalue::UnaryOp(op, operand) => match self.eval_operand(operand, state) { FlatSet::Elem(value) => self .ecx - .wrapping_unary_op(*op, &value) + .unary_op(*op, &value) .map_or(FlatSet::Top, |val| self.wrap_immediate(*val)), FlatSet::Bottom => FlatSet::Bottom, FlatSet::Top => FlatSet::Top, @@ -393,7 +365,9 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { } } Operand::Constant(box constant) => { - if let Ok(constant) = self.ecx.eval_mir_constant(&constant.const_, None, None) { + if let Ok(constant) = + self.ecx.eval_mir_constant(&constant.const_, constant.span, None) + { self.assign_constant(state, place, constant, &[]); } } @@ -464,7 +438,7 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { op: BinOp, left: &Operand<'tcx>, right: &Operand<'tcx>, - ) -> (FlatSet<Scalar>, FlatSet<bool>) { + ) -> (FlatSet<Scalar>, FlatSet<Scalar>) { let left = self.eval_operand(left, state); let right = self.eval_operand(right, state); @@ -472,9 +446,17 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { (FlatSet::Bottom, _) | (_, FlatSet::Bottom) => (FlatSet::Bottom, FlatSet::Bottom), // Both sides are known, do the actual computation. (FlatSet::Elem(left), FlatSet::Elem(right)) => { - match self.ecx.overflowing_binary_op(op, &left, &right) { - Ok((val, overflow)) => { - (FlatSet::Elem(val.to_scalar()), FlatSet::Elem(overflow)) + match self.ecx.binary_op(op, &left, &right) { + // Ideally this would return an Immediate, since it's sometimes + // a pair and sometimes not. But as a hack we always return a pair + // and just make the 2nd component `Bottom` when it does not exist. + Ok(val) => { + if matches!(val.layout.abi, Abi::ScalarPair(..)) { + let (val, overflow) = val.to_scalar_pair(); + (FlatSet::Elem(val), FlatSet::Elem(overflow)) + } else { + (FlatSet::Elem(val.to_scalar()), FlatSet::Bottom) + } } _ => (FlatSet::Top, FlatSet::Top), } @@ -500,7 +482,7 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { (FlatSet::Elem(arg_scalar), FlatSet::Bottom) } BinOp::Mul if layout.ty.is_integral() && arg_value == 0 => { - (FlatSet::Elem(arg_scalar), FlatSet::Elem(false)) + (FlatSet::Elem(arg_scalar), FlatSet::Elem(Scalar::from_bool(false))) } _ => (FlatSet::Top, FlatSet::Top), } @@ -712,6 +694,7 @@ fn try_write_constant<'tcx>( // Unsupported for now. ty::Array(_, _) + | ty::Pat(_, _) // Do not attempt to support indirection in constants. | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) @@ -886,159 +869,3 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { } } } - -pub(crate) struct DummyMachine; - -impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for DummyMachine { - rustc_const_eval::interpret::compile_time_machine!(<'mir, 'tcx>); - type MemoryKind = !; - const PANIC_ON_ALLOC_FAIL: bool = true; - - #[inline(always)] - fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool { - false // no reason to enforce alignment - } - - fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { - false - } - - fn before_access_global( - _tcx: TyCtxtAt<'tcx>, - _machine: &Self, - _alloc_id: AllocId, - alloc: ConstAllocation<'tcx>, - _static_def_id: Option<DefId>, - is_write: bool, - ) -> InterpResult<'tcx> { - if is_write { - throw_machine_stop_str!("can't write to global"); - } - - // If the static allocation is mutable, then we can't const prop it as its content - // might be different at runtime. - if alloc.inner().mutability.is_mut() { - throw_machine_stop_str!("can't access mutable globals in ConstProp"); - } - - Ok(()) - } - - fn find_mir_or_eval_fn( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _abi: rustc_target::spec::abi::Abi, - _args: &[rustc_const_eval::interpret::FnArg<'tcx, Self::Provenance>], - _destination: &rustc_const_eval::interpret::PlaceTy<'tcx, Self::Provenance>, - _target: Option<BasicBlock>, - _unwind: UnwindAction, - ) -> interpret::InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { - unimplemented!() - } - - fn panic_nounwind( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _msg: &str, - ) -> interpret::InterpResult<'tcx> { - unimplemented!() - } - - fn call_intrinsic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _args: &[rustc_const_eval::interpret::OpTy<'tcx, Self::Provenance>], - _destination: &rustc_const_eval::interpret::PlaceTy<'tcx, Self::Provenance>, - _target: Option<BasicBlock>, - _unwind: UnwindAction, - ) -> interpret::InterpResult<'tcx> { - unimplemented!() - } - - fn assert_panic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _msg: &rustc_middle::mir::AssertMessage<'tcx>, - _unwind: UnwindAction, - ) -> interpret::InterpResult<'tcx> { - unimplemented!() - } - - fn binary_ptr_op( - ecx: &InterpCx<'mir, 'tcx, Self>, - bin_op: BinOp, - left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, - right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, - ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> { - use rustc_middle::mir::BinOp::*; - Ok(match bin_op { - Eq | Ne | Lt | Le | Gt | Ge => { - // Types can differ, e.g. fn ptrs with different `for`. - assert_eq!(left.layout.abi, right.layout.abi); - let size = ecx.pointer_size(); - // Just compare the bits. ScalarPairs are compared lexicographically. - // We thus always compare pairs and simply fill scalars up with 0. - // If the pointer has provenance, `to_bits` will return `Err` and we bail out. - let left = match **left { - Immediate::Scalar(l) => (l.to_bits(size)?, 0), - Immediate::ScalarPair(l1, l2) => (l1.to_bits(size)?, l2.to_bits(size)?), - Immediate::Uninit => panic!("we should never see uninit data here"), - }; - let right = match **right { - Immediate::Scalar(r) => (r.to_bits(size)?, 0), - Immediate::ScalarPair(r1, r2) => (r1.to_bits(size)?, r2.to_bits(size)?), - Immediate::Uninit => panic!("we should never see uninit data here"), - }; - let res = match bin_op { - Eq => left == right, - Ne => left != right, - Lt => left < right, - Le => left <= right, - Gt => left > right, - Ge => left >= right, - _ => bug!(), - }; - (ImmTy::from_bool(res, *ecx.tcx), false) - } - - // Some more operations are possible with atomics. - // The return value always has the provenance of the *left* operand. - Add | Sub | BitOr | BitAnd | BitXor => { - throw_machine_stop_str!("pointer arithmetic is not handled") - } - - _ => span_bug!(ecx.cur_span(), "Invalid operator on pointers: {:?}", bin_op), - }) - } - - fn expose_ptr( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _ptr: interpret::Pointer<Self::Provenance>, - ) -> interpret::InterpResult<'tcx> { - unimplemented!() - } - - fn init_frame_extra( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _frame: rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance>, - ) -> interpret::InterpResult< - 'tcx, - rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>, - > { - unimplemented!() - } - - fn stack<'a>( - _ecx: &'a InterpCx<'mir, 'tcx, Self>, - ) -> &'a [rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] - { - // Return an empty stack instead of panicking, as `cur_span` uses it to evaluate constants. - &[] - } - - fn stack_mut<'a>( - _ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - ) -> &'a mut Vec< - rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>, - > { - unimplemented!() - } -} diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs index e6317e5469c..08dba1de500 100644 --- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -13,6 +13,7 @@ //! use crate::util::is_within_packed; +use rustc_middle::bug; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs index ca63f5550ae..370e930b740 100644 --- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -160,7 +160,7 @@ pub fn deduced_param_attrs<'tcx>( return &[]; } - // If the Freeze language item isn't present, then don't bother. + // If the Freeze lang item isn't present, then don't bother. if tcx.lang_items().freeze_trait().is_none() { return &[]; } diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index 2c8201b1903..b1016c0867c 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -135,6 +135,7 @@ use crate::MirPass; use rustc_data_structures::fx::{FxIndexMap, IndexEntry, IndexOccupiedEntry}; use rustc_index::bit_set::BitSet; use rustc_index::interval::SparseIntervalMatrix; +use rustc_middle::bug; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::HasLocalDecls; use rustc_middle::mir::{dump_mir, PassWhere}; @@ -563,7 +564,7 @@ impl WriteInfo { | Rvalue::ShallowInitBox(op, _) => { self.add_operand(op); } - Rvalue::BinaryOp(_, ops) | Rvalue::CheckedBinaryOp(_, ops) => { + Rvalue::BinaryOp(_, ops) => { for op in [&ops.0, &ops.1] { self.add_operand(op); } @@ -648,7 +649,8 @@ impl WriteInfo { } InlineAsmOperand::Const { .. } | InlineAsmOperand::SymFn { .. } - | InlineAsmOperand::SymStatic { .. } => (), + | InlineAsmOperand::SymStatic { .. } + | InlineAsmOperand::Label { .. } => {} } } } diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs index 0d600f0f937..9edb8bcee6e 100644 --- a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs +++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs @@ -1,6 +1,6 @@ use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; -use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::ty::{Ty, TyCtxt}; use std::fmt::Debug; use super::simplify::simplify_cfg; @@ -11,6 +11,7 @@ use super::simplify::simplify_cfg; /// let y: Option<()>; /// match (x,y) { /// (Some(_), Some(_)) => {0}, +/// (None, None) => {2}, /// _ => {1} /// } /// ``` @@ -23,10 +24,10 @@ use super::simplify::simplify_cfg; /// if discriminant_x == discriminant_y { /// match x { /// Some(_) => 0, -/// _ => 1, // <---- -/// } // | Actually the same bb -/// } else { // | -/// 1 // <-------------- +/// None => 2, +/// } +/// } else { +/// 1 /// } /// ``` /// @@ -47,18 +48,18 @@ use super::simplify::simplify_cfg; /// | | | /// ================= | | | /// | BBU | <-| | | ============================ -/// |---------------| | \-------> | BBD | -/// |---------------| | | |--------------------------| -/// | unreachable | | | | _dl = discriminant(P) | -/// ================= | | |--------------------------| -/// | | | switchInt(_dl) | -/// ================= | | | d | ---> BBD.2 +/// |---------------| \-------> | BBD | +/// |---------------| | |--------------------------| +/// | unreachable | | | _dl = discriminant(P) | +/// ================= | |--------------------------| +/// | | switchInt(_dl) | +/// ================= | | d | ---> BBD.2 /// | BB9 | <--------------- | otherwise | /// |---------------| ============================ /// | ... | /// ================= /// ``` -/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the +/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the /// code: /// - `BB1` is `parent` and `BBC, BBD` are children /// - `P` is `child_place` @@ -78,7 +79,7 @@ use super::simplify::simplify_cfg; /// |---------------------| | | switchInt(Q) | /// | switchInt(_t) | | | c | ---> BBC.2 /// | false | --------/ | d | ---> BBD.2 -/// | otherwise | ---------------- | otherwise | +/// | otherwise | /--------- | otherwise | /// ======================= | ============================ /// | /// ================= | @@ -87,16 +88,11 @@ use super::simplify::simplify_cfg; /// | ... | /// ================= /// ``` -/// -/// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`. -/// The filter on which `P` are allowed (together with discussion of its correctness) is found in -/// `may_hoist`. pub struct EarlyOtherwiseBranch; impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // unsound: https://github.com/rust-lang/rust/issues/95162 - sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts + sess.mir_opt_level() >= 2 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { @@ -172,7 +168,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { }; (value, targets.target_for_value(value)) }); - let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination); + // The otherwise either is the same target branch or an unreachable. + let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise()); // Create `bbEq` in example above let eq_switch = BasicBlockData::new(Some(Terminator { @@ -217,85 +214,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { } } -/// Returns true if computing the discriminant of `place` may be hoisted out of the branch -fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool { - // FIXME(JakobDegen): This is unsound. Someone could write code like this: - // ```rust - // let Q = val; - // if discriminant(P) == otherwise { - // let ptr = &mut Q as *mut _ as *mut u8; - // unsafe { *ptr = 10; } // Any invalid value for the type - // } - // - // match P { - // A => match Q { - // A => { - // // code - // } - // _ => { - // // don't use Q - // } - // } - // _ => { - // // don't use Q - // } - // }; - // ``` - // - // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an - // invalid value, which is UB. - // - // In order to fix this, we would either need to show that the discriminant computation of - // `place` is computed in all branches, including the `otherwise` branch, or we would need - // another analysis pass to determine that the place is fully initialized. It might even be best - // to have the hoisting be performed in a different pass and just do the CFG changing in this - // pass. - for (place, proj) in place.iter_projections() { - match proj { - // Dereferencing in the computation of `place` might cause issues from one of two - // categories. First, the referent might be invalid. We protect against this by - // dereferencing references only (not pointers). Second, the use of a reference may - // invalidate other references that are used later (for aliasing reasons). Consider - // where such an invalidated reference may appear: - // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so - // cannot contain referenced data. - // - In `BBU`: Not possible since that block contains only the `unreachable` terminator - // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to - // reaching that block in the input to our transformation, and so any data - // invalidated by that computation could not have been used there. - // - In `BB9`: Not possible since control flow might have reached `BB9` via the - // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would - // have invalidated the data when computing `discriminant(P)` - // So dereferencing here is correct. - ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() { - ty::Ref(..) => {} - _ => return false, - }, - // Field projections are always valid - ProjectionElem::Field(..) => {} - // We cannot allow - // downcasts either, since the correctness of the downcast may depend on the parent - // branch being taken. An easy example of this is - // ``` - // Q = discriminant(_3) - // P = (_3 as Variant) - // ``` - // However, checking if the child and parent place are the same and only erroring then - // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may - // be replaced by another optimization pass with any other condition that can be proven - // equivalent. - ProjectionElem::Downcast(..) => { - return false; - } - // We cannot allow indexing since the index may be out of bounds. - _ => { - return false; - } - } - } - true -} - #[derive(Debug)] struct OptimizationData<'tcx> { destination: BasicBlock, @@ -315,18 +233,40 @@ fn evaluate_candidate<'tcx>( return None; }; let parent_ty = parent_discr.ty(body.local_decls(), tcx); - let parent_dest = { - let poss = targets.otherwise(); - // If the fallthrough on the parent is trivially unreachable, we can let the - // children choose the destination - if bbs[poss].statements.len() == 0 - && bbs[poss].terminator().kind == TerminatorKind::Unreachable - { - None - } else { - Some(poss) - } - }; + if !bbs[targets.otherwise()].is_empty_unreachable() { + // Someone could write code like this: + // ```rust + // let Q = val; + // if discriminant(P) == otherwise { + // let ptr = &mut Q as *mut _ as *mut u8; + // // It may be difficult for us to effectively determine whether values are valid. + // // Invalid values can come from all sorts of corners. + // unsafe { *ptr = 10; } + // } + // + // match P { + // A => match Q { + // A => { + // // code + // } + // _ => { + // // don't use Q + // } + // } + // _ => { + // // don't use Q + // } + // }; + // ``` + // + // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an + // invalid value, which is UB. + // In order to fix this, **we would either need to show that the discriminant computation of + // `place` is computed in all branches**. + // FIXME(#95162) For the moment, we adopt a conservative approach and + // consider only the `otherwise` branch has no statements and an unreachable terminator. + return None; + } let (_, child) = targets.iter().next()?; let child_terminator = &bbs[child].terminator(); let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } = @@ -344,13 +284,7 @@ fn evaluate_candidate<'tcx>( let (_, Rvalue::Discriminant(child_place)) = &**boxed else { return None; }; - let destination = parent_dest.unwrap_or(child_targets.otherwise()); - - // Verify that the optimization is legal in general - // We can hoist evaluating the child discriminant out of the branch - if !may_hoist(tcx, body, *child_place) { - return None; - } + let destination = child_targets.otherwise(); // Verify that the optimization is legal for each branch for (value, child) in targets.iter() { @@ -411,5 +345,5 @@ fn verify_candidate_branch<'tcx>( if let Some(_) = iter.next() { return false; } - return true; + true } diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs index 96943435bab..d955b96d06a 100644 --- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -3,10 +3,10 @@ //! Box is not actually a pointer so it is incorrect to dereference it directly. use rustc_hir::def_id::DefId; -use rustc_index::Idx; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::ty::{Ty, TyCtxt}; use rustc_target::abi::FieldIdx; @@ -32,9 +32,9 @@ pub fn build_projection<'tcx>( ptr_ty: Ty<'tcx>, ) -> [PlaceElem<'tcx>; 3] { [ - PlaceElem::Field(FieldIdx::new(0), unique_ty), - PlaceElem::Field(FieldIdx::new(0), nonnull_ty), - PlaceElem::Field(FieldIdx::new(0), ptr_ty), + PlaceElem::Field(FieldIdx::ZERO, unique_ty), + PlaceElem::Field(FieldIdx::ZERO, nonnull_ty), + PlaceElem::Field(FieldIdx::ZERO, ptr_ty), ] } @@ -91,15 +91,14 @@ pub struct ElaborateBoxDerefs; impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { if let Some(def_id) = tcx.lang_items().owned_box() { - let unique_did = - tcx.adt_def(def_id).non_enum_variant().fields[FieldIdx::from_u32(0)].did; + let unique_did = tcx.adt_def(def_id).non_enum_variant().fields[FieldIdx::ZERO].did; let Some(nonnull_def) = tcx.type_of(unique_did).instantiate_identity().ty_adt_def() else { span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique") }; - let nonnull_did = nonnull_def.non_enum_variant().fields[FieldIdx::from_u32(0)].did; + let nonnull_did = nonnull_def.non_enum_variant().fields[FieldIdx::ZERO].did; let patch = MirPatch::new(body); diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs index 80e9172bbe1..0634e321ea3 100644 --- a/compiler/rustc_mir_transform/src/errors.rs +++ b/compiler/rustc_mir_transform/src/errors.rs @@ -1,11 +1,6 @@ -use std::borrow::Cow; - -use rustc_errors::{ - codes::*, Applicability, DecorateLint, Diag, DiagArgValue, DiagCtxt, DiagMessage, - EmissionGuarantee, IntoDiagnostic, Level, -}; +use rustc_errors::{codes::*, Diag, DiagMessage, LintDiagnostic}; use rustc_macros::{Diagnostic, LintDiagnostic, Subdiagnostic}; -use rustc_middle::mir::{AssertKind, UnsafetyViolationDetails}; +use rustc_middle::mir::AssertKind; use rustc_middle::ty::TyCtxt; use rustc_session::lint::{self, Lint}; use rustc_span::def_id::DefId; @@ -42,168 +37,6 @@ pub(crate) struct UnalignedPackedRef { pub span: Span, } -#[derive(LintDiagnostic)] -#[diag(mir_transform_unused_unsafe)] -pub(crate) struct UnusedUnsafe { - #[label(mir_transform_unused_unsafe)] - pub span: Span, - #[label] - pub nested_parent: Option<Span>, -} - -pub(crate) struct RequiresUnsafe { - pub span: Span, - pub details: RequiresUnsafeDetail, - pub enclosing: Option<Span>, - pub op_in_unsafe_fn_allowed: bool, -} - -// The primary message for this diagnostic should be '{$label} is unsafe and...', -// so we need to eagerly translate the label here, which isn't supported by the derive API -// We could also exhaustively list out the primary messages for all unsafe violations, -// but this would result in a lot of duplication. -impl<'a, G: EmissionGuarantee> IntoDiagnostic<'a, G> for RequiresUnsafe { - #[track_caller] - fn into_diagnostic(self, dcx: &'a DiagCtxt, level: Level) -> Diag<'a, G> { - let mut diag = Diag::new(dcx, level, fluent::mir_transform_requires_unsafe); - diag.code(E0133); - diag.span(self.span); - diag.span_label(self.span, self.details.label()); - let desc = dcx.eagerly_translate_to_string(self.details.label(), [].into_iter()); - diag.arg("details", desc); - diag.arg("op_in_unsafe_fn_allowed", self.op_in_unsafe_fn_allowed); - self.details.add_subdiagnostics(&mut diag); - if let Some(sp) = self.enclosing { - diag.span_label(sp, fluent::mir_transform_not_inherited); - } - diag - } -} - -#[derive(Clone)] -pub(crate) struct RequiresUnsafeDetail { - pub span: Span, - pub violation: UnsafetyViolationDetails, -} - -impl RequiresUnsafeDetail { - // FIXME: make this translatable - #[allow(rustc::diagnostic_outside_of_impl)] - #[allow(rustc::untranslatable_diagnostic)] - fn add_subdiagnostics<G: EmissionGuarantee>(&self, diag: &mut Diag<'_, G>) { - use UnsafetyViolationDetails::*; - match self.violation { - CallToUnsafeFunction => { - diag.note(fluent::mir_transform_call_to_unsafe_note); - } - UseOfInlineAssembly => { - diag.note(fluent::mir_transform_use_of_asm_note); - } - InitializingTypeWith => { - diag.note(fluent::mir_transform_initializing_valid_range_note); - } - CastOfPointerToInt => { - diag.note(fluent::mir_transform_const_ptr2int_note); - } - UseOfMutableStatic => { - diag.note(fluent::mir_transform_use_of_static_mut_note); - } - UseOfExternStatic => { - diag.note(fluent::mir_transform_use_of_extern_static_note); - } - DerefOfRawPointer => { - diag.note(fluent::mir_transform_deref_ptr_note); - } - AccessToUnionField => { - diag.note(fluent::mir_transform_union_access_note); - } - MutationOfLayoutConstrainedField => { - diag.note(fluent::mir_transform_mutation_layout_constrained_note); - } - BorrowOfLayoutConstrainedField => { - diag.note(fluent::mir_transform_mutation_layout_constrained_borrow_note); - } - CallToFunctionWith { ref missing, ref build_enabled } => { - diag.help(fluent::mir_transform_target_feature_call_help); - diag.arg( - "missing_target_features", - DiagArgValue::StrListSepByAnd( - missing.iter().map(|feature| Cow::from(feature.to_string())).collect(), - ), - ); - diag.arg("missing_target_features_count", missing.len()); - if !build_enabled.is_empty() { - diag.note(fluent::mir_transform_target_feature_call_note); - diag.arg( - "build_target_features", - DiagArgValue::StrListSepByAnd( - build_enabled - .iter() - .map(|feature| Cow::from(feature.to_string())) - .collect(), - ), - ); - diag.arg("build_target_features_count", build_enabled.len()); - } - } - } - } - - fn label(&self) -> DiagMessage { - use UnsafetyViolationDetails::*; - match self.violation { - CallToUnsafeFunction => fluent::mir_transform_call_to_unsafe_label, - UseOfInlineAssembly => fluent::mir_transform_use_of_asm_label, - InitializingTypeWith => fluent::mir_transform_initializing_valid_range_label, - CastOfPointerToInt => fluent::mir_transform_const_ptr2int_label, - UseOfMutableStatic => fluent::mir_transform_use_of_static_mut_label, - UseOfExternStatic => fluent::mir_transform_use_of_extern_static_label, - DerefOfRawPointer => fluent::mir_transform_deref_ptr_label, - AccessToUnionField => fluent::mir_transform_union_access_label, - MutationOfLayoutConstrainedField => { - fluent::mir_transform_mutation_layout_constrained_label - } - BorrowOfLayoutConstrainedField => { - fluent::mir_transform_mutation_layout_constrained_borrow_label - } - CallToFunctionWith { .. } => fluent::mir_transform_target_feature_call_label, - } - } -} - -pub(crate) struct UnsafeOpInUnsafeFn { - pub details: RequiresUnsafeDetail, - - /// These spans point to: - /// 1. the start of the function body - /// 2. the end of the function body - /// 3. the function signature - pub suggest_unsafe_block: Option<(Span, Span, Span)>, -} - -impl<'a> DecorateLint<'a, ()> for UnsafeOpInUnsafeFn { - #[track_caller] - fn decorate_lint<'b>(self, diag: &'b mut Diag<'a, ()>) { - let desc = diag.dcx.eagerly_translate_to_string(self.details.label(), [].into_iter()); - diag.arg("details", desc); - diag.span_label(self.details.span, self.details.label()); - self.details.add_subdiagnostics(diag); - - if let Some((start, end, fn_sig)) = self.suggest_unsafe_block { - diag.span_note(fn_sig, fluent::mir_transform_note); - diag.tool_only_multipart_suggestion( - fluent::mir_transform_suggestion, - vec![(start, " unsafe {".into()), (end, "}".into())], - Applicability::MaybeIncorrect, - ); - } - } - - fn msg(&self) -> DiagMessage { - fluent::mir_transform_unsafe_op_in_unsafe_fn - } -} - pub(crate) struct AssertLint<P> { pub span: Span, pub assert_kind: AssertKind<P>, @@ -215,7 +48,7 @@ pub(crate) enum AssertLintKind { UnconditionalPanic, } -impl<'a, P: std::fmt::Debug> DecorateLint<'a, ()> for AssertLint<P> { +impl<'a, P: std::fmt::Debug> LintDiagnostic<'a, ()> for AssertLint<P> { fn decorate_lint<'b>(self, diag: &'b mut Diag<'a, ()>) { let message = self.assert_kind.diagnostic_message(); self.assert_kind.add_args(&mut |name, value| { @@ -269,7 +102,7 @@ pub(crate) struct MustNotSupend<'tcx, 'a> { } // Needed for def_path_str -impl<'a> DecorateLint<'a, ()> for MustNotSupend<'_, '_> { +impl<'a> LintDiagnostic<'a, ()> for MustNotSupend<'_, '_> { fn decorate_lint<'b>(self, diag: &'b mut rustc_errors::Diag<'a, ()>) { diag.span_label(self.yield_sp, fluent::_subdiag::label); if let Some(reason) = self.reason { diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs index d9387ecd14c..5e3cd853675 100644 --- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -4,6 +4,7 @@ use rustc_middle::query::LocalCrate; use rustc_middle::query::Providers; use rustc_middle::ty::layout; use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::{bug, span_bug}; use rustc_session::lint::builtin::FFI_UNWIND_CALLS; use rustc_target::spec::abi::Abi; use rustc_target::spec::PanicStrategy; @@ -11,7 +12,7 @@ use rustc_target::spec::PanicStrategy; use crate::errors; /// Some of the functions declared as "may unwind" by `fn_can_unwind` can't actually unwind. In -/// particular, `extern "C"` is still considered as can-unwind on stable, but we need to to consider +/// particular, `extern "C"` is still considered as can-unwind on stable, but we need to consider /// it cannot-unwind here. So below we check `fn_can_unwind() && abi_can_unwind()` before concluding /// that a function call can unwind. fn abi_can_unwind(abi: Abi) -> bool { diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index e935dc7f5eb..434529ccff4 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -121,7 +121,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { fn is_fn_ref(ty: Ty<'tcx>) -> Option<(DefId, GenericArgsRef<'tcx>)> { let referent_ty = match ty.kind() { ty::Ref(_, referent_ty, _) => Some(referent_ty), - ty::RawPtr(ty_and_mut) => Some(&ty_and_mut.ty), + ty::RawPtr(referent_ty, _) => Some(referent_ty), _ => None, }; referent_ty @@ -158,7 +158,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { .lint_root; // FIXME: use existing printing routines to print the function signature let fn_sig = self.tcx.fn_sig(fn_id).instantiate(self.tcx, fn_args); - let unsafety = fn_sig.unsafety().prefix_str(); + let unsafety = fn_sig.safety().prefix_str(); let abi = match fn_sig.abi() { Abi::Rust => String::from(""), other_abi => { diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index a080e2423d4..9d2e7153eb5 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -82,6 +82,7 @@ //! Second, when writing constants in MIR, we do not write `Const::Slice` or `Const` //! that contain `AllocId`s. +use rustc_const_eval::const_eval::DummyMachine; use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind}; use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar}; use rustc_data_structures::fx::FxIndexSet; @@ -90,18 +91,18 @@ use rustc_hir::def::DefKind; use rustc_index::bit_set::BitSet; use rustc_index::newtype_index; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::interpret::GlobalAlloc; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::layout::LayoutOf; -use rustc_middle::ty::{self, Ty, TyCtxt, TypeAndMut}; +use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_span::def_id::DefId; use rustc_span::DUMMY_SP; use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT}; use smallvec::SmallVec; use std::borrow::Cow; -use crate::dataflow_const_prop::DummyMachine; use crate::ssa::{AssignedValue, SsaLocals}; use either::Either; @@ -121,7 +122,7 @@ impl<'tcx> MirPass<'tcx> for GVN { fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - let ssa = SsaLocals::new(body); + let ssa = SsaLocals::new(tcx, body, param_env); // Clone dominators as we need them while mutating the body. let dominators = body.basic_blocks.dominators().clone(); @@ -131,7 +132,7 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |local, value, location| { let value = match value { // We do not know anything of this assigned value. - AssignedValue::Arg | AssignedValue::Terminator(_) => None, + AssignedValue::Arg | AssignedValue::Terminator => None, // Try to get some insight. AssignedValue::Rvalue(rvalue) => { let value = state.simplify_rvalue(rvalue, location); @@ -222,7 +223,7 @@ enum Value<'tcx> { NullaryOp(NullOp<'tcx>, Ty<'tcx>), UnaryOp(UnOp, VnIndex), BinaryOp(BinOp, VnIndex, VnIndex), - CheckedBinaryOp(BinOp, VnIndex, VnIndex), + CheckedBinaryOp(BinOp, VnIndex, VnIndex), // FIXME get rid of this, work like MIR instead Cast { kind: CastKind, value: VnIndex, @@ -355,7 +356,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } fn insert_tuple(&mut self, values: Vec<VnIndex>) -> VnIndex { - self.insert(Value::Aggregate(AggregateTy::Tuple, VariantIdx::from_u32(0), values)) + self.insert(Value::Aggregate(AggregateTy::Tuple, VariantIdx::ZERO, values)) } #[instrument(level = "trace", skip(self), ret)] @@ -367,7 +368,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { Repeat(..) => return None, Constant { ref value, disambiguator: _ } => { - self.ecx.eval_mir_constant(value, None, None).ok()? + self.ecx.eval_mir_constant(value, DUMMY_SP, None).ok()? } Aggregate(kind, variant, ref fields) => { let fields = fields @@ -451,11 +452,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { AddressKind::Ref(bk) => Ty::new_ref( self.tcx, self.tcx.lifetimes.re_erased, - ty::TypeAndMut { ty: mplace.layout.ty, mutbl: bk.to_mutbl_lossy() }, + mplace.layout.ty, + bk.to_mutbl_lossy(), ), - AddressKind::Address(mutbl) => { - Ty::new_ptr(self.tcx, TypeAndMut { ty: mplace.layout.ty, mutbl }) - } + AddressKind::Address(mutbl) => Ty::new_ptr(self.tcx, mplace.layout.ty, mutbl), }; let layout = self.ecx.layout_of(ty).ok()?; ImmTy::from_immediate(pointer, layout).into() @@ -488,7 +488,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { NullOp::OffsetOf(fields) => { layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() } - NullOp::DebugAssertions => return None, + NullOp::UbChecks => return None, }; let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); let imm = ImmTy::try_from_uint(val, usize_layout)?; @@ -497,7 +497,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { UnaryOp(un_op, operand) => { let operand = self.evaluated[operand].as_ref()?; let operand = self.ecx.read_immediate(operand).ok()?; - let (val, _) = self.ecx.overflowing_unary_op(un_op, &operand).ok()?; + let val = self.ecx.unary_op(un_op, &operand).ok()?; val.into() } BinaryOp(bin_op, lhs, rhs) => { @@ -505,7 +505,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let lhs = self.ecx.read_immediate(lhs).ok()?; let rhs = self.evaluated[rhs].as_ref()?; let rhs = self.ecx.read_immediate(rhs).ok()?; - let (val, _) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; + let val = self.ecx.binary_op(bin_op, &lhs, &rhs).ok()?; val.into() } CheckedBinaryOp(bin_op, lhs, rhs) => { @@ -513,14 +513,11 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let lhs = self.ecx.read_immediate(lhs).ok()?; let rhs = self.evaluated[rhs].as_ref()?; let rhs = self.ecx.read_immediate(rhs).ok()?; - let (val, overflowed) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; - let tuple = Ty::new_tup_from_iter( - self.tcx, - [val.layout.ty, self.tcx.types.bool].into_iter(), - ); - let tuple = self.ecx.layout_of(tuple).ok()?; - ImmTy::from_scalar_pair(val.to_scalar(), Scalar::from_bool(overflowed), tuple) - .into() + let val = self + .ecx + .binary_op(bin_op.wrapping_to_overflowing().unwrap(), &lhs, &rhs) + .ok()?; + val.into() } Cast { kind, value, from: _, to } => match kind { CastKind::IntToInt | CastKind::IntToFloat => { @@ -595,7 +592,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let ty = place.ty(self.local_decls, self.tcx).ty; if let Some(Mutability::Not) = ty.ref_mutability() && let Some(pointee_ty) = ty.builtin_deref(true) - && pointee_ty.ty.is_freeze(self.tcx, self.param_env) + && pointee_ty.is_freeze(self.tcx, self.param_env) { // An immutable borrow `_x` always points to the same value for the // lifetime of the borrow, so we can merge all instances of `*_x`. @@ -725,6 +722,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { // Invariant: `value` holds the value up-to the `index`th projection excluded. let mut value = self.locals[place.local]?; for (index, proj) in place.projection.iter().enumerate() { + if let Value::Projection(pointer, ProjectionElem::Deref) = *self.get(value) + && let Value::Address { place: mut pointee, kind, .. } = *self.get(pointer) + && let AddressKind::Ref(BorrowKind::Shared) = kind + && let Some(v) = self.simplify_place_value(&mut pointee, location) + { + value = v; + place_ref = pointee.project_deeper(&place.projection[index..], self.tcx).as_ref(); + } if let Some(local) = self.try_as_local(value, location) { // Both `local` and `Place { local: place.local, projection: projection[..index] }` // hold the same value. Therefore, following place holds the value in the original @@ -736,6 +741,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { value = self.project(base, value, proj)?; } + if let Value::Projection(pointer, ProjectionElem::Deref) = *self.get(value) + && let Value::Address { place: mut pointee, kind, .. } = *self.get(pointer) + && let AddressKind::Ref(BorrowKind::Shared) = kind + && let Some(v) = self.simplify_place_value(&mut pointee, location) + { + value = v; + place_ref = pointee.project_deeper(&[], self.tcx).as_ref(); + } if let Some(new_local) = self.try_as_local(value, location) { place_ref = PlaceRef { local: new_local, projection: &[] }; } @@ -815,23 +828,18 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { // on both operands for side effect. let lhs = lhs?; let rhs = rhs?; - if let Some(value) = self.simplify_binary(op, false, ty, lhs, rhs) { - return Some(value); - } - Value::BinaryOp(op, lhs, rhs) - } - Rvalue::CheckedBinaryOp(op, box (ref mut lhs, ref mut rhs)) => { - let ty = lhs.ty(self.local_decls, self.tcx); - let lhs = self.simplify_operand(lhs, location); - let rhs = self.simplify_operand(rhs, location); - // Only short-circuit options after we called `simplify_operand` - // on both operands for side effect. - let lhs = lhs?; - let rhs = rhs?; - if let Some(value) = self.simplify_binary(op, true, ty, lhs, rhs) { - return Some(value); + + if let Some(op) = op.overflowing_to_wrapping() { + if let Some(value) = self.simplify_binary(op, true, ty, lhs, rhs) { + return Some(value); + } + Value::CheckedBinaryOp(op, lhs, rhs) + } else { + if let Some(value) = self.simplify_binary(op, false, ty, lhs, rhs) { + return Some(value); + } + Value::BinaryOp(op, lhs, rhs) } - Value::CheckedBinaryOp(op, lhs, rhs) } Rvalue::UnaryOp(op, ref mut arg) => { let arg = self.simplify_operand(arg, location)?; @@ -886,6 +894,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum, // Coroutines are never ZST, as they at least contain the implicit states. AggregateKind::Coroutine(..) => false, + AggregateKind::RawPtr(..) => bug!("MIR for RawPtr aggregate must have 2 fields"), }; if is_zst { @@ -911,6 +920,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } // Do not track unions. AggregateKind::Adt(_, _, _, _, Some(_)) => return None, + // FIXME: Do the extra work to GVN `from_raw_parts` + AggregateKind::RawPtr(..) => return None, }; let fields: Option<Vec<_>> = fields @@ -1115,9 +1126,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { if let Value::Cast { kind, from, to, .. } = self.get(inner) && let CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize) = kind && let Some(from) = from.builtin_deref(true) - && let ty::Array(_, len) = from.ty.kind() + && let ty::Array(_, len) = from.kind() && let Some(to) = to.builtin_deref(true) - && let ty::Slice(..) = to.ty.kind() + && let ty::Slice(..) = to.kind() { return self.insert_constant(Const::from_ty_const(*len, self.tcx)); } @@ -1203,7 +1214,7 @@ impl<'tcx> VnState<'_, 'tcx> { // not give the same value as the former mention. && value.is_deterministic() { - return Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_: value }); + return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_: value }); } let op = self.evaluated[index].as_ref()?; @@ -1220,7 +1231,7 @@ impl<'tcx> VnState<'_, 'tcx> { assert!(!value.may_have_provenance(self.tcx, op.layout.size)); let const_ = Const::Val(value, op.layout.ty); - Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_ }) + Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ }) } /// If there is a local which is assigned `index`, and its assignment strictly dominates `loc`, diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 36546a03cdf..fe2237dd2e9 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -1,17 +1,17 @@ //! Inlining pass for MIR functions use crate::deref_separator::deref_finder; use rustc_attr::InlineAttr; -use rustc_const_eval::transform::validate::validate_types; use rustc_hir::def::DefKind; use rustc_hir::def_id::DefId; use rustc_index::bit_set::BitSet; use rustc_index::Idx; +use rustc_middle::bug; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TypeVisitableExt; use rustc_middle::ty::{self, Instance, InstanceDef, ParamEnv, Ty, TyCtxt}; -use rustc_session::config::OptLevel; +use rustc_session::config::{DebugInfo, OptLevel}; use rustc_span::source_map::Spanned; use rustc_span::sym; use rustc_target::abi::FieldIdx; @@ -20,6 +20,7 @@ use rustc_target::spec::abi::Abi; use crate::cost_checker::CostChecker; use crate::simplify::simplify_cfg; use crate::util; +use crate::validate::validate_types; use std::iter; use std::ops::{Range, RangeFrom}; @@ -165,7 +166,7 @@ impl<'tcx> Inliner<'tcx> { caller_body: &mut Body<'tcx>, callsite: &CallSite<'tcx>, ) -> Result<std::ops::Range<BasicBlock>, &'static str> { - self.check_mir_is_available(caller_body, &callsite.callee)?; + self.check_mir_is_available(caller_body, callsite.callee)?; let callee_attrs = self.tcx.codegen_fn_attrs(callsite.callee.def_id()); let cross_crate_inlinable = self.tcx.cross_crate_inlinable(callsite.callee.def_id()); @@ -213,6 +214,7 @@ impl<'tcx> Inliner<'tcx> { MirPhase::Runtime(RuntimePhase::Optimized), self.param_env, &callee_body, + &caller_body, ) .is_empty() { @@ -297,7 +299,7 @@ impl<'tcx> Inliner<'tcx> { fn check_mir_is_available( &self, caller_body: &Body<'tcx>, - callee: &Instance<'tcx>, + callee: Instance<'tcx>, ) -> Result<(), &'static str> { let caller_def_id = caller_body.source.def_id(); let callee_def_id = callee.def_id(); @@ -323,7 +325,7 @@ impl<'tcx> Inliner<'tcx> { // do not need to catch this here, we can wait until the inliner decides to continue // inlining a second time. InstanceDef::VTableShim(_) - | InstanceDef::ReifyShim(_) + | InstanceDef::ReifyShim(..) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } @@ -331,7 +333,8 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) - | InstanceDef::FnPtrAddrShim(..) => return Ok(()), + | InstanceDef::FnPtrAddrShim(..) + | InstanceDef::AsyncDropGlueCtorShim(..) => return Ok(()), } if self.tcx.is_constructor(callee_def_id) { @@ -353,7 +356,7 @@ impl<'tcx> Inliner<'tcx> { // If we know for sure that the function we're calling will itself try to // call us, then we avoid inlining that function. - if self.tcx.mir_callgraph_reachable((*callee, caller_def_id.expect_local())) { + if self.tcx.mir_callgraph_reachable((callee, caller_def_id.expect_local())) { return Err("caller might be reachable from callee (query cycle avoidance)"); } @@ -565,7 +568,8 @@ impl<'tcx> Inliner<'tcx> { mut callee_body: Body<'tcx>, ) { let terminator = caller_body[callsite.block].terminator.take().unwrap(); - let TerminatorKind::Call { args, destination, unwind, target, .. } = terminator.kind else { + let TerminatorKind::Call { func, args, destination, unwind, target, .. } = terminator.kind + else { bug!("unexpected terminator kind {:?}", terminator.kind); }; @@ -697,7 +701,19 @@ impl<'tcx> Inliner<'tcx> { // Insert all of the (mapped) parts of the callee body into the caller. caller_body.local_decls.extend(callee_body.drain_vars_and_temps()); caller_body.source_scopes.extend(&mut callee_body.source_scopes.drain(..)); - caller_body.var_debug_info.append(&mut callee_body.var_debug_info); + if self + .tcx + .sess + .opts + .unstable_opts + .inline_mir_preserve_debug + .unwrap_or(self.tcx.sess.opts.debuginfo != DebugInfo::None) + { + // Note that we need to preserve these in the standard library so that + // people working on rust can build with or without debuginfo while + // still getting consistent results from the mir-opt tests. + caller_body.var_debug_info.append(&mut callee_body.var_debug_info); + } caller_body.basic_blocks_mut().extend(callee_body.basic_blocks_mut().drain(..)); caller_body[callsite.block].terminator = Some(Terminator { @@ -705,18 +721,31 @@ impl<'tcx> Inliner<'tcx> { kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) }, }); - // Copy only unevaluated constants from the callee_body into the caller_body. - // Although we are only pushing `ConstKind::Unevaluated` consts to - // `required_consts`, here we may not only have `ConstKind::Unevaluated` - // because we are calling `instantiate_and_normalize_erasing_regions`. - caller_body.required_consts.extend(callee_body.required_consts.iter().copied().filter( - |&ct| match ct.const_ { - Const::Ty(_) => { - bug!("should never encounter ty::UnevaluatedConst in `required_consts`") - } - Const::Val(..) | Const::Unevaluated(..) => true, - }, - )); + // Copy required constants from the callee_body into the caller_body. Although we are only + // pushing unevaluated consts to `required_consts`, here they may have been evaluated + // because we are calling `instantiate_and_normalize_erasing_regions` -- so we filter again. + caller_body.required_consts.extend( + callee_body.required_consts.into_iter().filter(|ct| ct.const_.is_required_const()), + ); + // Now that we incorporated the callee's `required_consts`, we can remove the callee from + // `mentioned_items` -- but we have to take their `mentioned_items` in return. This does + // some extra work here to save the monomorphization collector work later. It helps a lot, + // since monomorphization can avoid a lot of work when the "mentioned items" are similar to + // the actually used items. By doing this we can entirely avoid visiting the callee! + // We need to reconstruct the `required_item` for the callee so that we can find and + // remove it. + let callee_item = MentionedItem::Fn(func.ty(caller_body, self.tcx)); + if let Some(idx) = + caller_body.mentioned_items.iter().position(|item| item.node == callee_item) + { + // We found the callee, so remove it and add its items instead. + caller_body.mentioned_items.remove(idx); + caller_body.mentioned_items.extend(callee_body.mentioned_items); + } else { + // If we can't find the callee, there's no point in adding its items. Probably it + // already got removed by being inlined elsewhere in the same function, so we already + // took its items. + } } fn make_call_args( @@ -1036,8 +1065,8 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { { bug!("False unwinds should have been removed before inlining") } - TerminatorKind::InlineAsm { ref mut destination, ref mut unwind, .. } => { - if let Some(ref mut tgt) = *destination { + TerminatorKind::InlineAsm { ref mut targets, ref mut unwind, .. } => { + for tgt in targets.iter_mut() { *tgt = self.map_block(*tgt); } *unwind = self.map_unwind(*unwind); @@ -1051,13 +1080,14 @@ fn try_instance_mir<'tcx>( tcx: TyCtxt<'tcx>, instance: InstanceDef<'tcx>, ) -> Result<&'tcx Body<'tcx>, &'static str> { - if let ty::InstanceDef::DropGlue(_, Some(ty)) = instance + if let ty::InstanceDef::DropGlue(_, Some(ty)) + | ty::InstanceDef::AsyncDropGlueCtorShim(_, Some(ty)) = instance && let ty::Adt(def, args) = ty.kind() { let fields = def.all_fields(); for field in fields { let field_ty = field.ty(tcx, args); - if field_ty.has_param() && field_ty.has_projections() { + if field_ty.has_param() && field_ty.has_aliases() { return Err("cannot build drop shim for polymorphic type"); } } diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index f2b6dcac586..8c5f965108b 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -5,6 +5,7 @@ use rustc_middle::mir::TerminatorKind; use rustc_middle::ty::TypeVisitableExt; use rustc_middle::ty::{self, GenericArgsRef, InstanceDef, TyCtxt}; use rustc_session::Limit; +use rustc_span::sym; // FIXME: check whether it is cheaper to precompute the entire call graph instead of invoking // this query ridiculously often. @@ -84,7 +85,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( // again, a function item can end up getting inlined. Thus we'll be able to cause // a cycle that way InstanceDef::VTableShim(_) - | InstanceDef::ReifyShim(_) + | InstanceDef::ReifyShim(..) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } @@ -93,8 +94,10 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::CloneShim(..) => {} // This shim does not call any other functions, thus there can be no recursion. - InstanceDef::FnPtrAddrShim(..) => continue, - InstanceDef::DropGlue(..) => { + InstanceDef::FnPtrAddrShim(..) => { + continue; + } + InstanceDef::DropGlue(..) | InstanceDef::AsyncDropGlueCtorShim(..) => { // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this // needs some more analysis. @@ -164,11 +167,20 @@ pub(crate) fn mir_inliner_callees<'tcx>( let mut calls = FxIndexSet::default(); for bb_data in body.basic_blocks.iter() { let terminator = bb_data.terminator(); - if let TerminatorKind::Call { func, .. } = &terminator.kind { + if let TerminatorKind::Call { func, args: call_args, .. } = &terminator.kind { let ty = func.ty(&body.local_decls, tcx); - let call = match ty.kind() { - ty::FnDef(def_id, args) => (*def_id, *args), - _ => continue, + let ty::FnDef(def_id, generic_args) = ty.kind() else { + continue; + }; + let call = if tcx.is_intrinsic(*def_id, sym::const_eval_select) { + let func = &call_args[2].node; + let ty = func.ty(&body.local_decls, tcx); + let ty::FnDef(def_id, generic_args) = ty.kind() else { + continue; + }; + (*def_id, *generic_args) + } else { + (*def_id, *generic_args) }; calls.insert(call); } diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs index 6b33d81c1c4..f1adeab3f88 100644 --- a/compiler/rustc_mir_transform/src/instsimplify.rs +++ b/compiler/rustc_mir_transform/src/instsimplify.rs @@ -1,10 +1,13 @@ //! Performs various peephole optimizations. use crate::simplify::simplify_duplicate_switch_targets; +use rustc_ast::attr; +use rustc_middle::bug; use rustc_middle::mir::*; use rustc_middle::ty::layout; use rustc_middle::ty::layout::ValidityRequirement; use rustc_middle::ty::{self, GenericArgsRef, ParamEnv, Ty, TyCtxt}; +use rustc_span::sym; use rustc_span::symbol::Symbol; use rustc_target::abi::FieldIdx; use rustc_target::spec::abi::Abi; @@ -22,13 +25,19 @@ impl<'tcx> MirPass<'tcx> for InstSimplify { local_decls: &body.local_decls, param_env: tcx.param_env_reveal_all_normalized(body.source.def_id()), }; + let preserve_ub_checks = + attr::contains_name(tcx.hir().krate_attrs(), sym::rustc_preserve_ub_checks); for block in body.basic_blocks.as_mut() { for statement in block.statements.iter_mut() { match statement.kind { StatementKind::Assign(box (_place, ref mut rvalue)) => { + if !preserve_ub_checks { + ctx.simplify_ub_check(&statement.source_info, rvalue); + } ctx.simplify_bool_cmp(&statement.source_info, rvalue); ctx.simplify_ref_deref(&statement.source_info, rvalue); ctx.simplify_len(&statement.source_info, rvalue); + ctx.simplify_ptr_aggregate(&statement.source_info, rvalue); ctx.simplify_cast(rvalue); } _ => {} @@ -51,8 +60,17 @@ struct InstSimplifyContext<'tcx, 'a> { impl<'tcx> InstSimplifyContext<'tcx, '_> { fn should_simplify(&self, source_info: &SourceInfo, rvalue: &Rvalue<'tcx>) -> bool { + self.should_simplify_custom(source_info, "Rvalue", rvalue) + } + + fn should_simplify_custom( + &self, + source_info: &SourceInfo, + label: &str, + value: impl std::fmt::Debug, + ) -> bool { self.tcx.consider_optimizing(|| { - format!("InstSimplify - Rvalue: {rvalue:?} SourceInfo: {source_info:?}") + format!("InstSimplify - {label}: {value:?} SourceInfo: {source_info:?}") }) } @@ -104,7 +122,7 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { if a.const_.ty().is_bool() { a.const_.try_to_bool() } else { None } } - /// Transform "&(*a)" ==> "a". + /// Transform `&(*a)` ==> `a`. fn simplify_ref_deref(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Ref(_, _, place) = rvalue { if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() { @@ -124,7 +142,7 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { } } - /// Transform "Len([_; N])" ==> "N". + /// Transform `Len([_; N])` ==> `N`. fn simplify_len(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Len(ref place) = *rvalue { let place_ty = place.ty(self.local_decls, self.tcx).ty; @@ -140,6 +158,38 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { } } + /// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`. + fn simplify_ptr_aggregate(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Aggregate(box AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue + { + let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx); + if meta_ty.is_unit() { + // The mutable borrows we're holding prevent printing `rvalue` here + if !self.should_simplify_custom( + source_info, + "Aggregate::RawPtr", + (&pointee_ty, *mutability, &fields), + ) { + return; + } + + let mut fields = std::mem::take(fields); + let _meta = fields.pop().unwrap(); + let data = fields.pop().unwrap(); + let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability); + *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty); + } + } + } + + fn simplify_ub_check(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue { + let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks()); + let constant = ConstOperand { span: source_info.span, const_, user_ty: None }; + *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant))); + } + } + fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Cast(kind, operand, cast_ty) = rvalue { let operand_ty = operand.ty(self.local_decls, self.tcx); diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs index ad8f21ffbda..ae807655b68 100644 --- a/compiler/rustc_mir_transform/src/jump_threading.rs +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -36,10 +36,12 @@ //! cost by `MAX_COST`. use rustc_arena::DroplessArena; +use rustc_const_eval::const_eval::DummyMachine; use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable}; use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; @@ -50,7 +52,6 @@ use rustc_span::DUMMY_SP; use rustc_target::abi::{TagEncoding, Variants}; use crate::cost_checker::CostChecker; -use crate::dataflow_const_prop::DummyMachine; pub struct JumpThreading; @@ -416,7 +417,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { match rhs { // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. Operand::Constant(constant) => { - let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?; + let constant = + self.ecx.eval_mir_constant(&constant.const_, constant.span, None).ok()?; self.process_constant(bb, lhs, constant, state); } // Transfer the conditions on the copied rhs. diff --git a/compiler/rustc_mir_transform/src/known_panics_lint.rs b/compiler/rustc_mir_transform/src/known_panics_lint.rs index 27477769cef..0fa5c1b9126 100644 --- a/compiler/rustc_mir_transform/src/known_panics_lint.rs +++ b/compiler/rustc_mir_transform/src/known_panics_lint.rs @@ -6,13 +6,15 @@ use std::fmt::Debug; +use rustc_const_eval::const_eval::DummyMachine; use rustc_const_eval::interpret::{ format_interp_error, ImmTy, InterpCx, InterpResult, Projectable, Scalar, }; use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; use rustc_hir::HirId; -use rustc_index::{bit_set::BitSet, Idx, IndexVec}; +use rustc_index::{bit_set::BitSet, IndexVec}; +use rustc_middle::bug; use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; @@ -20,7 +22,6 @@ use rustc_middle::ty::{self, ConstInt, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisi use rustc_span::Span; use rustc_target::abi::{Abi, FieldIdx, HasDataLayout, Size, TargetDataLayout, VariantIdx}; -use crate::dataflow_const_prop::DummyMachine; use crate::errors::{AssertLint, AssertLintKind}; use crate::MirLint; @@ -124,10 +125,8 @@ impl<'tcx> Value<'tcx> { fields.ensure_contains_elem(*idx, || Value::Uninit) } (PlaceElem::Field(..), val @ Value::Uninit) => { - *val = Value::Aggregate { - variant: VariantIdx::new(0), - fields: Default::default(), - }; + *val = + Value::Aggregate { variant: VariantIdx::ZERO, fields: Default::default() }; val.project_mut(&[*proj])? } _ => return None, @@ -261,7 +260,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // manually normalized. let val = self.tcx.try_normalize_erasing_regions(self.param_env, c.const_).ok()?; - self.use_ecx(|this| this.ecx.eval_mir_constant(&val, Some(c.span), None))? + self.use_ecx(|this| this.ecx.eval_mir_constant(&val, c.span, None))? .as_mplace_or_imm() .right() } @@ -305,20 +304,25 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>, location: Location) -> Option<()> { let arg = self.eval_operand(arg)?; - if let (val, true) = self.use_ecx(|this| { - let val = this.ecx.read_immediate(&arg)?; - let (_res, overflow) = this.ecx.overflowing_unary_op(op, &val)?; - Ok((val, overflow)) - })? { - // `AssertKind` only has an `OverflowNeg` variant, so make sure that is - // appropriate to use. - assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); - self.report_assert_as_lint( - location, - AssertLintKind::ArithmeticOverflow, - AssertKind::OverflowNeg(val.to_const_int()), - ); - return None; + // The only operator that can overflow is `Neg`. + if op == UnOp::Neg && arg.layout.ty.is_integral() { + // Compute this as `0 - arg` so we can use `SubWithOverflow` to check for overflow. + let (arg, overflow) = self.use_ecx(|this| { + let arg = this.ecx.read_immediate(&arg)?; + let (_res, overflow) = this + .ecx + .binary_op(BinOp::SubWithOverflow, &ImmTy::from_int(0, arg.layout), &arg)? + .to_scalar_pair(); + Ok((arg, overflow.to_bool()?)) + })?; + if overflow { + self.report_assert_as_lint( + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::OverflowNeg(arg.to_const_int()), + ); + return None; + } } Some(()) @@ -364,11 +368,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } } - if let (Some(l), Some(r)) = (l, r) { - // The remaining operators are handled through `overflowing_binary_op`. + // Div/Rem are handled via the assertions they trigger. + // But for Add/Sub/Mul, those assertions only exist in debug builds, and we want to + // lint in release builds as well, so we check on the operation instead. + // So normalize to the "overflowing" operator, and then ensure that it + // actually is an overflowing operator. + let op = op.wrapping_to_overflowing().unwrap_or(op); + // The remaining operators are handled through `wrapping_to_overflowing`. + if let (Some(l), Some(r)) = (l, r) + && l.layout.ty.is_integral() + && op.is_overflowing() + { if self.use_ecx(|this| { - let (_res, overflow) = this.ecx.overflowing_binary_op(op, &l, &r)?; - Ok(overflow) + let (_res, overflow) = this.ecx.binary_op(op, &l, &r)?.to_scalar_pair(); + overflow.to_bool() })? { self.report_assert_as_lint( location, @@ -402,15 +415,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); self.check_binary_op(*op, left, right, location)?; } - Rvalue::CheckedBinaryOp(op, box (left, right)) => { - trace!( - "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", - op, - left, - right - ); - self.check_binary_op(*op, left, right, location)?; - } // Do not try creating references (#67862) Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { @@ -556,24 +560,16 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let right = self.eval_operand(right)?; let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; - let val = - self.use_ecx(|this| this.ecx.wrapping_binary_op(bin_op, &left, &right))?; - val.into() - } - - CheckedBinaryOp(bin_op, box (ref left, ref right)) => { - let left = self.eval_operand(left)?; - let left = self.use_ecx(|this| this.ecx.read_immediate(&left))?; - - let right = self.eval_operand(right)?; - let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; - - let (val, overflowed) = - self.use_ecx(|this| this.ecx.overflowing_binary_op(bin_op, &left, &right))?; - let overflowed = ImmTy::from_bool(overflowed, self.tcx); - Value::Aggregate { - variant: VariantIdx::new(0), - fields: [Value::from(val), overflowed.into()].into_iter().collect(), + let val = self.use_ecx(|this| this.ecx.binary_op(bin_op, &left, &right))?; + if matches!(val.layout.abi, Abi::ScalarPair(..)) { + // FIXME `Value` should properly support pairs in `Immediate`... but currently it does not. + let (val, overflow) = val.to_pair(&self.ecx); + Value::Aggregate { + variant: VariantIdx::ZERO, + fields: [val.into(), overflow.into()].into_iter().collect(), + } + } else { + val.into() } } @@ -581,36 +577,25 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let operand = self.eval_operand(operand)?; let val = self.use_ecx(|this| this.ecx.read_immediate(&operand))?; - let val = self.use_ecx(|this| this.ecx.wrapping_unary_op(un_op, &val))?; + let val = self.use_ecx(|this| this.ecx.unary_op(un_op, &val))?; val.into() } - Aggregate(ref kind, ref fields) => { - // Do not const prop union fields as they can be - // made to produce values that don't match their - // underlying layout's type (see ICE #121534). - // If the last element of the `Adt` tuple - // is `Some` it indicates the ADT is a union - if let AggregateKind::Adt(_, _, _, _, Some(_)) = **kind { - return None; - }; - Value::Aggregate { - fields: fields - .iter() - .map(|field| { - self.eval_operand(field).map_or(Value::Uninit, Value::Immediate) - }) - .collect(), - variant: match **kind { - AggregateKind::Adt(_, variant, _, _, _) => variant, - AggregateKind::Array(_) - | AggregateKind::Tuple - | AggregateKind::Closure(_, _) - | AggregateKind::Coroutine(_, _) - | AggregateKind::CoroutineClosure(_, _) => VariantIdx::new(0), - }, - } - } + Aggregate(ref kind, ref fields) => Value::Aggregate { + fields: fields + .iter() + .map(|field| self.eval_operand(field).map_or(Value::Uninit, Value::Immediate)) + .collect(), + variant: match **kind { + AggregateKind::Adt(_, variant, _, _, _) => variant, + AggregateKind::Array(_) + | AggregateKind::Tuple + | AggregateKind::RawPtr(_, _) + | AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(_, _) => VariantIdx::ZERO, + }, + }, Repeat(ref op, n) => { trace!(?op, ?n); @@ -639,7 +624,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { NullOp::OffsetOf(fields) => { op_layout.offset_of_subfield(self, fields.iter()).bytes() } - NullOp::DebugAssertions => return None, + NullOp::UbChecks => return None, }; ImmTy::from_scalar(Scalar::from_target_usize(val, self), layout).into() } @@ -798,7 +783,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { if let Some(ref value) = self.eval_operand(discr) && let Some(value_const) = self.use_ecx(|this| this.ecx.read_scalar(value)) && let Ok(constant) = value_const.try_to_int() - && let Ok(constant) = constant.to_bits(constant.size()) + && let Ok(constant) = constant.try_to_bits(constant.size()) { // We managed to evaluate the discriminant, so we know we only need to visit // one target. @@ -897,13 +882,20 @@ impl CanConstProp { }; for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { let ty = body.local_decls[local].ty; - match tcx.layout_of(param_env.and(ty)) { - Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} - // Either the layout fails to compute, then we can't use this local anyway - // or the local is too large, then we don't want to. - _ => { - *val = ConstPropMode::NoPropagation; - continue; + if ty.is_union() { + // Unions are incompatible with the current implementation of + // const prop because Rust has no concept of an active + // variant of a union + *val = ConstPropMode::NoPropagation; + } else { + match tcx.layout_of(param_env.and(ty)) { + Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} + // Either the layout fails to compute, then we can't use this local anyway + // or the local is too large, then we don't want to. + _ => { + *val = ConstPropMode::NoPropagation; + continue; + } } } } diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs index 8be96b6ba8f..e407929c9a7 100644 --- a/compiler/rustc_mir_transform/src/large_enums.rs +++ b/compiler/rustc_mir_transform/src/large_enums.rs @@ -1,7 +1,7 @@ -use crate::rustc_middle::ty::util::IntTypeExt; use rustc_data_structures::fx::FxHashMap; use rustc_middle::mir::interpret::AllocId; use rustc_middle::mir::*; +use rustc_middle::ty::util::IntTypeExt; use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt}; use rustc_session::Session; use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants}; diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index cd9b98e4f32..e4670633914 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -4,11 +4,9 @@ #![feature(cow_is_borrowed)] #![feature(decl_macro)] #![feature(impl_trait_in_assoc_type)] -#![feature(inline_const)] #![feature(is_sorted)] #![feature(let_chains)] #![feature(map_try_insert)] -#![cfg_attr(bootstrap, feature(min_specialization))] #![feature(never_type)] #![feature(option_get_or_insert_default)] #![feature(round_char_boundary)] @@ -18,8 +16,6 @@ #[macro_use] extern crate tracing; -#[macro_use] -extern crate rustc_middle; use hir::ConstContext; use required_consts::RequiredConstsVisitor; @@ -37,8 +33,10 @@ use rustc_middle::mir::{ LocalDecl, MirPass, MirPhase, Operand, Place, ProjectionElem, Promoted, RuntimePhase, Rvalue, SourceInfo, Statement, StatementKind, TerminatorKind, START_BLOCK, }; -use rustc_middle::query::Providers; +use rustc_middle::query; use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt}; +use rustc_middle::util::Providers; +use rustc_middle::{bug, span_bug}; use rustc_span::{source_map::Spanned, sym, DUMMY_SP}; use rustc_trait_selection::traits; @@ -53,7 +51,6 @@ mod add_moves_for_packed_drops; mod add_retag; mod check_const_item_mutation; mod check_packed_ref; -pub mod check_unsafety; mod remove_place_mention; // This pass is public to allow external drivers to perform MIR cleanup mod add_subtyping_projections; @@ -88,6 +85,7 @@ mod lint; mod lower_intrinsics; mod lower_slice_len; mod match_branches; +mod mentioned_items; mod multiple_return_terminators; mod normalize_array_len; mod nrvo; @@ -109,24 +107,23 @@ pub mod simplify; mod simplify_branches; mod simplify_comparison_integral; mod sroa; -mod uninhabited_enum_branching; +mod unreachable_enum_branching; mod unreachable_prop; +mod validate; -use rustc_const_eval::transform::check_consts::{self, ConstCx}; -use rustc_const_eval::transform::validate; +use rustc_const_eval::check_consts::{self, ConstCx}; use rustc_mir_dataflow::rustc_peek; rustc_fluent_macro::fluent_messages! { "../messages.ftl" } pub fn provide(providers: &mut Providers) { - check_unsafety::provide(providers); coverage::query::provide(providers); ffi_unwind_calls::provide(providers); shim::provide(providers); cross_crate_inline::provide(providers); - *providers = Providers { + providers.queries = query::Providers { mir_keys, - mir_const, + mir_built, mir_const_qualif, mir_promoted, mir_drops_elaborated_and_const_checked, @@ -139,7 +136,7 @@ pub fn provide(providers: &mut Providers) { mir_inliner_callees: inline::cycle::mir_inliner_callees, promoted_mir, deduced_param_attrs: deduce_param_attrs::deduced_param_attrs, - ..*providers + ..providers.queries }; } @@ -258,9 +255,9 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: LocalDefId) -> ConstQualifs { // N.B., this `borrow()` is guaranteed to be valid (i.e., the value // cannot yet be stolen), because `mir_promoted()`, which steals - // from `mir_const()`, forces this query to execute before + // from `mir_built()`, forces this query to execute before // performing the steal. - let body = &tcx.mir_const(def).borrow(); + let body = &tcx.mir_built(def).borrow(); if body.return_ty().references_error() { // It's possible to reach here without an error being emitted (#121103). @@ -278,19 +275,8 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: LocalDefId) -> ConstQualifs { validator.qualifs_in_return_place() } -/// Make MIR ready for const evaluation. This is run on all MIR, not just on consts! -/// FIXME(oli-obk): it's unclear whether we still need this phase (and its corresponding query). -/// We used to have this for pre-miri MIR based const eval. -fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { - // MIR unsafety check uses the raw mir, so make sure it is run. - if !tcx.sess.opts.unstable_opts.thir_unsafeck { - tcx.ensure_with_value().mir_unsafety_check_result(def); - } - - // has_ffi_unwind_calls query uses the raw mir, so make sure it is run. - tcx.ensure_with_value().has_ffi_unwind_calls(def); - - let mut body = tcx.mir_built(def).steal(); +fn mir_built(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { + let mut body = tcx.build_mir(def); pass_manager::dump_mir_for_phase_change(tcx, &body); @@ -333,16 +319,20 @@ fn mir_promoted( } DefKind::AssocConst | DefKind::Const - | DefKind::Static(_) + | DefKind::Static { .. } | DefKind::InlineConst | DefKind::AnonConst => tcx.mir_const_qualif(def), _ => ConstQualifs::default(), }; - let mut body = tcx.mir_const(def).steal(); + // has_ffi_unwind_calls query uses the raw mir, so make sure it is run. + tcx.ensure_with_value().has_ffi_unwind_calls(def); + let mut body = tcx.mir_built(def).steal(); if let Some(error_reported) = const_qualifs.tainted_by_errors { body.tainted_by_errors = Some(error_reported); } + // Collect `required_consts` *before* promotion, so if there are any consts being promoted + // we still add them to the list in the outer MIR body. let mut required_consts = Vec::new(); let mut required_consts_visitor = RequiredConstsVisitor::new(&mut required_consts); for (bb, bb_data) in traversal::reverse_postorder(&body) { @@ -506,7 +496,7 @@ fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let passes: &[&dyn MirPass<'tcx>] = &[ &cleanup_post_borrowck::CleanupPostBorrowck, &remove_noop_landing_pads::RemoveNoopLandingPads, - &simplify::SimplifyCfg::EarlyOpt, + &simplify::SimplifyCfg::PostAnalysis, &deref_separator::Derefer, ]; @@ -528,11 +518,11 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // AddMovesForPackedDrops needs to run after drop // elaboration. &add_moves_for_packed_drops::AddMovesForPackedDrops, - // `AddRetag` needs to run after `ElaborateDrops`. Otherwise it should run fairly late, + // `AddRetag` needs to run after `ElaborateDrops` but before `ElaborateBoxDerefs`. Otherwise it should run fairly late, // but before optimizations begin. + &add_retag::AddRetag, &elaborate_box_derefs::ElaborateBoxDerefs, &coroutine::StateTransform, - &add_retag::AddRetag, &Lint(known_panics_lint::KnownPanicsLint), ]; pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial))); @@ -543,7 +533,7 @@ fn run_runtime_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let passes: &[&dyn MirPass<'tcx>] = &[ &lower_intrinsics::LowerIntrinsics, &remove_place_mention::RemovePlaceMention, - &simplify::SimplifyCfg::ElaborateDrops, + &simplify::SimplifyCfg::PreOptimizations, ]; pm::run_passes(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::PostCleanup))); @@ -565,6 +555,10 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { tcx, body, &[ + // Before doing anything, remember which items are being mentioned so that the set of items + // visited does not depend on the optimization level. + &mentioned_items::MentionedItems, + // Add some UB checks before any UB gets optimized away. &check_alignment::CheckAlignment, // Before inlining: trim down MIR with passes to reduce inlining work. @@ -579,9 +573,10 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &remove_zsts::RemoveZsts, &remove_unneeded_drops::RemoveUnneededDrops, // Type instantiation may create uninhabited enums. - &uninhabited_enum_branching::UninhabitedEnumBranching, + // Also eliminates some unreachable branches based on variants of enums. + &unreachable_enum_branching::UnreachableEnumBranching, &unreachable_prop::UnreachablePropagation, - &o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching), + &o1(simplify::SimplifyCfg::AfterUnreachableEnumBranching), // Inlining may have introduced a lot of redundant code and a large move pattern. // Now, we need to shrink the generated MIR. @@ -632,12 +627,6 @@ fn optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> &Body<'_> { } fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> { - if tcx.intrinsic(did).is_some_and(|i| i.must_be_overridden) { - span_bug!( - tcx.def_span(did), - "this intrinsic must be overridden by the codegen backend, it has no meaningful body", - ) - } if tcx.is_constructor(did.to_def_id()) { // There's no reason to run all of the MIR passes on constructors when // we can just output the MIR we want directly. This also saves const diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs index f317c025e96..221301b2ceb 100644 --- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -2,6 +2,7 @@ use rustc_middle::mir::*; use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::{bug, span_bug}; use rustc_span::symbol::sym; pub struct LowerIntrinsics; @@ -20,13 +21,13 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { sym::unreachable => { terminator.kind = TerminatorKind::Unreachable; } - sym::debug_assertions => { + sym::ub_checks => { let target = target.unwrap(); block.statements.push(Statement { source_info: terminator.source_info, kind: StatementKind::Assign(Box::new(( *destination, - Rvalue::NullaryOp(NullOp::DebugAssertions, tcx.types.bool), + Rvalue::NullaryOp(NullOp::UbChecks, tcx.types.bool), ))), }); terminator.kind = TerminatorKind::Goto { target }; @@ -90,6 +91,7 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { sym::wrapping_add | sym::wrapping_sub | sym::wrapping_mul + | sym::three_way_compare | sym::unchecked_add | sym::unchecked_sub | sym::unchecked_mul @@ -109,6 +111,7 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { sym::wrapping_add => BinOp::Add, sym::wrapping_sub => BinOp::Sub, sym::wrapping_mul => BinOp::Mul, + sym::three_way_compare => BinOp::Cmp, sym::unchecked_add => BinOp::AddUnchecked, sym::unchecked_sub => BinOp::SubUnchecked, sym::unchecked_mul => BinOp::MulUnchecked, @@ -137,16 +140,16 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { rhs = args.next().unwrap(); } let bin_op = match intrinsic.name { - sym::add_with_overflow => BinOp::Add, - sym::sub_with_overflow => BinOp::Sub, - sym::mul_with_overflow => BinOp::Mul, + sym::add_with_overflow => BinOp::AddWithOverflow, + sym::sub_with_overflow => BinOp::SubWithOverflow, + sym::mul_with_overflow => BinOp::MulWithOverflow, _ => bug!("unexpected intrinsic"), }; block.statements.push(Statement { source_info: terminator.source_info, kind: StatementKind::Assign(Box::new(( *destination, - Rvalue::CheckedBinaryOp(bin_op, Box::new((lhs.node, rhs.node))), + Rvalue::BinaryOp(bin_op, Box::new((lhs.node, rhs.node))), ))), }); terminator.kind = TerminatorKind::Goto { target }; @@ -285,6 +288,34 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { terminator.kind = TerminatorKind::Unreachable; } } + sym::aggregate_raw_ptr => { + let Ok([data, meta]) = <[_; 2]>::try_from(std::mem::take(args)) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for aggregate_raw_ptr intrinsic", + ); + }; + let target = target.unwrap(); + let pointer_ty = generic_args.type_at(0); + let kind = if let ty::RawPtr(pointee_ty, mutability) = pointer_ty.kind() { + AggregateKind::RawPtr(*pointee_ty, *mutability) + } else { + span_bug!( + terminator.source_info.span, + "Return type of aggregate_raw_ptr intrinsic must be a raw pointer", + ); + }; + let fields = [data.node, meta.node]; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Aggregate(Box::new(kind), fields.into()), + ))), + }); + + terminator.kind = TerminatorKind::Goto { target }; + } _ => {} } } diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs index 8137525a332..2267a621a83 100644 --- a/compiler/rustc_mir_transform/src/lower_slice_len.rs +++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs @@ -21,7 +21,7 @@ impl<'tcx> MirPass<'tcx> for LowerSliceLenCalls { pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let language_items = tcx.lang_items(); let Some(slice_len_fn_item_def_id) = language_items.slice_len_fn() else { - // there is no language item to compare to :) + // there is no lang item to compare to :) return; }; diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 6d4332793af..1411d9be223 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,11 +1,128 @@ +use rustc_index::IndexSlice; +use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt}; +use rustc_target::abi::Size; use std::iter; use super::simplify::simplify_cfg; pub struct MatchBranchSimplification; +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let mut should_cleanup = false; + for i in 0..body.basic_blocks.len() { + let bbs = &*body.basic_blocks; + let bb_idx = BasicBlock::from_usize(i); + if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { + continue; + } + + match bbs[bb_idx].terminator().kind { + TerminatorKind::SwitchInt { + discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)), + ref targets, + .. + // We require that the possible target blocks don't contain this block. + } if !targets.all_targets().contains(&bb_idx) => {} + // Only optimize switch int statements + _ => continue, + }; + + if SimplifyToIf.simplify(tcx, body, bb_idx, param_env).is_some() { + should_cleanup = true; + continue; + } + // unsound: https://github.com/rust-lang/rust/issues/124150 + if tcx.sess.opts.unstable_opts.unsound_mir_opts + && SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env).is_some() + { + should_cleanup = true; + continue; + } + } + + if should_cleanup { + simplify_cfg(body); + } + } +} + +trait SimplifyMatch<'tcx> { + /// Simplifies a match statement, returning true if the simplification succeeds, false otherwise. + /// Generic code is written here, and we generally don't need a custom implementation. + fn simplify( + &mut self, + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + switch_bb_idx: BasicBlock, + param_env: ParamEnv<'tcx>, + ) -> Option<()> { + let bbs = &body.basic_blocks; + let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { + TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), + _ => unreachable!(), + }; + + let discr_ty = discr.ty(body.local_decls(), tcx); + self.can_simplify(tcx, targets, param_env, bbs, discr_ty)?; + + let mut patch = MirPatch::new(body); + + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + + // Introduce a temporary for the discriminant value. + let source_info = bbs[switch_bb_idx].terminator().source_info; + let discr_local = patch.new_temp(discr_ty, source_info.span); + + let (_, first) = targets.iter().next().unwrap(); + let statement_index = bbs[switch_bb_idx].statements.len(); + let parent_end = Location { block: switch_bb_idx, statement_index }; + patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); + patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); + self.new_stmts(tcx, targets, param_env, &mut patch, parent_end, bbs, discr_local, discr_ty); + patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); + patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone()); + patch.apply(body); + Some(()) + } + + /// Check that the BBs to be simplified satisfies all distinct and + /// that the terminator are the same. + /// There are also conditions for different ways of simplification. + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_ty: Ty<'tcx>, + ) -> Option<()>; + + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ); +} + +struct SimplifyToIf; + /// If a source block is found that switches between two blocks that are exactly /// the same modulo const bool assignments (e.g., one assigns true another false /// to the same place), merge a target block statements into the source block, @@ -37,144 +154,349 @@ pub struct MatchBranchSimplification; /// goto -> bb3; /// } /// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + _discr_ty: Ty<'tcx>, + ) -> Option<()> { + if targets.iter().len() != 1 { + return None; + } + // We require that the possible target blocks all be distinct. + let (_, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + if first == second { + return None; + } + // Check that destinations are identical, and if not, then don't optimize this block + if bbs[first].terminator().kind != bbs[second].terminator().kind { + return None; + } -impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 1 - } + // Check that blocks are assignments of consts to the same place or same statement, + // and match up 1-1, if not don't optimize this block. + let first_stmts = &bbs[first].statements; + let second_stmts = &bbs[second].statements; + if first_stmts.len() != second_stmts.len() { + return None; + } + for (f, s) in iter::zip(first_stmts, second_stmts) { + match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => {} - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let def_id = body.source.def_id(); - let param_env = tcx.param_env_reveal_all_normalized(def_id); + // If two statements are const bool assignments to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty().is_bool() + && s_c.const_.ty().is_bool() + && f_c.const_.try_eval_bool(tcx, param_env).is_some() + && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} - let bbs = body.basic_blocks.as_mut(); - let mut should_cleanup = false; - 'outer: for bb_idx in bbs.indices() { - if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { - continue; + // Otherwise we cannot optimize. Try another block. + _ => return None, } + } + Some(()) + } - let (discr, val, first, second) = match bbs[bb_idx].terminator().kind { - TerminatorKind::SwitchInt { - discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), - ref targets, - .. - } if targets.iter().len() == 1 => { - let (value, target) = targets.iter().next().unwrap(); - // We require that this block and the two possible target blocks all be - // distinct. - if target == targets.otherwise() - || bb_idx == target - || bb_idx == targets.otherwise() - { - continue; + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) { + let (val, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let first = &bbs[first]; + let second = &bbs[second]; + for (f, s) in iter::zip(&first.statements, &second.statements) { + match (&f.kind, &s.kind) { + (f_s, s_s) if f_s == s_s => { + patch.add_statement(parent_end, f.kind.clone()); + } + + ( + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), + ) => { + // From earlier loop we know that we are dealing with bool constants only: + let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); + let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); + if f_b == s_b { + // Same value in both blocks. Use statement as is. + patch.add_statement(parent_end, f.kind.clone()); + } else { + // Different value between blocks. Make value conditional on switch condition. + let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + tcx, + discr_ty, + rustc_const_eval::interpret::Scalar::from_uint(val, size), + rustc_span::DUMMY_SP, + ); + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; + let rhs = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), + ); + patch.add_assign(parent_end, *lhs, rhs); } - (discr, value, target, targets.otherwise()) } - // Only optimize switch int statements - _ => continue, - }; - // Check that destinations are identical, and if not, then don't optimize this block - if bbs[first].terminator().kind != bbs[second].terminator().kind { - continue; + _ => unreachable!(), } + } + } +} - // Check that blocks are assignments of consts to the same place or same statement, - // and match up 1-1, if not don't optimize this block. - let first_stmts = &bbs[first].statements; - let scnd_stmts = &bbs[second].statements; - if first_stmts.len() != scnd_stmts.len() { - continue; - } - for (f, s) in iter::zip(first_stmts, scnd_stmts) { - match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => {} +#[derive(Default)] +struct SimplifyToExp { + transfrom_types: Vec<TransfromType>, +} - // If two statements are const bool assignments to the same place, we can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty().is_bool() - && s_c.const_.ty().is_bool() - && f_c.const_.try_eval_bool(tcx, param_env).is_some() - && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} +#[derive(Clone, Copy)] +enum CompareType<'tcx, 'a> { + /// Identical statements. + Same(&'a StatementKind<'tcx>), + /// Assignment statements have the same value. + Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), + /// Enum variant comparison type. + Discr { place: &'a Place<'tcx>, ty: Ty<'tcx>, is_signed: bool }, +} - // Otherwise we cannot optimize. Try another block. - _ => continue 'outer, - } - } - // Take ownership of items now that we know we can optimize. - let discr = discr.clone(); - let discr_ty = discr.ty(&body.local_decls, tcx); +enum TransfromType { + Same, + Eq, + Discr, +} + +impl From<CompareType<'_, '_>> for TransfromType { + fn from(compare_type: CompareType<'_, '_>) -> Self { + match compare_type { + CompareType::Same(_) => TransfromType::Same, + CompareType::Eq(_, _, _) => TransfromType::Eq, + CompareType::Discr { .. } => TransfromType::Discr, + } + } +} + +/// If we find that the value of match is the same as the assignment, +/// merge a target block statements into the source block, +/// using cast to transform different integer types. +/// +/// For example: +/// +/// ```ignore (MIR) +/// bb0: { +/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; +/// } +/// +/// bb1: { +/// unreachable; +/// } +/// +/// bb2: { +/// _0 = const 1_i16; +/// goto -> bb5; +/// } +/// +/// bb3: { +/// _0 = const 2_i16; +/// goto -> bb5; +/// } +/// +/// bb4: { +/// _0 = const 3_i16; +/// goto -> bb5; +/// } +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// bb0: { +/// _0 = _3 as i16 (IntToInt); +/// goto -> bb5; +/// } +/// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_ty: Ty<'tcx>, + ) -> Option<()> { + if targets.iter().len() < 2 || targets.iter().len() > 64 { + return None; + } + // We require that the possible target blocks all be distinct. + if !targets.is_distinct() { + return None; + } + if !bbs[targets.otherwise()].is_empty_unreachable() { + return None; + } + let mut target_iter = targets.iter(); + let (first_val, first_target) = target_iter.next().unwrap(); + let first_terminator_kind = &bbs[first_target].terminator().kind; + // Check that destinations are identical, and if not, then don't optimize this block + if !targets + .iter() + .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) + { + return None; + } - // Introduce a temporary for the discriminant value. - let source_info = bbs[bb_idx].terminator().source_info; - let discr_local = body.local_decls.push(LocalDecl::new(discr_ty, source_info.span)); + let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; + let first_stmts = &bbs[first_target].statements; + let (second_val, second_target) = target_iter.next().unwrap(); + let second_stmts = &bbs[second_target].statements; + if first_stmts.len() != second_stmts.len() { + return None; + } - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); + fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool { + l.assert_int(l.size()) == ScalarInt::try_from_uint(r, size).unwrap().assert_int(size) + } - let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { - match (&f.kind, &s.kind) { - (f_s, s_s) if f_s == s_s => (*f).clone(), + // We first compare the two branches, and then the other branches need to fulfill the same conditions. + let mut compare_types = Vec::new(); + for (f, s) in iter::zip(first_stmts, second_stmts) { + let compare_type = match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => CompareType::Same(f_s), - ( - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), - ) => { - // From earlier loop we know that we are dealing with bool constants only: - let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); - let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); - if f_b == s_b { - // Same value in both blocks. Use statement as is. - (*f).clone() - } else { - // Different value between blocks. Make value conditional on switch condition. - let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; - let const_cmp = Operand::const_from_scalar( - tcx, - discr_ty, - rustc_const_eval::interpret::Scalar::from_uint(val, size), - rustc_span::DUMMY_SP, - ); - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; - let rhs = Rvalue::BinaryOp( - op, - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), - ); - Statement { - source_info: f.source_info, - kind: StatementKind::Assign(Box::new((*lhs, rhs))), + // If two statements are assignments with the match values to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty() == s_c.const_.ty() + && f_c.const_.ty().is_integral() => + { + match ( + f_c.const_.try_eval_scalar_int(tcx, param_env), + s_c.const_.try_eval_scalar_int(tcx, param_env), + ) { + (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), + // Enum variants can also be simplified to an assignment statement if their values are equal. + // We need to consider both unsigned and signed scenarios here. + (Some(f), Some(s)) + if ((f_c.const_.ty().is_signed() || discr_ty.is_signed()) + && int_equal(f, first_val, discr_size) + && int_equal(s, second_val, discr_size)) + || (Some(f) == ScalarInt::try_from_uint(first_val, f.size()) + && Some(s) + == ScalarInt::try_from_uint(second_val, s.size())) => + { + CompareType::Discr { + place: lhs_f, + ty: f_c.const_.ty(), + is_signed: f_c.const_.ty().is_signed() || discr_ty.is_signed(), } } + _ => { + return None; + } } + } + + // Otherwise we cannot optimize. Try another block. + _ => return None, + }; + compare_types.push(compare_type); + } - _ => unreachable!(), + // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step. + for (other_val, other_target) in target_iter { + let other_stmts = &bbs[other_target].statements; + if compare_types.len() != other_stmts.len() { + return None; + } + for (f, s) in iter::zip(&compare_types, other_stmts) { + match (*f, &s.kind) { + (CompareType::Same(f_s), s_s) if f_s == s_s => {} + ( + CompareType::Eq(lhs_f, f_ty, val), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && s_c.const_.ty() == f_ty + && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} + ( + CompareType::Discr { place: lhs_f, ty: f_ty, is_signed }, + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { + let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { + return None; + }; + if is_signed + && s_c.const_.ty().is_signed() + && int_equal(f, other_val, discr_size) + { + continue; + } + if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) { + continue; + } + return None; + } + _ => return None, } - }); - - from.statements - .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); - from.statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new(( - Place::from(discr_local), - Rvalue::Use(discr), - ))), - }); - from.statements.extend(new_stmts); - from.statements - .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); - from.terminator_mut().kind = first.terminator().kind.clone(); - should_cleanup = true; + } } + self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect(); + Some(()) + } - if should_cleanup { - simplify_cfg(body); + fn new_stmts( + &self, + _tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + _param_env: ParamEnv<'tcx>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) { + let (_, first) = targets.iter().next().unwrap(); + let first = &bbs[first]; + + for (t, s) in iter::zip(&self.transfrom_types, &first.statements) { + match (t, &s.kind) { + (TransfromType::Same, _) | (TransfromType::Eq, _) => { + patch.add_statement(parent_end, s.kind.clone()); + } + ( + TransfromType::Discr, + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + ) => { + let operand = Operand::Copy(Place::from(discr_local)); + let r_val = if f_c.const_.ty() == discr_ty { + Rvalue::Use(operand) + } else { + Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) + }; + patch.add_assign(parent_end, *lhs, r_val); + } + _ => unreachable!(), + } } } } diff --git a/compiler/rustc_mir_transform/src/mentioned_items.rs b/compiler/rustc_mir_transform/src/mentioned_items.rs new file mode 100644 index 00000000000..db2bb60bdac --- /dev/null +++ b/compiler/rustc_mir_transform/src/mentioned_items.rs @@ -0,0 +1,117 @@ +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::{self, Location, MentionedItem, MirPass}; +use rustc_middle::ty::{self, adjustment::PointerCoercion, TyCtxt}; +use rustc_session::Session; +use rustc_span::source_map::Spanned; + +pub struct MentionedItems; + +struct MentionedItemsVisitor<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a mir::Body<'tcx>, + mentioned_items: &'a mut Vec<Spanned<MentionedItem<'tcx>>>, +} + +impl<'tcx> MirPass<'tcx> for MentionedItems { + fn is_enabled(&self, _sess: &Session) -> bool { + // If this pass is skipped the collector assume that nothing got mentioned! We could + // potentially skip it in opt-level 0 if we are sure that opt-level will never *remove* uses + // of anything, but that still seems fragile. Furthermore, even debug builds use level 1, so + // special-casing level 0 is just not worth it. + true + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + debug_assert!(body.mentioned_items.is_empty()); + let mut mentioned_items = Vec::new(); + MentionedItemsVisitor { tcx, body, mentioned_items: &mut mentioned_items }.visit_body(body); + body.mentioned_items = mentioned_items; + } +} + +// This visitor is carefully in sync with the one in `rustc_monomorphize::collector`. We are +// visiting the exact same places but then instead of monomorphizing and creating `MonoItems`, we +// have to remain generic and just recording the relevant information in `mentioned_items`, where it +// will then be monomorphized later during "mentioned items" collection. +impl<'tcx> Visitor<'tcx> for MentionedItemsVisitor<'_, 'tcx> { + fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, location: Location) { + self.super_terminator(terminator, location); + let span = || self.body.source_info(location).span; + match &terminator.kind { + mir::TerminatorKind::Call { func, .. } => { + let callee_ty = func.ty(self.body, self.tcx); + self.mentioned_items + .push(Spanned { node: MentionedItem::Fn(callee_ty), span: span() }); + } + mir::TerminatorKind::Drop { place, .. } => { + let ty = place.ty(self.body, self.tcx).ty; + self.mentioned_items.push(Spanned { node: MentionedItem::Drop(ty), span: span() }); + } + mir::TerminatorKind::InlineAsm { ref operands, .. } => { + for op in operands { + match *op { + mir::InlineAsmOperand::SymFn { ref value } => { + self.mentioned_items.push(Spanned { + node: MentionedItem::Fn(value.const_.ty()), + span: span(), + }); + } + _ => {} + } + } + } + _ => {} + } + } + + fn visit_rvalue(&mut self, rvalue: &mir::Rvalue<'tcx>, location: Location) { + self.super_rvalue(rvalue, location); + let span = || self.body.source_info(location).span; + match *rvalue { + // We need to detect unsizing casts that required vtables. + mir::Rvalue::Cast( + mir::CastKind::PointerCoercion(PointerCoercion::Unsize), + ref operand, + target_ty, + ) + | mir::Rvalue::Cast(mir::CastKind::DynStar, ref operand, target_ty) => { + // This isn't monomorphized yet so we can't tell what the actual types are -- just + // add everything that may involve a vtable. + let source_ty = operand.ty(self.body, self.tcx); + let may_involve_vtable = match ( + source_ty.builtin_deref(true).map(|t| t.kind()), + target_ty.builtin_deref(true).map(|t| t.kind()), + ) { + (Some(ty::Array(..)), Some(ty::Str | ty::Slice(..))) => false, // &str/&[T] unsizing + _ => true, + }; + if may_involve_vtable { + self.mentioned_items.push(Spanned { + node: MentionedItem::UnsizeCast { source_ty, target_ty }, + span: span(), + }); + } + } + // Similarly, record closures that are turned into function pointers. + mir::Rvalue::Cast( + mir::CastKind::PointerCoercion(PointerCoercion::ClosureFnPointer(_)), + ref operand, + _, + ) => { + let source_ty = operand.ty(self.body, self.tcx); + self.mentioned_items + .push(Spanned { node: MentionedItem::Closure(source_ty), span: span() }); + } + // And finally, function pointer reification casts. + mir::Rvalue::Cast( + mir::CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer), + ref operand, + _, + ) => { + let fn_ty = operand.ty(self.body, self.tcx); + self.mentioned_items.push(Spanned { node: MentionedItem::Fn(fn_ty), span: span() }); + } + _ => {} + } + } +} diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs index 128634bd7f2..2070895c900 100644 --- a/compiler/rustc_mir_transform/src/normalize_array_len.rs +++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs @@ -22,7 +22,8 @@ impl<'tcx> MirPass<'tcx> for NormalizeArrayLen { } fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let ssa = SsaLocals::new(body); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let ssa = SsaLocals::new(tcx, body, param_env); let slice_lengths = compute_slice_length(tcx, &ssa, body); debug!(?slice_lengths); @@ -47,9 +48,9 @@ fn compute_slice_length<'tcx>( let operand_ty = operand.ty(body, tcx); debug!(?operand_ty); if let Some(operand_ty) = operand_ty.builtin_deref(true) - && let ty::Array(_, len) = operand_ty.ty.kind() + && let ty::Array(_, len) = operand_ty.kind() && let Some(cast_ty) = cast_ty.builtin_deref(true) - && let ty::Slice(..) = cast_ty.ty.kind() + && let ty::Slice(..) = cast_ty.kind() { slice_lengths[local] = Some(*len); } diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs index c3a92911bbf..885dbd5f339 100644 --- a/compiler/rustc_mir_transform/src/nrvo.rs +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -2,6 +2,7 @@ use rustc_hir::Mutability; use rustc_index::bit_set::BitSet; +use rustc_middle::bug; use rustc_middle::mir::visit::{MutVisitor, NonUseContext, PlaceContext, Visitor}; use rustc_middle::mir::{self, BasicBlock, Local, Location}; use rustc_middle::ty::TyCtxt; @@ -84,7 +85,7 @@ impl<'tcx> MirPass<'tcx> for RenameReturnPlace { /// /// If the MIR fulfills both these conditions, this function returns the `Local` that is assigned /// to the return place along all possible paths through the control-flow graph. -fn local_eligible_for_nrvo(body: &mut mir::Body<'_>) -> Option<Local> { +fn local_eligible_for_nrvo(body: &mir::Body<'_>) -> Option<Local> { if IsReturnPlaceRead::run(body) { return None; } @@ -118,10 +119,7 @@ fn local_eligible_for_nrvo(body: &mut mir::Body<'_>) -> Option<Local> { copied_to_return_place } -fn find_local_assigned_to_return_place( - start: BasicBlock, - body: &mut mir::Body<'_>, -) -> Option<Local> { +fn find_local_assigned_to_return_place(start: BasicBlock, body: &mir::Body<'_>) -> Option<Local> { let mut block = start; let mut seen = BitSet::new_empty(body.basic_blocks.len()); diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index 77478cc741d..17a1c3c7157 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -186,9 +186,6 @@ fn run_passes_inner<'tcx>( if let Some(by_move_body) = coroutine.by_move_body.as_mut() { run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); } - if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() { - run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each); - } } } diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs index 2e11da4d585..e37f90ae7f4 100644 --- a/compiler/rustc_mir_transform/src/promote_consts.rs +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -13,12 +13,14 @@ //! move analysis runs after promotion on broken MIR. use either::{Left, Right}; +use rustc_data_structures::fx::FxHashSet; use rustc_hir as hir; use rustc_middle::mir; use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, List, Ty, TyCtxt, TypeVisitableExt}; +use rustc_middle::{bug, span_bug}; use rustc_span::Span; use rustc_index::{Idx, IndexSlice, IndexVec}; @@ -28,7 +30,7 @@ use std::assert_matches::assert_matches; use std::cell::Cell; use std::{cmp, iter, mem}; -use rustc_const_eval::transform::check_consts::{qualifs, ConstCx}; +use rustc_const_eval::check_consts::{qualifs, ConstCx}; /// A `MirPass` for promotion. /// @@ -175,6 +177,12 @@ fn collect_temps_and_candidates<'tcx>( struct Validator<'a, 'tcx> { ccx: &'a ConstCx<'a, 'tcx>, temps: &'a mut IndexSlice<Local, TempState>, + /// For backwards compatibility, we are promoting function calls in `const`/`static` + /// initializers. But we want to avoid evaluating code that might panic and that otherwise would + /// not have been evaluated, so we only promote such calls in basic blocks that are guaranteed + /// to execute. In other words, we only promote such calls in basic blocks that are definitely + /// not dead code. Here we cache the result of computing that set of basic blocks. + promotion_safe_blocks: Option<FxHashSet<BasicBlock>>, } impl<'a, 'tcx> std::ops::Deref for Validator<'a, 'tcx> { @@ -260,7 +268,9 @@ impl<'tcx> Validator<'_, 'tcx> { self.validate_rvalue(rhs) } Right(terminator) => match &terminator.kind { - TerminatorKind::Call { func, args, .. } => self.validate_call(func, args), + TerminatorKind::Call { func, args, .. } => { + self.validate_call(func, args, loc.block) + } TerminatorKind::Yield { .. } => Err(Unpromotable), kind => { span_bug!(terminator.source_info.span, "{:?} not promotable", kind); @@ -384,7 +394,7 @@ impl<'tcx> Validator<'_, 'tcx> { match kind { // Reject these borrow types just to be safe. // FIXME(RalfJung): could we allow them? Should we? No point in it until we have a usecase. - BorrowKind::Fake | BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture } => { + BorrowKind::Fake(_) | BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture } => { return Err(Unpromotable); } @@ -434,7 +444,7 @@ impl<'tcx> Validator<'_, 'tcx> { Rvalue::ThreadLocalRef(_) => return Err(Unpromotable), // ptr-to-int casts are not possible in consts and thus not promotable - Rvalue::Cast(CastKind::PointerExposeAddress, _, _) => return Err(Unpromotable), + Rvalue::Cast(CastKind::PointerExposeProvenance, _, _) => return Err(Unpromotable), // all other casts including int-to-ptr casts are fine, they just use the integer value // at pointer type. @@ -446,7 +456,7 @@ impl<'tcx> Validator<'_, 'tcx> { NullOp::SizeOf => {} NullOp::AlignOf => {} NullOp::OffsetOf(_) => {} - NullOp::DebugAssertions => {} + NullOp::UbChecks => {} }, Rvalue::ShallowInitBox(_, _) => return Err(Unpromotable), @@ -460,11 +470,11 @@ impl<'tcx> Validator<'_, 'tcx> { self.validate_operand(operand)?; } - Rvalue::BinaryOp(op, box (lhs, rhs)) | Rvalue::CheckedBinaryOp(op, box (lhs, rhs)) => { + Rvalue::BinaryOp(op, box (lhs, rhs)) => { let op = *op; let lhs_ty = lhs.ty(self.body, self.tcx); - if let ty::RawPtr(_) | ty::FnPtr(..) = lhs_ty.kind() { + if let ty::RawPtr(_, _) | ty::FnPtr(..) = lhs_ty.kind() { // Raw and fn pointer operations are not allowed inside consts and thus not promotable. assert!(matches!( op, @@ -490,14 +500,14 @@ impl<'tcx> Validator<'_, 'tcx> { } _ => None, }; - match rhs_val.map(|x| x.try_to_uint(sz).unwrap()) { + match rhs_val.map(|x| x.assert_uint(sz)) { // for the zero test, int vs uint does not matter Some(x) if x != 0 => {} // okay _ => return Err(Unpromotable), // value not known or 0 -- not okay } // Furthermore, for signed divison, we also have to exclude `int::MIN / -1`. if lhs_ty.is_signed() { - match rhs_val.map(|x| x.try_to_int(sz).unwrap()) { + match rhs_val.map(|x| x.assert_int(sz)) { Some(-1) | None => { // The RHS is -1 or unknown, so we have to be careful. // But is the LHS int::MIN? @@ -508,7 +518,7 @@ impl<'tcx> Validator<'_, 'tcx> { _ => None, }; let lhs_min = sz.signed_int_min(); - match lhs_val.map(|x| x.try_to_int(sz).unwrap()) { + match lhs_val.map(|x| x.assert_int(sz)) { Some(x) if x != lhs_min => {} // okay _ => return Err(Unpromotable), // value not known or int::MIN -- not okay } @@ -525,13 +535,17 @@ impl<'tcx> Validator<'_, 'tcx> { | BinOp::Lt | BinOp::Ge | BinOp::Gt + | BinOp::Cmp | BinOp::Offset | BinOp::Add | BinOp::AddUnchecked + | BinOp::AddWithOverflow | BinOp::Sub | BinOp::SubUnchecked + | BinOp::SubWithOverflow | BinOp::Mul | BinOp::MulUnchecked + | BinOp::MulWithOverflow | BinOp::BitXor | BinOp::BitAnd | BinOp::BitOr @@ -587,29 +601,79 @@ impl<'tcx> Validator<'_, 'tcx> { Ok(()) } + /// Computes the sets of blocks of this MIR that are definitely going to be executed + /// if the function returns successfully. That makes it safe to promote calls in them + /// that might fail. + fn promotion_safe_blocks(body: &mir::Body<'tcx>) -> FxHashSet<BasicBlock> { + let mut safe_blocks = FxHashSet::default(); + let mut safe_block = START_BLOCK; + loop { + safe_blocks.insert(safe_block); + // Let's see if we can find another safe block. + safe_block = match body.basic_blocks[safe_block].terminator().kind { + TerminatorKind::Goto { target } => target, + TerminatorKind::Call { target: Some(target), .. } + | TerminatorKind::Drop { target, .. } => { + // This calls a function or the destructor. `target` does not get executed if + // the callee loops or panics. But in both cases the const already fails to + // evaluate, so we are fine considering `target` a safe block for promotion. + target + } + TerminatorKind::Assert { target, .. } => { + // Similar to above, we only consider successful execution. + target + } + _ => { + // No next safe block. + break; + } + }; + } + safe_blocks + } + + /// Returns whether the block is "safe" for promotion, which means it cannot be dead code. + /// We use this to avoid promoting operations that can fail in dead code. + fn is_promotion_safe_block(&mut self, block: BasicBlock) -> bool { + let body = self.body; + let safe_blocks = + self.promotion_safe_blocks.get_or_insert_with(|| Self::promotion_safe_blocks(body)); + safe_blocks.contains(&block) + } + fn validate_call( &mut self, callee: &Operand<'tcx>, args: &[Spanned<Operand<'tcx>>], + block: BasicBlock, ) -> Result<(), Unpromotable> { + // Validate the operands. If they fail, there's no question -- we cannot promote. + self.validate_operand(callee)?; + for arg in args { + self.validate_operand(&arg.node)?; + } + + // Functions marked `#[rustc_promotable]` are explicitly allowed to be promoted, so we can + // accept them at this point. let fn_ty = callee.ty(self.body, self.tcx); + if let ty::FnDef(def_id, _) = *fn_ty.kind() { + if self.tcx.is_promotable_const_fn(def_id) { + return Ok(()); + } + } - // Inside const/static items, we promote all (eligible) function calls. - // Everywhere else, we require `#[rustc_promotable]` on the callee. - let promote_all_const_fn = matches!( + // Ideally, we'd stop here and reject the rest. + // But for backward compatibility, we have to accept some promotion in const/static + // initializers. Inline consts are explicitly excluded, they are more recent so we have no + // backwards compatibility reason to allow more promotion inside of them. + let promote_all_fn = matches!( self.const_kind, Some(hir::ConstContext::Static(_) | hir::ConstContext::Const { inline: false }) ); - if !promote_all_const_fn { - if let ty::FnDef(def_id, _) = *fn_ty.kind() { - // Never promote runtime `const fn` calls of - // functions without `#[rustc_promotable]`. - if !self.tcx.is_promotable_const_fn(def_id) { - return Err(Unpromotable); - } - } + if !promote_all_fn { + return Err(Unpromotable); } - + // Make sure the callee is a `const fn`. let is_const_fn = match *fn_ty.kind() { ty::FnDef(def_id, _) => self.tcx.is_const_fn_raw(def_id), _ => false, @@ -617,23 +681,23 @@ impl<'tcx> Validator<'_, 'tcx> { if !is_const_fn { return Err(Unpromotable); } - - self.validate_operand(callee)?; - for arg in args { - self.validate_operand(&arg.node)?; + // The problem is, this may promote calls to functions that panic. + // We don't want to introduce compilation errors if there's a panic in a call in dead code. + // So we ensure that this is not dead code. + if !self.is_promotion_safe_block(block) { + return Err(Unpromotable); } - + // This passed all checks, so let's accept. Ok(()) } } -// FIXME(eddyb) remove the differences for promotability in `static`, `const`, `const fn`. fn validate_candidates( ccx: &ConstCx<'_, '_>, temps: &mut IndexSlice<Local, TempState>, candidates: &[Candidate], ) -> Vec<Candidate> { - let mut validator = Validator { ccx, temps }; + let mut validator = Validator { ccx, temps, promotion_safe_blocks: None }; candidates .iter() @@ -652,6 +716,10 @@ struct Promoter<'a, 'tcx> { /// If true, all nested temps are also kept in the /// source MIR, not moved to the promoted MIR. keep_original: bool, + + /// If true, add the new const (the promoted) to the required_consts of the parent MIR. + /// This is initially false and then set by the visitor when it encounters a `Call` terminator. + add_to_required: bool, } impl<'a, 'tcx> Promoter<'a, 'tcx> { @@ -754,6 +822,10 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { TerminatorKind::Call { mut func, mut args, call_source: desugar, fn_span, .. } => { + // This promoted involves a function call, so it may fail to evaluate. + // Let's make sure it is added to `required_consts` so that that failure cannot get lost. + self.add_to_required = true; + self.visit_operand(&mut func, loc); for arg in &mut args { self.visit_operand(&mut arg.node, loc); @@ -788,7 +860,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { fn promote_candidate(mut self, candidate: Candidate, next_promoted_id: usize) -> Body<'tcx> { let def = self.source.source.def_id(); - let mut rvalue = { + let (mut rvalue, promoted_op) = { let promoted = &mut self.promoted; let promoted_id = Promoted::new(next_promoted_id); let tcx = self.tcx; @@ -798,11 +870,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { let args = tcx.erase_regions(GenericArgs::identity_for_item(tcx, def)); let uneval = mir::UnevaluatedConst { def, args, promoted: Some(promoted_id) }; - Operand::Constant(Box::new(ConstOperand { - span, - user_ty: None, - const_: Const::Unevaluated(uneval, ty), - })) + ConstOperand { span, user_ty: None, const_: Const::Unevaluated(uneval, ty) } }; let blocks = self.source.basic_blocks.as_mut(); @@ -820,11 +888,8 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { let ty = local_decls[place.local].ty; let span = statement.source_info.span; - let ref_ty = Ty::new_ref( - tcx, - tcx.lifetimes.re_erased, - ty::TypeAndMut { ty, mutbl: borrow_kind.to_mutbl_lossy() }, - ); + let ref_ty = + Ty::new_ref(tcx, tcx.lifetimes.re_erased, ty, borrow_kind.to_mutbl_lossy()); let mut projection = vec![PlaceElem::Deref]; projection.extend(place.projection); @@ -838,22 +903,26 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { let promoted_ref = local_decls.push(promoted_ref); assert_eq!(self.temps.push(TempState::Unpromotable), promoted_ref); + let promoted_operand = promoted_operand(ref_ty, span); let promoted_ref_statement = Statement { source_info: statement.source_info, kind: StatementKind::Assign(Box::new(( Place::from(promoted_ref), - Rvalue::Use(promoted_operand(ref_ty, span)), + Rvalue::Use(Operand::Constant(Box::new(promoted_operand))), ))), }; self.extra_statements.push((loc, promoted_ref_statement)); - Rvalue::Ref( - tcx.lifetimes.re_erased, - *borrow_kind, - Place { - local: mem::replace(&mut place.local, promoted_ref), - projection: List::empty(), - }, + ( + Rvalue::Ref( + tcx.lifetimes.re_erased, + *borrow_kind, + Place { + local: mem::replace(&mut place.local, promoted_ref), + projection: List::empty(), + }, + ), + promoted_operand, ) }; @@ -865,6 +934,12 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { let span = self.promoted.span; self.assign(RETURN_PLACE, rvalue, span); + + // Now that we did promotion, we know whether we'll want to add this to `required_consts`. + if self.add_to_required { + self.source.required_consts.push(promoted_op); + } + self.promoted } } @@ -880,6 +955,14 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Promoter<'a, 'tcx> { *local = self.promote_temp(*local); } } + + fn visit_constant(&mut self, constant: &mut ConstOperand<'tcx>, _location: Location) { + if constant.const_.is_required_const() { + self.promoted.required_consts.push(*constant); + } + + // Skipping `super_constant` as the visitor is otherwise only looking for locals. + } } fn promote_candidates<'tcx>( @@ -933,8 +1016,10 @@ fn promote_candidates<'tcx>( temps: &mut temps, extra_statements: &mut extra_statements, keep_original: false, + add_to_required: false, }; + // `required_consts` of the promoted itself gets filled while building the MIR body. let mut promoted = promoter.promote_candidate(candidate, promotions.len()); promoted.source.promoted = Some(promotions.next_index()); promotions.push(promoted); diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs index d5642be5513..801ef14c9cd 100644 --- a/compiler/rustc_mir_transform/src/ref_prop.rs +++ b/compiler/rustc_mir_transform/src/ref_prop.rs @@ -1,6 +1,7 @@ use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; @@ -82,7 +83,8 @@ impl<'tcx> MirPass<'tcx> for ReferencePropagation { } fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { - let ssa = SsaLocals::new(body); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let ssa = SsaLocals::new(tcx, body, param_env); let mut replacer = compute_replacement(tcx, body, &ssa); debug!(?replacer.targets); diff --git a/compiler/rustc_mir_transform/src/required_consts.rs b/compiler/rustc_mir_transform/src/required_consts.rs index abde6a47e83..71ac929d35e 100644 --- a/compiler/rustc_mir_transform/src/required_consts.rs +++ b/compiler/rustc_mir_transform/src/required_consts.rs @@ -1,6 +1,5 @@ use rustc_middle::mir::visit::Visitor; -use rustc_middle::mir::{Const, ConstOperand, Location}; -use rustc_middle::ty::ConstKind; +use rustc_middle::mir::{ConstOperand, Location}; pub struct RequiredConstsVisitor<'a, 'tcx> { required_consts: &'a mut Vec<ConstOperand<'tcx>>, @@ -14,14 +13,8 @@ impl<'a, 'tcx> RequiredConstsVisitor<'a, 'tcx> { impl<'tcx> Visitor<'tcx> for RequiredConstsVisitor<'_, 'tcx> { fn visit_constant(&mut self, constant: &ConstOperand<'tcx>, _: Location) { - let const_ = constant.const_; - match const_ { - Const::Ty(c) => match c.kind() { - ConstKind::Param(_) | ConstKind::Error(_) | ConstKind::Value(_) => {} - _ => bug!("only ConstKind::Param/Value should be encountered here, got {:#?}", c), - }, - Const::Unevaluated(..) => self.required_consts.push(*constant), - Const::Val(..) => {} + if constant.const_.is_required_const() { + self.required_consts.push(*constant); } } } diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 733e2f93b25..dcf54ad2cfc 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -3,8 +3,9 @@ use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; use rustc_middle::query::Providers; +use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt}; -use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL}; +use rustc_middle::{bug, span_bug}; use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_index::{Idx, IndexVec}; @@ -17,11 +18,13 @@ use std::iter; use crate::{ abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator, - pass_manager as pm, remove_noop_landing_pads, simplify, + mentioned_items, pass_manager as pm, remove_noop_landing_pads, simplify, }; use rustc_middle::mir::patch::MirPatch; use rustc_mir_dataflow::elaborate_drops::{self, DropElaborator, DropFlagMode, DropStyle}; +mod async_destructor_ctor; + pub fn provide(providers: &mut Providers) { providers.mir_shims = make_shim; } @@ -55,7 +58,7 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' // a virtual call, or a direct call to a function for which // indirect calls must be codegen'd differently than direct ones // (such as `#[track_caller]`). - ty::InstanceDef::ReifyShim(def_id) => { + ty::InstanceDef::ReifyShim(def_id, _) => { build_call_shim(tcx, instance, None, CallKind::Direct(def_id)) } ty::InstanceDef::ClosureOnceShim { call_once: _, track_caller: _ } => { @@ -72,37 +75,12 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id, - target_kind, - } => match target_kind { - ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"), - ty::ClosureKind::FnMut => { - // No need to optimize the body, it has already been optimized - // since we steal it from the `AsyncFn::call` body and just fix - // the return type. - return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); - } - ty::ClosureKind::FnOnce => { - build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id) - } - }, + receiver_by_ref, + } => build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id, receiver_by_ref), - ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind { - ty::ClosureKind::Fn => unreachable!(), - ty::ClosureKind::FnMut => { - return tcx - .optimized_mir(coroutine_def_id) - .coroutine_by_mut_body() - .unwrap() - .clone(); - } - ty::ClosureKind::FnOnce => { - return tcx - .optimized_mir(coroutine_def_id) - .coroutine_by_move_body() - .unwrap() - .clone(); - } - }, + ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => { + return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone(); + } ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end @@ -123,21 +101,11 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() { coroutine_body.coroutine_drop().unwrap() } else { - match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() { - ty::ClosureKind::Fn => { - unreachable!() - } - ty::ClosureKind::FnMut => coroutine_body - .coroutine_by_mut_body() - .unwrap() - .coroutine_drop() - .unwrap(), - ty::ClosureKind::FnOnce => coroutine_body - .coroutine_by_move_body() - .unwrap() - .coroutine_drop() - .unwrap(), - } + assert_eq!( + args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(), + ty::ClosureKind::FnOnce + ); + coroutine_body.coroutine_by_move_body().unwrap().coroutine_drop().unwrap() }; let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); @@ -147,6 +115,7 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' tcx, &mut body, &[ + &mentioned_items::MentionedItems, &abort_unwinding_calls::AbortUnwindingCalls, &add_call_guards::CriticalCallEdges, ], @@ -161,6 +130,9 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' ty::InstanceDef::ThreadLocalShim(..) => build_thread_local_shim(tcx, instance), ty::InstanceDef::CloneShim(def_id, ty) => build_clone_shim(tcx, def_id, ty), ty::InstanceDef::FnPtrAddrShim(def_id, ty) => build_fn_ptr_addr_shim(tcx, def_id, ty), + ty::InstanceDef::AsyncDropGlueCtorShim(def_id, ty) => { + async_destructor_ctor::build_async_destructor_ctor_shim(tcx, def_id, ty) + } ty::InstanceDef::Virtual(..) => { bug!("InstanceDef::Virtual ({:?}) is for direct calls only", instance) } @@ -178,6 +150,7 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' tcx, &mut result, &[ + &mentioned_items::MentionedItems, &add_moves_for_packed_drops::AddMovesForPackedDrops, &deref_separator::Derefer, &remove_noop_landing_pads::RemoveNoopLandingPads, @@ -572,14 +545,8 @@ impl<'tcx> CloneShimBuilder<'tcx> { const_: Const::zero_sized(func_ty), })); - let ref_loc = self.make_place( - Mutability::Not, - Ty::new_ref( - tcx, - tcx.lifetimes.re_erased, - ty::TypeAndMut { ty, mutbl: hir::Mutability::Not }, - ), - ); + let ref_loc = + self.make_place(Mutability::Not, Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, ty)); // `let ref_loc: &ty = &src;` let statement = self.make_statement(StatementKind::Assign(Box::new(( @@ -804,11 +771,7 @@ fn build_call_shim<'tcx>( // let rcvr = &mut rcvr; let ref_rcvr = local_decls.push( LocalDecl::new( - Ty::new_ref( - tcx, - tcx.lifetimes.re_erased, - ty::TypeAndMut { ty: sig.inputs()[0], mutbl: hir::Mutability::Mut }, - ), + Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, sig.inputs()[0]), span, ) .immutable(), @@ -1028,7 +991,7 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t let locals = local_decls_for_sig(&sig, span); let source_info = SourceInfo::outermost(span); - // FIXME: use `expose_addr` once we figure out whether function pointers have meaningful provenance. + // FIXME: use `expose_provenance` once we figure out whether function pointers have meaningful provenance. let rvalue = Rvalue::Cast( CastKind::FnPtrToPtr, Operand::Move(Place::from(Local::new(1))), @@ -1051,12 +1014,26 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t fn build_construct_coroutine_by_move_shim<'tcx>( tcx: TyCtxt<'tcx>, coroutine_closure_def_id: DefId, + receiver_by_ref: bool, ) -> Body<'tcx> { - let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); let ty::CoroutineClosure(_, args) = *self_ty.kind() else { bug!(); }; + // We use `&mut Self` here because we only need to emit an ABI-compatible shim body, + // rather than match the signature exactly (which might take `&self` instead). + // + // The self type here is a coroutine-closure, not a coroutine, and we never read from + // it because it never has any captures, because this is only true in the Fn/FnMut + // implementation, not the AsyncFn/AsyncFnMut implementation, which is implemented only + // if the coroutine-closure has no captures. + if receiver_by_ref { + // Triple-check that there's no captures here. + assert_eq!(args.as_coroutine_closure().tupled_upvars_ty(), tcx.types.unit); + self_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty); + } + let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { tcx.mk_fn_sig( [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()), @@ -1070,7 +1047,7 @@ fn build_construct_coroutine_by_move_shim<'tcx>( args.as_coroutine_closure().coroutine_captures_by_ref_ty(), ), sig.c_variadic, - sig.unsafety, + sig.safety, sig.abi, ) }); @@ -1112,49 +1089,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>( let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id, - target_kind: ty::ClosureKind::FnOnce, + receiver_by_ref, }); let body = new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span); - dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(())); - - body -} - -fn build_construct_coroutine_by_mut_shim<'tcx>( - tcx: TyCtxt<'tcx>, - coroutine_closure_def_id: DefId, -) -> Body<'tcx> { - let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone(); - let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); - let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else { - bug!(); - }; - let args = args.as_coroutine_closure(); - - body.local_decls[RETURN_PLACE].ty = - tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| { - sig.to_coroutine_given_kind_and_upvars( - tcx, - args.parent_args(), - tcx.coroutine_for_closure(coroutine_closure_def_id), - ty::ClosureKind::FnMut, - tcx.lifetimes.re_erased, - args.tupled_upvars_ty(), - args.coroutine_captures_by_ref_ty(), - ) - })); - body.local_decls[CAPTURE_STRUCT_LOCAL].ty = - Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty); - - body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { - coroutine_closure_def_id, - target_kind: ty::ClosureKind::FnMut, - }); - - body.pass_count = 0; - dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(())); + dump_mir( + tcx, + false, + if receiver_by_ref { "coroutine_closure_by_ref" } else { "coroutine_closure_by_move" }, + &0, + &body, + |_, _| Ok(()), + ); body } diff --git a/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs b/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs new file mode 100644 index 00000000000..f4481c22fc1 --- /dev/null +++ b/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs @@ -0,0 +1,618 @@ +use std::iter; + +use itertools::Itertools; +use rustc_ast::Mutability; +use rustc_const_eval::interpret; +use rustc_hir::def_id::DefId; +use rustc_hir::lang_items::LangItem; +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, Body, CallSource, CastKind, Const, ConstOperand, ConstValue, Local, + LocalDecl, MirSource, Operand, Place, PlaceElem, Rvalue, SourceInfo, Statement, StatementKind, + Terminator, TerminatorKind, UnwindAction, UnwindTerminateReason, RETURN_PLACE, +}; +use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::util::Discr; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::{bug, span_bug}; +use rustc_span::source_map::respan; +use rustc_span::{Span, Symbol}; +use rustc_target::abi::{FieldIdx, VariantIdx}; +use rustc_target::spec::PanicStrategy; + +use super::{local_decls_for_sig, new_body}; + +pub fn build_async_destructor_ctor_shim<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, + ty: Option<Ty<'tcx>>, +) -> Body<'tcx> { + debug!("build_drop_shim(def_id={:?}, ty={:?})", def_id, ty); + + AsyncDestructorCtorShimBuilder::new(tcx, def_id, ty).build() +} + +/// Builder for async_drop_in_place shim. Functions as a stack machine +/// to build up an expression using combinators. Stack contains pairs +/// of locals and types. Combinator is a not yet instantiated pair of a +/// function and a type, is considered to be an operator which consumes +/// operands from the stack by instantiating its function and its type +/// with operand types and moving locals into the function call. Top +/// pair is considered to be the last operand. +// FIXME: add mir-opt tests +struct AsyncDestructorCtorShimBuilder<'tcx> { + tcx: TyCtxt<'tcx>, + def_id: DefId, + self_ty: Option<Ty<'tcx>>, + span: Span, + source_info: SourceInfo, + param_env: ty::ParamEnv<'tcx>, + + stack: Vec<Operand<'tcx>>, + last_bb: BasicBlock, + top_cleanup_bb: Option<BasicBlock>, + + locals: IndexVec<Local, LocalDecl<'tcx>>, + bbs: IndexVec<BasicBlock, BasicBlockData<'tcx>>, +} + +#[derive(Clone, Copy)] +enum SurfaceDropKind { + Async, + Sync, +} + +impl<'tcx> AsyncDestructorCtorShimBuilder<'tcx> { + const SELF_PTR: Local = Local::from_u32(1); + const INPUT_COUNT: usize = 1; + const MAX_STACK_LEN: usize = 2; + + fn new(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Option<Ty<'tcx>>) -> Self { + let args = if let Some(ty) = self_ty { + tcx.mk_args(&[ty.into()]) + } else { + ty::GenericArgs::identity_for_item(tcx, def_id) + }; + let sig = tcx.fn_sig(def_id).instantiate(tcx, args); + let sig = tcx.instantiate_bound_regions_with_erased(sig); + let span = tcx.def_span(def_id); + + let source_info = SourceInfo::outermost(span); + + debug_assert_eq!(sig.inputs().len(), Self::INPUT_COUNT); + let locals = local_decls_for_sig(&sig, span); + + // Usual case: noop() + unwind resume + return + let mut bbs = IndexVec::with_capacity(3); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + AsyncDestructorCtorShimBuilder { + tcx, + def_id, + self_ty, + span, + source_info, + param_env, + + stack: Vec::with_capacity(Self::MAX_STACK_LEN), + last_bb: bbs.push(BasicBlockData::new(None)), + top_cleanup_bb: match tcx.sess.panic_strategy() { + PanicStrategy::Unwind => { + // Don't drop input arg because it's just a pointer + Some(bbs.push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::UnwindResume, + }), + is_cleanup: true, + })) + } + PanicStrategy::Abort => None, + }, + + locals, + bbs, + } + } + + fn build(self) -> Body<'tcx> { + let (tcx, def_id, Some(self_ty)) = (self.tcx, self.def_id, self.self_ty) else { + return self.build_zst_output(); + }; + + let surface_drop_kind = || { + let param_env = tcx.param_env_reveal_all_normalized(def_id); + if self_ty.has_surface_async_drop(tcx, param_env) { + Some(SurfaceDropKind::Async) + } else if self_ty.has_surface_drop(tcx, param_env) { + Some(SurfaceDropKind::Sync) + } else { + None + } + }; + + match self_ty.kind() { + ty::Array(elem_ty, _) => self.build_slice(true, *elem_ty), + ty::Slice(elem_ty) => self.build_slice(false, *elem_ty), + + ty::Tuple(elem_tys) => self.build_chain(None, elem_tys.iter()), + ty::Adt(adt_def, args) if adt_def.is_struct() => { + let field_tys = adt_def.non_enum_variant().fields.iter().map(|f| f.ty(tcx, args)); + self.build_chain(surface_drop_kind(), field_tys) + } + ty::Closure(_, args) => self.build_chain(None, args.as_closure().upvar_tys().iter()), + ty::CoroutineClosure(_, args) => { + self.build_chain(None, args.as_coroutine_closure().upvar_tys().iter()) + } + + ty::Adt(adt_def, args) if adt_def.is_enum() => { + self.build_enum(*adt_def, *args, surface_drop_kind()) + } + + ty::Adt(adt_def, _) => { + assert!(adt_def.is_union()); + match surface_drop_kind().unwrap() { + SurfaceDropKind::Async => self.build_fused_async_surface(), + SurfaceDropKind::Sync => self.build_fused_sync_surface(), + } + } + + ty::Bound(..) + | ty::Foreign(_) + | ty::Placeholder(_) + | ty::Infer(ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) | ty::TyVar(_)) + | ty::Param(_) + | ty::Alias(..) => { + bug!("Building async destructor for unexpected type: {self_ty:?}") + } + + _ => { + bug!( + "Building async destructor constructor shim is not yet implemented for type: {self_ty:?}" + ) + } + } + } + + fn build_enum( + mut self, + adt_def: ty::AdtDef<'tcx>, + args: ty::GenericArgsRef<'tcx>, + surface_drop: Option<SurfaceDropKind>, + ) -> Body<'tcx> { + let tcx = self.tcx; + + let surface = match surface_drop { + None => None, + Some(kind) => { + self.put_self(); + Some(match kind { + SurfaceDropKind::Async => self.combine_async_surface(), + SurfaceDropKind::Sync => self.combine_sync_surface(), + }) + } + }; + + let mut other = None; + for (variant_idx, discr) in adt_def.discriminants(tcx) { + let variant = adt_def.variant(variant_idx); + + let mut chain = None; + for (field_idx, field) in variant.fields.iter_enumerated() { + let field_ty = field.ty(tcx, args); + self.put_variant_field(variant.name, variant_idx, field_idx, field_ty); + let defer = self.combine_defer(field_ty); + chain = Some(match chain { + None => defer, + Some(chain) => self.combine_chain(chain, defer), + }) + } + let variant_dtor = chain.unwrap_or_else(|| self.put_noop()); + + other = Some(match other { + None => variant_dtor, + Some(other) => { + self.put_self(); + self.put_discr(discr); + self.combine_either(other, variant_dtor) + } + }); + } + let variants_dtor = other.unwrap_or_else(|| self.put_noop()); + + let dtor = match surface { + None => variants_dtor, + Some(surface) => self.combine_chain(surface, variants_dtor), + }; + self.combine_fuse(dtor); + self.return_() + } + + fn build_chain<I>(mut self, surface_drop: Option<SurfaceDropKind>, elem_tys: I) -> Body<'tcx> + where + I: Iterator<Item = Ty<'tcx>> + ExactSizeIterator, + { + let surface = match surface_drop { + None => None, + Some(kind) => { + self.put_self(); + Some(match kind { + SurfaceDropKind::Async => self.combine_async_surface(), + SurfaceDropKind::Sync => self.combine_sync_surface(), + }) + } + }; + + let mut chain = None; + for (field_idx, field_ty) in elem_tys.enumerate().map(|(i, ty)| (FieldIdx::new(i), ty)) { + self.put_field(field_idx, field_ty); + let defer = self.combine_defer(field_ty); + chain = Some(match chain { + None => defer, + Some(chain) => self.combine_chain(chain, defer), + }) + } + let chain = chain.unwrap_or_else(|| self.put_noop()); + + let dtor = match surface { + None => chain, + Some(surface) => self.combine_chain(surface, chain), + }; + self.combine_fuse(dtor); + self.return_() + } + + fn build_zst_output(mut self) -> Body<'tcx> { + self.put_zst_output(); + self.return_() + } + + fn build_fused_async_surface(mut self) -> Body<'tcx> { + self.put_self(); + let surface = self.combine_async_surface(); + self.combine_fuse(surface); + self.return_() + } + + fn build_fused_sync_surface(mut self) -> Body<'tcx> { + self.put_self(); + let surface = self.combine_sync_surface(); + self.combine_fuse(surface); + self.return_() + } + + fn build_slice(mut self, is_array: bool, elem_ty: Ty<'tcx>) -> Body<'tcx> { + if is_array { + self.put_array_as_slice(elem_ty) + } else { + self.put_self() + } + let dtor = self.combine_slice(elem_ty); + self.combine_fuse(dtor); + self.return_() + } + + fn put_zst_output(&mut self) { + let return_ty = self.locals[RETURN_PLACE].ty; + self.put_operand(Operand::Constant(Box::new(ConstOperand { + span: self.span, + user_ty: None, + const_: Const::zero_sized(return_ty), + }))); + } + + /// Puts `to_drop: *mut Self` on top of the stack. + fn put_self(&mut self) { + self.put_operand(Operand::Copy(Self::SELF_PTR.into())) + } + + /// Given that `Self is [ElemTy; N]` puts `to_drop: *mut [ElemTy]` + /// on top of the stack. + fn put_array_as_slice(&mut self, elem_ty: Ty<'tcx>) { + let slice_ptr_ty = Ty::new_mut_ptr(self.tcx, Ty::new_slice(self.tcx, elem_ty)); + self.put_temp_rvalue(Rvalue::Cast( + CastKind::PointerCoercion(PointerCoercion::Unsize), + Operand::Copy(Self::SELF_PTR.into()), + slice_ptr_ty, + )) + } + + /// If given Self is a struct puts `to_drop: *mut FieldTy` on top + /// of the stack. + fn put_field(&mut self, field: FieldIdx, field_ty: Ty<'tcx>) { + let place = Place { + local: Self::SELF_PTR, + projection: self + .tcx + .mk_place_elems(&[PlaceElem::Deref, PlaceElem::Field(field, field_ty)]), + }; + self.put_temp_rvalue(Rvalue::AddressOf(Mutability::Mut, place)) + } + + /// If given Self is an enum puts `to_drop: *mut FieldTy` on top of + /// the stack. + fn put_variant_field( + &mut self, + variant_sym: Symbol, + variant: VariantIdx, + field: FieldIdx, + field_ty: Ty<'tcx>, + ) { + let place = Place { + local: Self::SELF_PTR, + projection: self.tcx.mk_place_elems(&[ + PlaceElem::Deref, + PlaceElem::Downcast(Some(variant_sym), variant), + PlaceElem::Field(field, field_ty), + ]), + }; + self.put_temp_rvalue(Rvalue::AddressOf(Mutability::Mut, place)) + } + + /// If given Self is an enum puts `to_drop: *mut FieldTy` on top of + /// the stack. + fn put_discr(&mut self, discr: Discr<'tcx>) { + let (size, _) = discr.ty.int_size_and_signed(self.tcx); + self.put_operand(Operand::const_from_scalar( + self.tcx, + discr.ty, + interpret::Scalar::from_uint(discr.val, size), + self.span, + )); + } + + /// Puts `x: RvalueType` on top of the stack. + fn put_temp_rvalue(&mut self, rvalue: Rvalue<'tcx>) { + let last_bb = &mut self.bbs[self.last_bb]; + debug_assert!(last_bb.terminator.is_none()); + let source_info = self.source_info; + + let local_ty = rvalue.ty(&self.locals, self.tcx); + // We need to create a new local to be able to "consume" it with + // a combinator + let local = self.locals.push(LocalDecl::with_source_info(local_ty, source_info)); + last_bb.statements.extend_from_slice(&[ + Statement { source_info, kind: StatementKind::StorageLive(local) }, + Statement { + source_info, + kind: StatementKind::Assign(Box::new((local.into(), rvalue))), + }, + ]); + + self.put_operand(Operand::Move(local.into())); + } + + /// Puts operand on top of the stack. + fn put_operand(&mut self, operand: Operand<'tcx>) { + if let Some(top_cleanup_bb) = &mut self.top_cleanup_bb { + let source_info = self.source_info; + match &operand { + Operand::Copy(_) | Operand::Constant(_) => { + *top_cleanup_bb = self.bbs.push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Goto { target: *top_cleanup_bb }, + }), + is_cleanup: true, + }); + } + Operand::Move(place) => { + let local = place.as_local().unwrap(); + *top_cleanup_bb = self.bbs.push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: if self.locals[local].ty.needs_drop(self.tcx, self.param_env) { + TerminatorKind::Drop { + place: local.into(), + target: *top_cleanup_bb, + unwind: UnwindAction::Terminate( + UnwindTerminateReason::InCleanup, + ), + replace: false, + } + } else { + TerminatorKind::Goto { target: *top_cleanup_bb } + }, + }), + is_cleanup: true, + }); + } + }; + } + self.stack.push(operand); + } + + /// Puts `noop: async_drop::Noop` on top of the stack + fn put_noop(&mut self) -> Ty<'tcx> { + self.apply_combinator(0, LangItem::AsyncDropNoop, &[]) + } + + fn combine_async_surface(&mut self) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::SurfaceAsyncDropInPlace, &[self.self_ty.unwrap().into()]) + } + + fn combine_sync_surface(&mut self) -> Ty<'tcx> { + self.apply_combinator( + 1, + LangItem::AsyncDropSurfaceDropInPlace, + &[self.self_ty.unwrap().into()], + ) + } + + fn combine_fuse(&mut self, inner_future_ty: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::AsyncDropFuse, &[inner_future_ty.into()]) + } + + fn combine_slice(&mut self, elem_ty: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::AsyncDropSlice, &[elem_ty.into()]) + } + + fn combine_defer(&mut self, to_drop_ty: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::AsyncDropDefer, &[to_drop_ty.into()]) + } + + fn combine_chain(&mut self, first: Ty<'tcx>, second: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(2, LangItem::AsyncDropChain, &[first.into(), second.into()]) + } + + fn combine_either(&mut self, other: Ty<'tcx>, matched: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator( + 4, + LangItem::AsyncDropEither, + &[other.into(), matched.into(), self.self_ty.unwrap().into()], + ) + } + + fn return_(mut self) -> Body<'tcx> { + let last_bb = &mut self.bbs[self.last_bb]; + debug_assert!(last_bb.terminator.is_none()); + let source_info = self.source_info; + + let (1, Some(output)) = (self.stack.len(), self.stack.pop()) else { + span_bug!( + self.span, + "async destructor ctor shim builder finished with invalid number of stack items: expected 1 found {}", + self.stack.len(), + ) + }; + #[cfg(debug_assertions)] + if let Some(ty) = self.self_ty { + debug_assert_eq!( + output.ty(&self.locals, self.tcx), + ty.async_destructor_ty(self.tcx, self.param_env), + "output async destructor types did not match for type: {ty:?}", + ); + } + + let dead_storage = match &output { + Operand::Move(place) => Some(Statement { + source_info, + kind: StatementKind::StorageDead(place.as_local().unwrap()), + }), + _ => None, + }; + + last_bb.statements.extend( + iter::once(Statement { + source_info, + kind: StatementKind::Assign(Box::new((RETURN_PLACE.into(), Rvalue::Use(output)))), + }) + .chain(dead_storage), + ); + + last_bb.terminator = Some(Terminator { source_info, kind: TerminatorKind::Return }); + + let source = MirSource::from_instance(ty::InstanceDef::AsyncDropGlueCtorShim( + self.def_id, + self.self_ty, + )); + new_body(source, self.bbs, self.locals, Self::INPUT_COUNT, self.span) + } + + fn apply_combinator( + &mut self, + arity: usize, + function: LangItem, + args: &[ty::GenericArg<'tcx>], + ) -> Ty<'tcx> { + let function = self.tcx.require_lang_item(function, Some(self.span)); + let operands_split = self + .stack + .len() + .checked_sub(arity) + .expect("async destructor ctor shim combinator tried to consume too many items"); + let operands = &self.stack[operands_split..]; + + let func_ty = Ty::new_fn_def(self.tcx, function, args.iter().copied()); + let func_sig = func_ty.fn_sig(self.tcx).no_bound_vars().unwrap(); + #[cfg(debug_assertions)] + operands.iter().zip(func_sig.inputs()).for_each(|(operand, expected_ty)| { + let operand_ty = operand.ty(&self.locals, self.tcx); + if operand_ty == *expected_ty { + return; + } + + // If projection of Discriminant then compare with `Ty::discriminant_ty` + if let ty::Alias(ty::Projection, ty::AliasTy { args, def_id, .. }) = expected_ty.kind() + && Some(*def_id) == self.tcx.lang_items().discriminant_type() + && args.first().unwrap().as_type().unwrap().discriminant_ty(self.tcx) == operand_ty + { + return; + } + + span_bug!( + self.span, + "Operand type and combinator argument type are not equal. + operand_ty: {:?} + argument_ty: {:?} +", + operand_ty, + expected_ty + ); + }); + + let target = self.bbs.push(BasicBlockData { + statements: operands + .iter() + .rev() + .filter_map(|o| { + if let Operand::Move(Place { local, projection }) = o { + assert!(projection.is_empty()); + Some(Statement { + source_info: self.source_info, + kind: StatementKind::StorageDead(*local), + }) + } else { + None + } + }) + .collect(), + terminator: None, + is_cleanup: false, + }); + + let dest_ty = func_sig.output(); + let dest = + self.locals.push(LocalDecl::with_source_info(dest_ty, self.source_info).immutable()); + + let unwind = if let Some(top_cleanup_bb) = &mut self.top_cleanup_bb { + for _ in 0..arity { + *top_cleanup_bb = + self.bbs[*top_cleanup_bb].terminator().successors().exactly_one().ok().unwrap(); + } + UnwindAction::Cleanup(*top_cleanup_bb) + } else { + UnwindAction::Unreachable + }; + + let last_bb = &mut self.bbs[self.last_bb]; + debug_assert!(last_bb.terminator.is_none()); + last_bb.statements.push(Statement { + source_info: self.source_info, + kind: StatementKind::StorageLive(dest), + }); + last_bb.terminator = Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::Call { + func: Operand::Constant(Box::new(ConstOperand { + span: self.span, + user_ty: None, + const_: Const::Val(ConstValue::ZeroSized, func_ty), + })), + destination: dest.into(), + target: Some(target), + unwind, + call_source: CallSource::Misc, + fn_span: self.span, + args: self.stack.drain(operands_split..).map(|o| respan(self.span, o)).collect(), + }, + }); + + self.put_operand(Operand::Move(dest.into())); + self.last_bb = target; + + dest_ty + } +} diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs index 8c8818bd68e..5bbe3bb747f 100644 --- a/compiler/rustc_mir_transform/src/simplify.rs +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -37,11 +37,14 @@ pub enum SimplifyCfg { Initial, PromoteConsts, RemoveFalseEdges, - EarlyOpt, - ElaborateDrops, + /// Runs at the beginning of "analysis to runtime" lowering, *before* drop elaboration. + PostAnalysis, + /// Runs at the end of "analysis to runtime" lowering, *after* drop elaboration. + /// This is before the main optimization passes on runtime MIR kick in. + PreOptimizations, Final, MakeShim, - AfterUninhabitedEnumBranching, + AfterUnreachableEnumBranching, } impl SimplifyCfg { @@ -50,12 +53,12 @@ impl SimplifyCfg { SimplifyCfg::Initial => "SimplifyCfg-initial", SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts", SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges", - SimplifyCfg::EarlyOpt => "SimplifyCfg-early-opt", - SimplifyCfg::ElaborateDrops => "SimplifyCfg-elaborate-drops", + SimplifyCfg::PostAnalysis => "SimplifyCfg-post-analysis", + SimplifyCfg::PreOptimizations => "SimplifyCfg-pre-optimizations", SimplifyCfg::Final => "SimplifyCfg-final", SimplifyCfg::MakeShim => "SimplifyCfg-make_shim", - SimplifyCfg::AfterUninhabitedEnumBranching => { - "SimplifyCfg-after-uninhabited-enum-branching" + SimplifyCfg::AfterUnreachableEnumBranching => { + "SimplifyCfg-after-unreachable-enum-branching" } } } @@ -412,7 +415,7 @@ fn make_local_map<V>( used_locals: &UsedLocals, ) -> IndexVec<Local, Option<Local>> { let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls); - let mut used = Local::new(0); + let mut used = Local::ZERO; for alive_index in local_decls.indices() { // `is_used` treats the `RETURN_PLACE` and arguments as used. diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs index 35a052166bd..c746041ebd8 100644 --- a/compiler/rustc_mir_transform/src/simplify_branches.rs +++ b/compiler/rustc_mir_transform/src/simplify_branches.rs @@ -19,6 +19,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyConstCondition { let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); 'blocks: for block in body.basic_blocks_mut() { for stmt in block.statements.iter_mut() { + // Simplify `assume` of a known value: either a NOP or unreachable. if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind && let NonDivergingIntrinsic::Assume(discr) = intrinsic && let Operand::Constant(ref c) = discr diff --git a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs index 1a8cfc41178..03907babf2b 100644 --- a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs +++ b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs @@ -2,6 +2,7 @@ use std::iter; use super::MirPass; use rustc_middle::{ + bug, mir::{ interpret::Scalar, BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, StatementKind, SwitchTargets, TerminatorKind, diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 06d5e17fdd6..f19c34cae7a 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -1,6 +1,7 @@ use rustc_data_structures::flat_map_in_place::FlatMapInPlace; use rustc_index::bit_set::{BitSet, GrowableBitSet}; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; @@ -69,6 +70,11 @@ fn escaping_locals<'tcx>( // Exclude #[repr(simd)] types so that they are not de-optimized into an array return true; } + if Some(def.did()) == tcx.lang_items().dyn_metadata() { + // codegen wants to see the `DynMetadata<T>`, + // not the inner reference-to-opaque-type. + return true; + } // We already excluded unions and enums, so this ADT must have one variant let variant = def.variant(FIRST_VARIANT); if variant.fields.len() > 1 { diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs index e4fdbd6ae69..fb870425f6e 100644 --- a/compiler/rustc_mir_transform/src/ssa.rs +++ b/compiler/rustc_mir_transform/src/ssa.rs @@ -2,15 +2,18 @@ //! 1/ They are only assigned-to once, either as a function parameter, or in an assign statement; //! 2/ This single assignment dominates all uses; //! -//! As a consequence of rule 2, we consider that borrowed locals are not SSA, even if they are -//! `Freeze`, as we do not track that the assignment dominates all uses of the borrow. +//! As we do not track indirect assignments, a local that has its address taken (either by +//! AddressOf or by borrowing) is considered non-SSA. However, it is UB to modify through an +//! immutable borrow of a `Freeze` local. Those can still be considered to be SSA. use rustc_data_structures::graph::dominators::Dominators; use rustc_index::bit_set::BitSet; use rustc_index::{IndexSlice, IndexVec}; +use rustc_middle::bug; use rustc_middle::middle::resolve_bound_vars::Set1; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; +use rustc_middle::ty::{ParamEnv, TyCtxt}; pub struct SsaLocals { /// Assignments to each local. This defines whether the local is SSA. @@ -24,24 +27,33 @@ pub struct SsaLocals { /// Number of "direct" uses of each local, ie. uses that are not dereferences. /// We ignore non-uses (Storage statements, debuginfo). direct_uses: IndexVec<Local, u32>, + /// Set of SSA locals that are immutably borrowed. + borrowed_locals: BitSet<Local>, } pub enum AssignedValue<'a, 'tcx> { Arg, Rvalue(&'a mut Rvalue<'tcx>), - Terminator(&'a mut TerminatorKind<'tcx>), + Terminator, } impl SsaLocals { - pub fn new<'tcx>(body: &Body<'tcx>) -> SsaLocals { + pub fn new<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ParamEnv<'tcx>) -> SsaLocals { let assignment_order = Vec::with_capacity(body.local_decls.len()); let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls); let dominators = body.basic_blocks.dominators(); let direct_uses = IndexVec::from_elem(0, &body.local_decls); - let mut visitor = - SsaVisitor { body, assignments, assignment_order, dominators, direct_uses }; + let borrowed_locals = BitSet::new_empty(body.local_decls.len()); + let mut visitor = SsaVisitor { + body, + assignments, + assignment_order, + dominators, + direct_uses, + borrowed_locals, + }; for local in body.args_iter() { visitor.assignments[local] = Set1::One(DefLocation::Argument); @@ -58,6 +70,16 @@ impl SsaLocals { visitor.visit_var_debug_info(var_debug_info); } + // The immutability of shared borrows only works on `Freeze` locals. If the visitor found + // borrows, we need to check the types. For raw pointers and mutable borrows, the locals + // have already been marked as non-SSA. + debug!(?visitor.borrowed_locals); + for local in visitor.borrowed_locals.iter() { + if !body.local_decls[local].ty.is_freeze(tcx, param_env) { + visitor.assignments[local] = Set1::Many; + } + } + debug!(?visitor.assignments); debug!(?visitor.direct_uses); @@ -70,6 +92,7 @@ impl SsaLocals { assignments: visitor.assignments, assignment_order: visitor.assignment_order, direct_uses: visitor.direct_uses, + borrowed_locals: visitor.borrowed_locals, // This is filled by `compute_copy_classes`. copy_classes: IndexVec::default(), }; @@ -149,8 +172,7 @@ impl SsaLocals { Set1::One(DefLocation::CallReturn { call, .. }) => { let bb = &mut basic_blocks[call]; let loc = Location { block: call, statement_index: bb.statements.len() }; - let term = bb.terminator_mut(); - f(local, AssignedValue::Terminator(&mut term.kind), loc) + f(local, AssignedValue::Terminator, loc) } _ => {} } @@ -175,6 +197,11 @@ impl SsaLocals { &self.copy_classes } + /// Set of SSA locals that are immutably borrowed. + pub fn borrowed_locals(&self) -> &BitSet<Local> { + &self.borrowed_locals + } + /// Make a property uniform on a copy equivalence class by removing elements. pub fn meet_copy_equivalence(&self, property: &mut BitSet<Local>) { // Consolidate to have a local iff all its copies are. @@ -209,6 +236,8 @@ struct SsaVisitor<'tcx, 'a> { assignments: IndexVec<Local, Set1<DefLocation>>, assignment_order: Vec<Local>, direct_uses: IndexVec<Local, u32>, + // Track locals that are immutably borrowed, so we can check their type is `Freeze` later. + borrowed_locals: BitSet<Local>, } impl SsaVisitor<'_, '_> { @@ -233,16 +262,18 @@ impl<'tcx> Visitor<'tcx> for SsaVisitor<'tcx, '_> { PlaceContext::MutatingUse(MutatingUseContext::Projection) | PlaceContext::NonMutatingUse(NonMutatingUseContext::Projection) => bug!(), // Anything can happen with raw pointers, so remove them. - // We do not verify that all uses of the borrow dominate the assignment to `local`, - // so we have to remove them too. - PlaceContext::NonMutatingUse( - NonMutatingUseContext::SharedBorrow - | NonMutatingUseContext::FakeBorrow - | NonMutatingUseContext::AddressOf, - ) + PlaceContext::NonMutatingUse(NonMutatingUseContext::AddressOf) | PlaceContext::MutatingUse(_) => { self.assignments[local] = Set1::Many; } + // Immutable borrows are ok, but we need to delay a check that the type is `Freeze`. + PlaceContext::NonMutatingUse( + NonMutatingUseContext::SharedBorrow | NonMutatingUseContext::FakeBorrow, + ) => { + self.borrowed_locals.insert(local); + self.check_dominates(local, loc); + self.direct_uses[local] += 1; + } PlaceContext::NonMutatingUse(_) => { self.check_dominates(local, loc); self.direct_uses[local] += 1; diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs deleted file mode 100644 index e68d37f4c70..00000000000 --- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs +++ /dev/null @@ -1,137 +0,0 @@ -//! A pass that eliminates branches on uninhabited enum variants. - -use crate::MirPass; -use rustc_data_structures::fx::FxHashSet; -use rustc_middle::mir::{ - BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind, -}; -use rustc_middle::ty::layout::TyAndLayout; -use rustc_middle::ty::{Ty, TyCtxt}; -use rustc_target::abi::{Abi, Variants}; - -pub struct UninhabitedEnumBranching; - -fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> { - if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator { - p.as_local() - } else { - None - } -} - -/// If the basic block terminates by switching on a discriminant, this returns the `Ty` the -/// discriminant is read from. Otherwise, returns None. -fn get_switched_on_type<'tcx>( - block_data: &BasicBlockData<'tcx>, - tcx: TyCtxt<'tcx>, - body: &Body<'tcx>, -) -> Option<Ty<'tcx>> { - let terminator = block_data.terminator(); - - // Only bother checking blocks which terminate by switching on a local. - let local = get_discriminant_local(&terminator.kind)?; - - let stmt_before_term = block_data.statements.last()?; - - if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind - && l.as_local() == Some(local) - { - let ty = place.ty(body, tcx).ty; - if ty.is_enum() { - return Some(ty); - } - } - - None -} - -fn variant_discriminants<'tcx>( - layout: &TyAndLayout<'tcx>, - ty: Ty<'tcx>, - tcx: TyCtxt<'tcx>, -) -> FxHashSet<u128> { - match &layout.variants { - Variants::Single { index } => { - let mut res = FxHashSet::default(); - res.insert( - ty.discriminant_for_variant(tcx, *index) - .map_or(index.as_u32() as u128, |discr| discr.val), - ); - res - } - Variants::Multiple { variants, .. } => variants - .iter_enumerated() - .filter_map(|(idx, layout)| { - (layout.abi != Abi::Uninhabited) - .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) - }) - .collect(), - } -} - -impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() > 0 - } - - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - trace!("UninhabitedEnumBranching starting for {:?}", body.source); - - let mut removable_switchs = Vec::new(); - - for (bb, bb_data) in body.basic_blocks.iter_enumerated() { - trace!("processing block {:?}", bb); - - if bb_data.is_cleanup { - continue; - } - - let Some(discriminant_ty) = get_switched_on_type(bb_data, tcx, body) else { continue }; - - let layout = tcx.layout_of( - tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty), - ); - - let allowed_variants = if let Ok(layout) = layout { - variant_discriminants(&layout, discriminant_ty, tcx) - } else { - continue; - }; - - trace!("allowed_variants = {:?}", allowed_variants); - - let terminator = bb_data.terminator(); - let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() }; - - let mut reachable_count = 0; - for (index, (val, _)) in targets.iter().enumerate() { - if allowed_variants.contains(&val) { - reachable_count += 1; - } else { - removable_switchs.push((bb, index)); - } - } - - if reachable_count == allowed_variants.len() { - removable_switchs.push((bb, targets.iter().count())); - } - } - - if removable_switchs.is_empty() { - return; - } - - let new_block = BasicBlockData::new(Some(Terminator { - source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info, - kind: TerminatorKind::Unreachable, - })); - let unreachable_block = body.basic_blocks.as_mut().push(new_block); - - for (bb, index) in removable_switchs { - let bb = &mut body.basic_blocks.as_mut()[bb]; - let terminator = bb.terminator_mut(); - let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() }; - targets.all_targets_mut()[index] = unreachable_block; - } - } -} diff --git a/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs b/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs new file mode 100644 index 00000000000..1404a45f4d2 --- /dev/null +++ b/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs @@ -0,0 +1,209 @@ +//! A pass that eliminates branches on uninhabited or unreachable enum variants. + +use crate::MirPass; +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::bug; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind, + TerminatorKind, +}; +use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_target::abi::{Abi, Variants}; + +pub struct UnreachableEnumBranching; + +fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> { + if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator { + p.as_local() + } else { + None + } +} + +/// If the basic block terminates by switching on a discriminant, this returns the `Ty` the +/// discriminant is read from. Otherwise, returns None. +fn get_switched_on_type<'tcx>( + block_data: &BasicBlockData<'tcx>, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, +) -> Option<Ty<'tcx>> { + let terminator = block_data.terminator(); + + // Only bother checking blocks which terminate by switching on a local. + let local = get_discriminant_local(&terminator.kind)?; + + let stmt_before_term = block_data.statements.last()?; + + if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind + && l.as_local() == Some(local) + { + let ty = place.ty(body, tcx).ty; + if ty.is_enum() { + return Some(ty); + } + } + + None +} + +fn variant_discriminants<'tcx>( + layout: &TyAndLayout<'tcx>, + ty: Ty<'tcx>, + tcx: TyCtxt<'tcx>, +) -> FxHashSet<u128> { + match &layout.variants { + Variants::Single { index } => { + let mut res = FxHashSet::default(); + res.insert( + ty.discriminant_for_variant(tcx, *index) + .map_or(index.as_u32() as u128, |discr| discr.val), + ); + res + } + Variants::Multiple { variants, .. } => variants + .iter_enumerated() + .filter_map(|(idx, layout)| { + (layout.abi != Abi::Uninhabited) + .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) + }) + .collect(), + } +} + +impl<'tcx> MirPass<'tcx> for UnreachableEnumBranching { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("UnreachableEnumBranching starting for {:?}", body.source); + + let mut unreachable_targets = Vec::new(); + let mut patch = MirPatch::new(body); + + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { + trace!("processing block {:?}", bb); + + if bb_data.is_cleanup { + continue; + } + + let Some(discriminant_ty) = get_switched_on_type(bb_data, tcx, body) else { continue }; + + let layout = tcx.layout_of( + tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty), + ); + + let mut allowed_variants = if let Ok(layout) = layout { + // Find allowed variants based on uninhabited. + variant_discriminants(&layout, discriminant_ty, tcx) + } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) { + // If there are some generics, we can still get the allowed variants. + variant_range + .map(|variant| { + discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val + }) + .collect() + } else { + continue; + }; + + trace!("allowed_variants = {:?}", allowed_variants); + + unreachable_targets.clear(); + let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else { + bug!() + }; + + for (index, (val, _)) in targets.iter().enumerate() { + if !allowed_variants.remove(&val) { + unreachable_targets.push(index); + } + } + let otherwise_is_empty_unreachable = + body.basic_blocks[targets.otherwise()].is_empty_unreachable(); + fn check_successors(basic_blocks: &BasicBlocks<'_>, bb: BasicBlock) -> bool { + // After resolving https://github.com/llvm/llvm-project/issues/78578, + // We can remove this check. + // The main issue here is that `early-tailduplication` causes compile time overhead + // and potential performance problems. + // Simply put, when encounter a switch (indirect branch) statement, + // `early-tailduplication` tries to duplicate the switch branch statement with BB + // into (each) predecessors. This makes CFG very complex. + // We can understand it as it transforms the following code + // ```rust + // match a { ... many cases }; + // match b { ... many cases }; + // ``` + // into + // ```rust + // match a { ... many match b { goto BB cases } } + // ... BB cases + // ``` + // Abandon this transformation when it is possible (the best effort) + // to encounter the problem. + let mut successors = basic_blocks[bb].terminator().successors(); + let Some(first_successor) = successors.next() else { return true }; + if successors.next().is_some() { + return true; + } + if let TerminatorKind::SwitchInt { .. } = + &basic_blocks[first_successor].terminator().kind + { + return false; + }; + true + } + // If and only if there is a variant that does not have a branch set, + // change the current of otherwise as the variant branch and set otherwise to unreachable. + // It transforms following code + // ```rust + // match c { + // Ordering::Less => 1, + // Ordering::Equal => 2, + // _ => 3, + // } + // ``` + // to + // ```rust + // match c { + // Ordering::Less => 1, + // Ordering::Equal => 2, + // Ordering::Greater => 3, + // } + // ``` + let otherwise_is_last_variant = !otherwise_is_empty_unreachable + && allowed_variants.len() == 1 + // Despite the LLVM issue, we hope that small enum can still be transformed. + // This is valuable for both `a <= b` and `if let Some/Ok(v)`. + && (targets.all_targets().len() <= 3 + || check_successors(&body.basic_blocks, targets.otherwise())); + let replace_otherwise_to_unreachable = otherwise_is_last_variant + || (!otherwise_is_empty_unreachable && allowed_variants.is_empty()); + + if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable { + continue; + } + + let unreachable_block = patch.unreachable_no_cleanup_block(); + let mut targets = targets.clone(); + if replace_otherwise_to_unreachable { + if otherwise_is_last_variant { + // We have checked that `allowed_variants` has only one element. + #[allow(rustc::potential_query_instability)] + let last_variant = *allowed_variants.iter().next().unwrap(); + targets.add_target(last_variant, targets.otherwise()); + } + unreachable_targets.push(targets.iter().count()); + } + for index in unreachable_targets.iter() { + targets.all_targets_mut()[*index] = unreachable_block; + } + patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() }); + } + + patch.apply(body); + } +} diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index bff59aa6023..a6c3c3b189e 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -3,6 +3,7 @@ //! post-order traversal of the blocks. use rustc_data_structures::fx::FxHashSet; +use rustc_middle::bug; use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; @@ -14,11 +15,7 @@ pub struct UnreachablePropagation; impl MirPass<'_> for UnreachablePropagation { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable. - - // FIXME(#116171) Coverage gets confused by MIR passes that can remove all - // coverage statements from an instrumented function. This pass can be - // re-enabled when coverage codegen is robust against that happening. - sess.mir_opt_level() >= 2 && !sess.instrument_coverage() + sess.mir_opt_level() >= 2 } fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { diff --git a/compiler/rustc_mir_transform/src/validate.rs b/compiler/rustc_mir_transform/src/validate.rs new file mode 100644 index 00000000000..66cc65de647 --- /dev/null +++ b/compiler/rustc_mir_transform/src/validate.rs @@ -0,0 +1,1406 @@ +//! Validates the MIR to ensure that invariants are upheld. + +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_index::bit_set::BitSet; +use rustc_index::IndexVec; +use rustc_infer::traits::Reveal; +use rustc_middle::mir::coverage::CoverageKind; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeVisitableExt, Variance}; +use rustc_middle::{bug, span_bug}; +use rustc_target::abi::{Size, FIRST_VARIANT}; +use rustc_target::spec::abi::Abi; + +use crate::util::is_within_packed; + +use crate::util::relate_types; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum EdgeKind { + Unwind, + Normal, +} + +pub struct Validator { + /// Describes at which point in the pipeline this validation is happening. + pub when: String, + /// The phase for which we are upholding the dialect. If the given phase forbids a specific + /// element, this validator will now emit errors if that specific element is encountered. + /// Note that phases that change the dialect cause all *following* phases to check the + /// invariants of the new dialect. A phase that changes dialects never checks the new invariants + /// itself. + pub mir_phase: MirPhase, +} + +impl<'tcx> MirPass<'tcx> for Validator { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // FIXME(JakobDegen): These bodies never instantiated in codegend anyway, so it's not + // terribly important that they pass the validator. However, I think other passes might + // still see them, in which case they might be surprised. It would probably be better if we + // didn't put this through the MIR pipeline at all. + if matches!(body.source.instance, InstanceDef::Intrinsic(..) | InstanceDef::Virtual(..)) { + return; + } + let def_id = body.source.def_id(); + let mir_phase = self.mir_phase; + let param_env = match mir_phase.reveal() { + Reveal::UserFacing => tcx.param_env(def_id), + Reveal::All => tcx.param_env_reveal_all_normalized(def_id), + }; + + let can_unwind = if mir_phase <= MirPhase::Runtime(RuntimePhase::Initial) { + // In this case `AbortUnwindingCalls` haven't yet been executed. + true + } else if !tcx.def_kind(def_id).is_fn_like() { + true + } else { + let body_ty = tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), + ty::Closure(..) => Abi::RustCall, + ty::CoroutineClosure(..) => Abi::RustCall, + ty::Coroutine(..) => Abi::Rust, + // No need to do MIR validation on error bodies + ty::Error(_) => return, + _ => { + span_bug!(body.span, "unexpected body ty: {:?} phase {:?}", body_ty, mir_phase) + } + }; + + ty::layout::fn_can_unwind(tcx, Some(def_id), body_abi) + }; + + let mut cfg_checker = CfgChecker { + when: &self.when, + body, + tcx, + mir_phase, + unwind_edge_count: 0, + reachable_blocks: traversal::reachable_as_bitset(body), + value_cache: FxHashSet::default(), + can_unwind, + }; + cfg_checker.visit_body(body); + cfg_checker.check_cleanup_control_flow(); + + // Also run the TypeChecker. + for (location, msg) in validate_types(tcx, self.mir_phase, param_env, body, body) { + cfg_checker.fail(location, msg); + } + + if let MirPhase::Runtime(_) = body.phase { + if let ty::InstanceDef::Item(_) = body.source.instance { + if body.has_free_regions() { + cfg_checker.fail( + Location::START, + format!("Free regions in optimized {} MIR", body.phase.name()), + ); + } + } + } + + // Enforce that coroutine-closure layouts are identical. + if let Some(layout) = body.coroutine_layout_raw() + && let Some(by_move_body) = body.coroutine_by_move_body() + && let Some(by_move_layout) = by_move_body.coroutine_layout_raw() + { + // FIXME(async_closures): We could do other validation here? + if layout.variant_fields.len() != by_move_layout.variant_fields.len() { + cfg_checker.fail( + Location::START, + format!( + "Coroutine layout has different number of variant fields from \ + by-move coroutine layout:\n\ + layout: {layout:#?}\n\ + by_move_layout: {by_move_layout:#?}", + ), + ); + } + } + } +} + +struct CfgChecker<'a, 'tcx> { + when: &'a str, + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + mir_phase: MirPhase, + unwind_edge_count: usize, + reachable_blocks: BitSet<BasicBlock>, + value_cache: FxHashSet<u128>, + // If `false`, then the MIR must not contain `UnwindAction::Continue` or + // `TerminatorKind::Resume`. + can_unwind: bool, +} + +impl<'a, 'tcx> CfgChecker<'a, 'tcx> { + #[track_caller] + fn fail(&self, location: Location, msg: impl AsRef<str>) { + // We might see broken MIR when other errors have already occurred. + assert!( + self.tcx.dcx().has_errors().is_some(), + "broken MIR in {:?} ({}) at {:?}:\n{}", + self.body.source.instance, + self.when, + location, + msg.as_ref(), + ); + } + + fn check_edge(&mut self, location: Location, bb: BasicBlock, edge_kind: EdgeKind) { + if bb == START_BLOCK { + self.fail(location, "start block must not have predecessors") + } + if let Some(bb) = self.body.basic_blocks.get(bb) { + let src = self.body.basic_blocks.get(location.block).unwrap(); + match (src.is_cleanup, bb.is_cleanup, edge_kind) { + // Non-cleanup blocks can jump to non-cleanup blocks along non-unwind edges + (false, false, EdgeKind::Normal) + // Cleanup blocks can jump to cleanup blocks along non-unwind edges + | (true, true, EdgeKind::Normal) => {} + // Non-cleanup blocks can jump to cleanup blocks along unwind edges + (false, true, EdgeKind::Unwind) => { + self.unwind_edge_count += 1; + } + // All other jumps are invalid + _ => { + self.fail( + location, + format!( + "{:?} edge to {:?} violates unwind invariants (cleanup {:?} -> {:?})", + edge_kind, + bb, + src.is_cleanup, + bb.is_cleanup, + ) + ) + } + } + } else { + self.fail(location, format!("encountered jump to invalid basic block {bb:?}")) + } + } + + fn check_cleanup_control_flow(&self) { + if self.unwind_edge_count <= 1 { + return; + } + let doms = self.body.basic_blocks.dominators(); + let mut post_contract_node = FxHashMap::default(); + // Reusing the allocation across invocations of the closure + let mut dom_path = vec![]; + let mut get_post_contract_node = |mut bb| { + let root = loop { + if let Some(root) = post_contract_node.get(&bb) { + break *root; + } + let parent = doms.immediate_dominator(bb).unwrap(); + dom_path.push(bb); + if !self.body.basic_blocks[parent].is_cleanup { + break bb; + } + bb = parent; + }; + for bb in dom_path.drain(..) { + post_contract_node.insert(bb, root); + } + root + }; + + let mut parent = IndexVec::from_elem(None, &self.body.basic_blocks); + for (bb, bb_data) in self.body.basic_blocks.iter_enumerated() { + if !bb_data.is_cleanup || !self.reachable_blocks.contains(bb) { + continue; + } + let bb = get_post_contract_node(bb); + for s in bb_data.terminator().successors() { + let s = get_post_contract_node(s); + if s == bb { + continue; + } + let parent = &mut parent[bb]; + match parent { + None => { + *parent = Some(s); + } + Some(e) if *e == s => (), + Some(e) => self.fail( + Location { block: bb, statement_index: 0 }, + format!( + "Cleanup control flow violation: The blocks dominated by {:?} have edges to both {:?} and {:?}", + bb, + s, + *e + ) + ), + } + } + } + + // Check for cycles + let mut stack = FxHashSet::default(); + for i in 0..parent.len() { + let mut bb = BasicBlock::from_usize(i); + stack.clear(); + stack.insert(bb); + loop { + let Some(parent) = parent[bb].take() else { break }; + let no_cycle = stack.insert(parent); + if !no_cycle { + self.fail( + Location { block: bb, statement_index: 0 }, + format!( + "Cleanup control flow violation: Cycle involving edge {bb:?} -> {parent:?}", + ), + ); + break; + } + bb = parent; + } + } + } + + fn check_unwind_edge(&mut self, location: Location, unwind: UnwindAction) { + let is_cleanup = self.body.basic_blocks[location.block].is_cleanup; + match unwind { + UnwindAction::Cleanup(unwind) => { + if is_cleanup { + self.fail(location, "`UnwindAction::Cleanup` in cleanup block"); + } + self.check_edge(location, unwind, EdgeKind::Unwind); + } + UnwindAction::Continue => { + if is_cleanup { + self.fail(location, "`UnwindAction::Continue` in cleanup block"); + } + + if !self.can_unwind { + self.fail(location, "`UnwindAction::Continue` in no-unwind function"); + } + } + UnwindAction::Terminate(UnwindTerminateReason::InCleanup) => { + if !is_cleanup { + self.fail( + location, + "`UnwindAction::Terminate(InCleanup)` in a non-cleanup block", + ); + } + } + // These are allowed everywhere. + UnwindAction::Unreachable | UnwindAction::Terminate(UnwindTerminateReason::Abi) => (), + } + } + + fn is_critical_call_edge(&self, target: Option<BasicBlock>, unwind: UnwindAction) -> bool { + let Some(target) = target else { return false }; + matches!(unwind, UnwindAction::Cleanup(_) | UnwindAction::Terminate(_)) + && self.body.basic_blocks.predecessors()[target].len() > 1 + } +} + +impl<'a, 'tcx> Visitor<'tcx> for CfgChecker<'a, 'tcx> { + fn visit_local(&mut self, local: Local, _context: PlaceContext, location: Location) { + if self.body.local_decls.get(local).is_none() { + self.fail( + location, + format!("local {local:?} has no corresponding declaration in `body.local_decls`"), + ); + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::AscribeUserType(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`AscribeUserType` should have been removed after drop lowering phase", + ); + } + } + StatementKind::FakeRead(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FakeRead` should have been removed after drop lowering phase", + ); + } + } + StatementKind::SetDiscriminant { .. } => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`SetDiscriminant`is not allowed until deaggregation"); + } + } + StatementKind::Deinit(..) => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Deinit`is not allowed until deaggregation"); + } + } + StatementKind::Retag(kind, _) => { + // FIXME(JakobDegen) The validator should check that `self.mir_phase < + // DropsLowered`. However, this causes ICEs with generation of drop shims, which + // seem to fail to set their `MirPhase` correctly. + if matches!(kind, RetagKind::TwoPhase) { + self.fail(location, format!("explicit `{kind:?}` is forbidden")); + } + } + StatementKind::Coverage(kind) => { + if self.mir_phase >= MirPhase::Analysis(AnalysisPhase::PostCleanup) + && let CoverageKind::BlockMarker { .. } | CoverageKind::SpanMarker { .. } = kind + { + self.fail( + location, + format!("{kind:?} should have been removed after analysis"), + ); + } + } + StatementKind::Assign(..) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::Nop => {} + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + match &terminator.kind { + TerminatorKind::Goto { target } => { + self.check_edge(location, *target, EdgeKind::Normal); + } + TerminatorKind::SwitchInt { targets, discr: _ } => { + for (_, target) in targets.iter() { + self.check_edge(location, target, EdgeKind::Normal); + } + self.check_edge(location, targets.otherwise(), EdgeKind::Normal); + + self.value_cache.clear(); + self.value_cache.extend(targets.iter().map(|(value, _)| value)); + let has_duplicates = targets.iter().len() != self.value_cache.len(); + if has_duplicates { + self.fail( + location, + format!( + "duplicated values in `SwitchInt` terminator: {:?}", + terminator.kind, + ), + ); + } + } + TerminatorKind::Drop { target, unwind, .. } => { + self.check_edge(location, *target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::Call { args, destination, target, unwind, .. } => { + if let Some(target) = target { + self.check_edge(location, *target, EdgeKind::Normal); + } + self.check_unwind_edge(location, *unwind); + + // The code generation assumes that there are no critical call edges. The assumption + // is used to simplify inserting code that should be executed along the return edge + // from the call. FIXME(tmiasko): Since this is a strictly code generation concern, + // the code generation should be responsible for handling it. + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Optimized) + && self.is_critical_call_edge(*target, *unwind) + { + self.fail( + location, + format!( + "encountered critical edge in `Call` terminator {:?}", + terminator.kind, + ), + ); + } + + // The call destination place and Operand::Move place used as an argument might be + // passed by a reference to the callee. Consequently they cannot be packed. + if is_within_packed(self.tcx, &self.body.local_decls, *destination).is_some() { + // This is bad! The callee will expect the memory to be aligned. + self.fail( + location, + format!( + "encountered packed place in `Call` terminator destination: {:?}", + terminator.kind, + ), + ); + } + for arg in args { + if let Operand::Move(place) = &arg.node { + if is_within_packed(self.tcx, &self.body.local_decls, *place).is_some() { + // This is bad! The callee will expect the memory to be aligned. + self.fail( + location, + format!( + "encountered `Move` of a packed place in `Call` terminator: {:?}", + terminator.kind, + ), + ); + } + } + } + } + TerminatorKind::Assert { target, unwind, .. } => { + self.check_edge(location, *target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::Yield { resume, drop, .. } => { + if self.body.coroutine.is_none() { + self.fail(location, "`Yield` cannot appear outside coroutine bodies"); + } + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Yield` should have been replaced by coroutine lowering"); + } + self.check_edge(location, *resume, EdgeKind::Normal); + if let Some(drop) = drop { + self.check_edge(location, *drop, EdgeKind::Normal); + } + } + TerminatorKind::FalseEdge { real_target, imaginary_target } => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FalseEdge` should have been removed after drop elaboration", + ); + } + self.check_edge(location, *real_target, EdgeKind::Normal); + self.check_edge(location, *imaginary_target, EdgeKind::Normal); + } + TerminatorKind::FalseUnwind { real_target, unwind } => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FalseUnwind` should have been removed after drop elaboration", + ); + } + self.check_edge(location, *real_target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::InlineAsm { targets, unwind, .. } => { + for &target in targets { + self.check_edge(location, target, EdgeKind::Normal); + } + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::CoroutineDrop => { + if self.body.coroutine.is_none() { + self.fail(location, "`CoroutineDrop` cannot appear outside coroutine bodies"); + } + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`CoroutineDrop` should have been replaced by coroutine lowering", + ); + } + } + TerminatorKind::UnwindResume => { + let bb = location.block; + if !self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `UnwindResume` from non-cleanup basic block") + } + if !self.can_unwind { + self.fail(location, "Cannot `UnwindResume` in a function that cannot unwind") + } + } + TerminatorKind::UnwindTerminate(_) => { + let bb = location.block; + if !self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `UnwindTerminate` from non-cleanup basic block") + } + } + TerminatorKind::Return => { + let bb = location.block; + if self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `Return` from cleanup basic block") + } + } + TerminatorKind::Unreachable => {} + } + + self.super_terminator(terminator, location); + } + + fn visit_source_scope(&mut self, scope: SourceScope) { + if self.body.source_scopes.get(scope).is_none() { + self.tcx.dcx().span_bug( + self.body.span, + format!( + "broken MIR in {:?} ({}):\ninvalid source scope {:?}", + self.body.source.instance, self.when, scope, + ), + ); + } + } +} + +/// A faster version of the validation pass that only checks those things which may break when +/// instantiating any generic parameters. +/// +/// `caller_body` is used to detect cycles in MIR inlining and MIR validation before +/// `optimized_mir` is available. +pub fn validate_types<'tcx>( + tcx: TyCtxt<'tcx>, + mir_phase: MirPhase, + param_env: ty::ParamEnv<'tcx>, + body: &Body<'tcx>, + caller_body: &Body<'tcx>, +) -> Vec<(Location, String)> { + let mut type_checker = + TypeChecker { body, caller_body, tcx, param_env, mir_phase, failures: Vec::new() }; + type_checker.visit_body(body); + type_checker.failures +} + +struct TypeChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + caller_body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + mir_phase: MirPhase, + failures: Vec<(Location, String)>, +} + +impl<'a, 'tcx> TypeChecker<'a, 'tcx> { + fn fail(&mut self, location: Location, msg: impl Into<String>) { + self.failures.push((location, msg.into())); + } + + /// Check if src can be assigned into dest. + /// This is not precise, it will accept some incorrect assignments. + fn mir_assign_valid_types(&self, src: Ty<'tcx>, dest: Ty<'tcx>) -> bool { + // Fast path before we normalize. + if src == dest { + // Equal types, all is good. + return true; + } + + // We sometimes have to use `defining_opaque_types` for subtyping + // to succeed here and figuring out how exactly that should work + // is annoying. It is harmless enough to just not validate anything + // in that case. We still check this after analysis as all opaque + // types have been revealed at this point. + if (src, dest).has_opaque_types() { + return true; + } + + // After borrowck subtyping should be fully explicit via + // `Subtype` projections. + let variance = if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + Variance::Invariant + } else { + Variance::Covariant + }; + + crate::util::relate_types(self.tcx, self.param_env, variance, src, dest) + } +} + +impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + // This check is somewhat expensive, so only run it when -Zvalidate-mir is passed. + if self.tcx.sess.opts.unstable_opts.validate_mir + && self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) + { + // `Operand::Copy` is only supposed to be used with `Copy` types. + if let Operand::Copy(place) = operand { + let ty = place.ty(&self.body.local_decls, self.tcx).ty; + + if !ty.is_copy_modulo_regions(self.tcx, self.param_env) { + self.fail(location, format!("`Operand::Copy` with non-`Copy` type {ty}")); + } + } + } + + self.super_operand(operand, location); + } + + fn visit_projection_elem( + &mut self, + place_ref: PlaceRef<'tcx>, + elem: PlaceElem<'tcx>, + context: PlaceContext, + location: Location, + ) { + match elem { + ProjectionElem::OpaqueCast(ty) + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) => + { + self.fail( + location, + format!("explicit opaque type cast to `{ty}` after `RevealAll`"), + ) + } + ProjectionElem::Index(index) => { + let index_ty = self.body.local_decls[index].ty; + if index_ty != self.tcx.types.usize { + self.fail(location, format!("bad index ({index_ty:?} != usize)")) + } + } + ProjectionElem::Deref + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::PostCleanup) => + { + let base_ty = place_ref.ty(&self.body.local_decls, self.tcx).ty; + + if base_ty.is_box() { + self.fail( + location, + format!("{base_ty:?} dereferenced after ElaborateBoxDerefs"), + ) + } + } + ProjectionElem::Field(f, ty) => { + let parent_ty = place_ref.ty(&self.body.local_decls, self.tcx); + let fail_out_of_bounds = |this: &mut Self, location| { + this.fail(location, format!("Out of bounds field {f:?} for {parent_ty:?}")); + }; + let check_equal = |this: &mut Self, location, f_ty| { + if !this.mir_assign_valid_types(ty, f_ty) { + this.fail( + location, + format!( + "Field projection `{place_ref:?}.{f:?}` specified type `{ty:?}`, but actual type is `{f_ty:?}`" + ) + ) + } + }; + + let kind = match parent_ty.ty.kind() { + &ty::Alias(ty::Opaque, ty::AliasTy { def_id, args, .. }) => { + self.tcx.type_of(def_id).instantiate(self.tcx, args).kind() + } + kind => kind, + }; + + match kind { + ty::Tuple(fields) => { + let Some(f_ty) = fields.get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, *f_ty); + } + ty::Adt(adt_def, args) => { + // see <https://github.com/rust-lang/rust/blob/7601adcc764d42c9f2984082b49948af652df986/compiler/rustc_middle/src/ty/layout.rs#L861-L864> + if Some(adt_def.did()) == self.tcx.lang_items().dyn_metadata() { + self.fail( + location, + format!("You can't project to field {f:?} of `DynMetadata` because \ + layout is weird and thinks it doesn't have fields."), + ); + } + + let var = parent_ty.variant_index.unwrap_or(FIRST_VARIANT); + let Some(field) = adt_def.variant(var).fields.get(f) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, field.ty(self.tcx, args)); + } + ty::Closure(_, args) => { + let args = args.as_closure(); + let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, f_ty); + } + ty::CoroutineClosure(_, args) => { + let args = args.as_coroutine_closure(); + let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, f_ty); + } + &ty::Coroutine(def_id, args) => { + let f_ty = if let Some(var) = parent_ty.variant_index { + // If we're currently validating an inlined copy of this body, + // then it will no longer be parameterized over the original + // args of the coroutine. Otherwise, we prefer to use this body + // since we may be in the process of computing this MIR in the + // first place. + let layout = if def_id == self.caller_body.source.def_id() { + // FIXME: This is not right for async closures. + self.caller_body.coroutine_layout_raw() + } else { + self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) + }; + + let Some(layout) = layout else { + self.fail( + location, + format!("No coroutine layout for {parent_ty:?}"), + ); + return; + }; + + let Some(&local) = layout.variant_fields[var].get(f) else { + fail_out_of_bounds(self, location); + return; + }; + + let Some(f_ty) = layout.field_tys.get(local) else { + self.fail( + location, + format!("Out of bounds local {local:?} for {parent_ty:?}"), + ); + return; + }; + + ty::EarlyBinder::bind(f_ty.ty).instantiate(self.tcx, args) + } else { + let Some(&f_ty) = args.as_coroutine().prefix_tys().get(f.index()) + else { + fail_out_of_bounds(self, location); + return; + }; + + f_ty + }; + + check_equal(self, location, f_ty); + } + _ => { + self.fail(location, format!("{:?} does not have fields", parent_ty.ty)); + } + } + } + ProjectionElem::Subtype(ty) => { + if !relate_types( + self.tcx, + self.param_env, + Variance::Covariant, + ty, + place_ref.ty(&self.body.local_decls, self.tcx).ty, + ) { + self.fail( + location, + format!( + "Failed subtyping {ty:#?} and {:#?}", + place_ref.ty(&self.body.local_decls, self.tcx).ty + ), + ) + } + } + _ => {} + } + self.super_projection_elem(place_ref, elem, context, location); + } + + fn visit_var_debug_info(&mut self, debuginfo: &VarDebugInfo<'tcx>) { + if let Some(box VarDebugInfoFragment { ty, ref projection }) = debuginfo.composite { + if ty.is_union() || ty.is_enum() { + self.fail( + START_BLOCK.start_location(), + format!("invalid type {ty:?} in debuginfo for {:?}", debuginfo.name), + ); + } + if projection.is_empty() { + self.fail( + START_BLOCK.start_location(), + format!("invalid empty projection in debuginfo for {:?}", debuginfo.name), + ); + } + if projection.iter().any(|p| !matches!(p, PlaceElem::Field(..))) { + self.fail( + START_BLOCK.start_location(), + format!( + "illegal projection {:?} in debuginfo for {:?}", + projection, debuginfo.name + ), + ); + } + } + match debuginfo.value { + VarDebugInfoContents::Const(_) => {} + VarDebugInfoContents::Place(place) => { + if place.projection.iter().any(|p| !p.can_use_in_debuginfo()) { + self.fail( + START_BLOCK.start_location(), + format!("illegal place {:?} in debuginfo for {:?}", place, debuginfo.name), + ); + } + } + } + self.super_var_debug_info(debuginfo); + } + + fn visit_place(&mut self, place: &Place<'tcx>, cntxt: PlaceContext, location: Location) { + // Set off any `bug!`s in the type computation code + let _ = place.ty(&self.body.local_decls, self.tcx); + + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) + && place.projection.len() > 1 + && cntxt != PlaceContext::NonUse(NonUseContext::VarDebugInfo) + && place.projection[1..].contains(&ProjectionElem::Deref) + { + self.fail(location, format!("{place:?}, has deref at the wrong place")); + } + + self.super_place(place, cntxt, location); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + macro_rules! check_kinds { + ($t:expr, $text:literal, $typat:pat) => { + if !matches!(($t).kind(), $typat) { + self.fail(location, format!($text, $t)); + } + }; + } + match rvalue { + Rvalue::Use(_) | Rvalue::CopyForDeref(_) => {} + Rvalue::Aggregate(kind, fields) => match **kind { + AggregateKind::Tuple => {} + AggregateKind::Array(dest) => { + for src in fields { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "array field has the wrong type"); + } + } + } + AggregateKind::Adt(def_id, idx, args, _, Some(field)) => { + let adt_def = self.tcx.adt_def(def_id); + assert!(adt_def.is_union()); + assert_eq!(idx, FIRST_VARIANT); + let dest_ty = self.tcx.normalize_erasing_regions( + self.param_env, + adt_def.non_enum_variant().fields[field].ty(self.tcx, args), + ); + if fields.len() == 1 { + let src_ty = fields.raw[0].ty(self.body, self.tcx); + if !self.mir_assign_valid_types(src_ty, dest_ty) { + self.fail(location, "union field has the wrong type"); + } + } else { + self.fail(location, "unions should have one initialized field"); + } + } + AggregateKind::Adt(def_id, idx, args, _, None) => { + let adt_def = self.tcx.adt_def(def_id); + assert!(!adt_def.is_union()); + let variant = &adt_def.variants()[idx]; + if variant.fields.len() != fields.len() { + self.fail(location, "adt has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, &variant.fields) { + let dest_ty = self + .tcx + .normalize_erasing_regions(self.param_env, dest.ty(self.tcx, args)); + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest_ty) { + self.fail(location, "adt field has the wrong type"); + } + } + } + AggregateKind::Closure(_, args) => { + let upvars = args.as_closure().upvar_tys(); + if upvars.len() != fields.len() { + self.fail(location, "closure has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "closure field has the wrong type"); + } + } + } + AggregateKind::Coroutine(_, args) => { + let upvars = args.as_coroutine().upvar_tys(); + if upvars.len() != fields.len() { + self.fail(location, "coroutine has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "coroutine field has the wrong type"); + } + } + } + AggregateKind::CoroutineClosure(_, args) => { + let upvars = args.as_coroutine_closure().upvar_tys(); + if upvars.len() != fields.len() { + self.fail( + location, + "coroutine-closure has the wrong number of initialized fields", + ); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "coroutine-closure field has the wrong type"); + } + } + } + AggregateKind::RawPtr(pointee_ty, mutability) => { + if !matches!(self.mir_phase, MirPhase::Runtime(_)) { + // It would probably be fine to support this in earlier phases, + // but at the time of writing it's only ever introduced from intrinsic lowering, + // so earlier things just `bug!` on it. + self.fail(location, "RawPtr should be in runtime MIR only"); + } + + if fields.len() != 2 { + self.fail(location, "raw pointer aggregate must have 2 fields"); + } else { + let data_ptr_ty = fields.raw[0].ty(self.body, self.tcx); + let metadata_ty = fields.raw[1].ty(self.body, self.tcx); + if let ty::RawPtr(in_pointee, in_mut) = data_ptr_ty.kind() { + if *in_mut != mutability { + self.fail(location, "input and output mutability must match"); + } + + // FIXME: check `Thin` instead of `Sized` + if !in_pointee.is_sized(self.tcx, self.param_env) { + self.fail(location, "input pointer must be thin"); + } + } else { + self.fail( + location, + "first operand to raw pointer aggregate must be a raw pointer", + ); + } + + // FIXME: Check metadata more generally + if pointee_ty.is_slice() { + if !self.mir_assign_valid_types(metadata_ty, self.tcx.types.usize) { + self.fail(location, "slice metadata must be usize"); + } + } else if pointee_ty.is_sized(self.tcx, self.param_env) { + if metadata_ty != self.tcx.types.unit { + self.fail(location, "metadata for pointer-to-thin must be unit"); + } + } + } + } + }, + Rvalue::Ref(_, BorrowKind::Fake(_), _) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`Assign` statement with a `Fake` borrow should have been removed in runtime MIR", + ); + } + } + Rvalue::Ref(..) => {} + Rvalue::Len(p) => { + let pty = p.ty(&self.body.local_decls, self.tcx).ty; + check_kinds!( + pty, + "Cannot compute length of non-array type {:?}", + ty::Array(..) | ty::Slice(..) + ); + } + Rvalue::BinaryOp(op, vals) => { + use BinOp::*; + let a = vals.0.ty(&self.body.local_decls, self.tcx); + let b = vals.1.ty(&self.body.local_decls, self.tcx); + if crate::util::binop_right_homogeneous(*op) { + if let Eq | Lt | Le | Ne | Ge | Gt = op { + // The function pointer types can have lifetimes + if !self.mir_assign_valid_types(a, b) { + self.fail( + location, + format!("Cannot {op:?} compare incompatible types {a:?} and {b:?}"), + ); + } + } else if a != b { + self.fail( + location, + format!( + "Cannot perform binary op {op:?} on unequal types {a:?} and {b:?}" + ), + ); + } + } + + match op { + Offset => { + check_kinds!(a, "Cannot offset non-pointer type {:?}", ty::RawPtr(..)); + if b != self.tcx.types.isize && b != self.tcx.types.usize { + self.fail(location, format!("Cannot offset by non-isize type {b:?}")); + } + } + Eq | Lt | Le | Ne | Ge | Gt => { + for x in [a, b] { + check_kinds!( + x, + "Cannot {op:?} compare type {:?}", + ty::Bool + | ty::Char + | ty::Int(..) + | ty::Uint(..) + | ty::Float(..) + | ty::RawPtr(..) + | ty::FnPtr(..) + ) + } + } + Cmp => { + for x in [a, b] { + check_kinds!( + x, + "Cannot three-way compare non-integer type {:?}", + ty::Char | ty::Uint(..) | ty::Int(..) + ) + } + } + AddUnchecked | AddWithOverflow | SubUnchecked | SubWithOverflow + | MulUnchecked | MulWithOverflow | Shl | ShlUnchecked | Shr | ShrUnchecked => { + for x in [a, b] { + check_kinds!( + x, + "Cannot {op:?} non-integer type {:?}", + ty::Uint(..) | ty::Int(..) + ) + } + } + BitAnd | BitOr | BitXor => { + for x in [a, b] { + check_kinds!( + x, + "Cannot perform bitwise op {op:?} on type {:?}", + ty::Uint(..) | ty::Int(..) | ty::Bool + ) + } + } + Add | Sub | Mul | Div | Rem => { + for x in [a, b] { + check_kinds!( + x, + "Cannot perform arithmetic {op:?} on type {:?}", + ty::Uint(..) | ty::Int(..) | ty::Float(..) + ) + } + } + } + } + Rvalue::UnaryOp(op, operand) => { + let a = operand.ty(&self.body.local_decls, self.tcx); + match op { + UnOp::Neg => { + check_kinds!(a, "Cannot negate type {:?}", ty::Int(..) | ty::Float(..)) + } + UnOp::Not => { + check_kinds!( + a, + "Cannot binary not type {:?}", + ty::Int(..) | ty::Uint(..) | ty::Bool + ); + } + } + } + Rvalue::ShallowInitBox(operand, _) => { + let a = operand.ty(&self.body.local_decls, self.tcx); + check_kinds!(a, "Cannot shallow init type {:?}", ty::RawPtr(..)); + } + Rvalue::Cast(kind, operand, target_type) => { + let op_ty = operand.ty(self.body, self.tcx); + match kind { + CastKind::DynStar => { + // FIXME(dyn-star): make sure nothing needs to be done here. + } + // FIXME: Add Checks for these + CastKind::PointerWithExposedProvenance + | CastKind::PointerExposeProvenance + | CastKind::PointerCoercion(_) => {} + CastKind::IntToInt | CastKind::IntToFloat => { + let input_valid = op_ty.is_integral() || op_ty.is_char() || op_ty.is_bool(); + let target_valid = target_type.is_numeric() || target_type.is_char(); + if !input_valid || !target_valid { + self.fail( + location, + format!("Wrong cast kind {kind:?} for the type {op_ty}",), + ); + } + } + CastKind::FnPtrToPtr | CastKind::PtrToPtr => { + if !(op_ty.is_any_ptr() && target_type.is_unsafe_ptr()) { + self.fail(location, "Can't cast {op_ty} into 'Ptr'"); + } + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + if !op_ty.is_floating_point() || !target_type.is_numeric() { + self.fail( + location, + format!( + "Trying to cast non 'Float' as {kind:?} into {target_type:?}" + ), + ); + } + } + CastKind::Transmute => { + if let MirPhase::Runtime(..) = self.mir_phase { + // Unlike `mem::transmute`, a MIR `Transmute` is well-formed + // for any two `Sized` types, just potentially UB to run. + + if !self + .tcx + .normalize_erasing_regions(self.param_env, op_ty) + .is_sized(self.tcx, self.param_env) + { + self.fail( + location, + format!("Cannot transmute from non-`Sized` type {op_ty:?}"), + ); + } + if !self + .tcx + .normalize_erasing_regions(self.param_env, *target_type) + .is_sized(self.tcx, self.param_env) + { + self.fail( + location, + format!("Cannot transmute to non-`Sized` type {target_type:?}"), + ); + } + } else { + self.fail( + location, + format!( + "Transmute is not supported in non-runtime phase {:?}.", + self.mir_phase + ), + ); + } + } + } + } + Rvalue::NullaryOp(NullOp::OffsetOf(indices), container) => { + let fail_out_of_bounds = |this: &mut Self, location, field, ty| { + this.fail(location, format!("Out of bounds field {field:?} for {ty:?}")); + }; + + let mut current_ty = *container; + + for (variant, field) in indices.iter() { + match current_ty.kind() { + ty::Tuple(fields) => { + if variant != FIRST_VARIANT { + self.fail( + location, + format!("tried to get variant {variant:?} of tuple"), + ); + return; + } + let Some(&f_ty) = fields.get(field.as_usize()) else { + fail_out_of_bounds(self, location, field, current_ty); + return; + }; + + current_ty = self.tcx.normalize_erasing_regions(self.param_env, f_ty); + } + ty::Adt(adt_def, args) => { + let Some(field) = adt_def.variant(variant).fields.get(field) else { + fail_out_of_bounds(self, location, field, current_ty); + return; + }; + + let f_ty = field.ty(self.tcx, args); + current_ty = self.tcx.normalize_erasing_regions(self.param_env, f_ty); + } + _ => { + self.fail( + location, + format!("Cannot get offset ({variant:?}, {field:?}) from type {current_ty:?}"), + ); + return; + } + } + } + } + Rvalue::Repeat(_, _) + | Rvalue::ThreadLocalRef(_) + | Rvalue::AddressOf(_, _) + | Rvalue::NullaryOp(NullOp::SizeOf | NullOp::AlignOf | NullOp::UbChecks, _) + | Rvalue::Discriminant(_) => {} + } + self.super_rvalue(rvalue, location); + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(box (dest, rvalue)) => { + // LHS and RHS of the assignment must have the same type. + let left_ty = dest.ty(&self.body.local_decls, self.tcx).ty; + let right_ty = rvalue.ty(&self.body.local_decls, self.tcx); + + if !self.mir_assign_valid_types(right_ty, left_ty) { + self.fail( + location, + format!( + "encountered `{:?}` with incompatible types:\n\ + left-hand side has type: {}\n\ + right-hand side has type: {}", + statement.kind, left_ty, right_ty, + ), + ); + } + if let Rvalue::CopyForDeref(place) = rvalue { + if place.ty(&self.body.local_decls, self.tcx).ty.builtin_deref(true).is_none() { + self.fail( + location, + "`CopyForDeref` should only be used for dereferenceable types", + ) + } + } + } + StatementKind::AscribeUserType(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`AscribeUserType` should have been removed after drop lowering phase", + ); + } + } + StatementKind::FakeRead(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FakeRead` should have been removed after drop lowering phase", + ); + } + } + StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(op)) => { + let ty = op.ty(&self.body.local_decls, self.tcx); + if !ty.is_bool() { + self.fail( + location, + format!("`assume` argument must be `bool`, but got: `{ty}`"), + ); + } + } + StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping( + CopyNonOverlapping { src, dst, count }, + )) => { + let src_ty = src.ty(&self.body.local_decls, self.tcx); + let op_src_ty = if let Some(src_deref) = src_ty.builtin_deref(true) { + src_deref + } else { + self.fail( + location, + format!("Expected src to be ptr in copy_nonoverlapping, got: {src_ty}"), + ); + return; + }; + let dst_ty = dst.ty(&self.body.local_decls, self.tcx); + let op_dst_ty = if let Some(dst_deref) = dst_ty.builtin_deref(true) { + dst_deref + } else { + self.fail( + location, + format!("Expected dst to be ptr in copy_nonoverlapping, got: {dst_ty}"), + ); + return; + }; + // since CopyNonOverlapping is parametrized by 1 type, + // we only need to check that they are equal and not keep an extra parameter. + if !self.mir_assign_valid_types(op_src_ty, op_dst_ty) { + self.fail(location, format!("bad arg ({op_src_ty:?} != {op_dst_ty:?})")); + } + + let op_cnt_ty = count.ty(&self.body.local_decls, self.tcx); + if op_cnt_ty != self.tcx.types.usize { + self.fail(location, format!("bad arg ({op_cnt_ty:?} != usize)")) + } + } + StatementKind::SetDiscriminant { place, .. } => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`SetDiscriminant`is not allowed until deaggregation"); + } + let pty = place.ty(&self.body.local_decls, self.tcx).ty.kind(); + if !matches!(pty, ty::Adt(..) | ty::Coroutine(..) | ty::Alias(ty::Opaque, ..)) { + self.fail( + location, + format!( + "`SetDiscriminant` is only allowed on ADTs and coroutines, not {pty:?}" + ), + ); + } + } + StatementKind::Deinit(..) => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Deinit`is not allowed until deaggregation"); + } + } + StatementKind::Retag(kind, _) => { + // FIXME(JakobDegen) The validator should check that `self.mir_phase < + // DropsLowered`. However, this causes ICEs with generation of drop shims, which + // seem to fail to set their `MirPhase` correctly. + if matches!(kind, RetagKind::TwoPhase) { + self.fail(location, format!("explicit `{kind:?}` is forbidden")); + } + } + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Coverage(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::Nop => {} + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + match &terminator.kind { + TerminatorKind::SwitchInt { targets, discr } => { + let switch_ty = discr.ty(&self.body.local_decls, self.tcx); + + let target_width = self.tcx.sess.target.pointer_width; + + let size = Size::from_bits(match switch_ty.kind() { + ty::Uint(uint) => uint.normalize(target_width).bit_width().unwrap(), + ty::Int(int) => int.normalize(target_width).bit_width().unwrap(), + ty::Char => 32, + ty::Bool => 1, + other => bug!("unhandled type: {:?}", other), + }); + + for (value, _) in targets.iter() { + if Scalar::<()>::try_from_uint(value, size).is_none() { + self.fail( + location, + format!("the value {value:#x} is not a proper {switch_ty:?}"), + ) + } + } + } + TerminatorKind::Call { func, .. } => { + let func_ty = func.ty(&self.body.local_decls, self.tcx); + match func_ty.kind() { + ty::FnPtr(..) | ty::FnDef(..) => {} + _ => self.fail( + location, + format!("encountered non-callable type {func_ty} in `Call` terminator"), + ), + } + } + TerminatorKind::Assert { cond, .. } => { + let cond_ty = cond.ty(&self.body.local_decls, self.tcx); + if cond_ty != self.tcx.types.bool { + self.fail( + location, + format!( + "encountered non-boolean condition of type {cond_ty} in `Assert` terminator" + ), + ); + } + } + TerminatorKind::Goto { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable => {} + } + + self.super_terminator(terminator, location); + } +} | 
