diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
47 files changed, 3150 insertions, 471 deletions
diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs index ba70a4453d6..d43fca3dc7e 100644 --- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -1,5 +1,6 @@ use rustc_ast::InlineAsmOptions; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::ty::layout; use rustc_middle::ty::{self, TyCtxt}; use rustc_target::spec::abi::Abi; diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs index 0af88729887..5199c41c58c 100644 --- a/compiler/rustc_mir_transform/src/check_alignment.rs +++ b/compiler/rustc_mir_transform/src/check_alignment.rs @@ -106,7 +106,7 @@ impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> { } let pointee_ty = - pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty; + pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer"); // Ideally we'd support this in the future, but for now we are limited to sized types. if !pointee_ty.is_sized(self.tcx, self.param_env) { debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty); diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs index a405ed6088d..5f67bd75c48 100644 --- a/compiler/rustc_mir_transform/src/check_packed_ref.rs +++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs @@ -1,5 +1,6 @@ use rustc_middle::mir::visit::{PlaceContext, Visitor}; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::ty::{self, TyCtxt}; use crate::MirLint; diff --git a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs index da82f8de781..48a6a83e146 100644 --- a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs +++ b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs @@ -29,7 +29,7 @@ impl<'tcx> MirPass<'tcx> for CleanupPostBorrowck { for statement in basic_block.statements.iter_mut() { match statement.kind { StatementKind::AscribeUserType(..) - | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Fake, _))) + | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Fake(_), _))) | StatementKind::Coverage( // These kinds of coverage statements are markers inserted during // MIR building, and are not needed after InstrumentCoverage. diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs index 0119b95cced..c1f9313a377 100644 --- a/compiler/rustc_mir_transform/src/copy_prop.rs +++ b/compiler/rustc_mir_transform/src/copy_prop.rs @@ -3,7 +3,6 @@ use rustc_index::IndexSlice; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; -use rustc_mir_dataflow::impls::borrowed_locals; use crate::ssa::SsaLocals; @@ -32,8 +31,8 @@ impl<'tcx> MirPass<'tcx> for CopyProp { } fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let borrowed_locals = borrowed_locals(body); - let ssa = SsaLocals::new(body); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let ssa = SsaLocals::new(tcx, body, param_env); let fully_moved = fully_moved_locals(&ssa, body); debug!(?fully_moved); @@ -51,7 +50,7 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { tcx, copy_classes: ssa.copy_classes(), fully_moved, - borrowed_locals, + borrowed_locals: ssa.borrowed_locals(), storage_to_remove, } .visit_body_preserves_cfg(body); @@ -101,7 +100,7 @@ struct Replacer<'a, 'tcx> { tcx: TyCtxt<'tcx>, fully_moved: BitSet<Local>, storage_to_remove: BitSet<Local>, - borrowed_locals: BitSet<Local>, + borrowed_locals: &'a BitSet<Local>, copy_classes: &'a IndexSlice<Local, Local>, } @@ -112,6 +111,12 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { fn visit_local(&mut self, local: &mut Local, ctxt: PlaceContext, _: Location) { let new_local = self.copy_classes[*local]; + // We must not unify two locals that are borrowed. But this is fine if one is borrowed and + // the other is not. We chose to check the original local, and not the target. That way, if + // the original local is borrowed and the target is not, we do not pessimize the whole class. + if self.borrowed_locals.contains(*local) { + return; + } match ctxt { // Do not modify the local in storage statements. PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {} @@ -122,32 +127,14 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { } } - fn visit_place(&mut self, place: &mut Place<'tcx>, ctxt: PlaceContext, loc: Location) { + fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, loc: Location) { if let Some(new_projection) = self.process_projection(place.projection, loc) { place.projection = self.tcx().mk_place_elems(&new_projection); } - let observes_address = match ctxt { - PlaceContext::NonMutatingUse( - NonMutatingUseContext::SharedBorrow - | NonMutatingUseContext::FakeBorrow - | NonMutatingUseContext::AddressOf, - ) => true, - // For debuginfo, merging locals is ok. - PlaceContext::NonUse(NonUseContext::VarDebugInfo) => { - self.borrowed_locals.contains(place.local) - } - _ => false, - }; - if observes_address && !place.is_indirect() { - // We observe the address of `place.local`. Do not replace it. - } else { - self.visit_local( - &mut place.local, - PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy), - loc, - ) - } + // Any non-mutating use context is ok. + let ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy); + self.visit_local(&mut place.local, ctxt, loc) } fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) { diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index e2a911f0dc7..a3e6e5a5a91 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -70,6 +70,7 @@ use rustc_middle::mir::*; use rustc_middle::ty::CoroutineArgs; use rustc_middle::ty::InstanceDef; use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::{bug, span_bug}; use rustc_mir_dataflow::impls::{ MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, }; @@ -80,6 +81,10 @@ use rustc_span::symbol::sym; use rustc_span::Span; use rustc_target::abi::{FieldIdx, VariantIdx}; use rustc_target::spec::PanicStrategy; +use rustc_trait_selection::infer::TyCtxtInferExt as _; +use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _; +use rustc_trait_selection::traits::ObligationCtxt; +use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode}; use std::{iter, ops}; pub struct StateTransform; @@ -1256,7 +1261,7 @@ fn create_coroutine_drop_shim<'tcx>( } // Replace the return variable - body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(Ty::new_unit(tcx), source_info); + body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info); make_coroutine_state_argument_indirect(tcx, &mut body); @@ -1584,10 +1589,46 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>( let (_, coroutine_layout, _) = compute_layout(liveness_info, body); check_suspend_tys(tcx, &coroutine_layout, body); + check_field_tys_sized(tcx, &coroutine_layout, def_id); Some(coroutine_layout) } +fn check_field_tys_sized<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_layout: &CoroutineLayout<'tcx>, + def_id: LocalDefId, +) { + // No need to check if unsized_locals/unsized_fn_params is disabled, + // since we will error during typeck. + if !tcx.features().unsized_locals && !tcx.features().unsized_fn_params { + return; + } + + let infcx = tcx.infer_ctxt().ignoring_regions().build(); + let param_env = tcx.param_env(def_id); + + let ocx = ObligationCtxt::new(&infcx); + for field_ty in &coroutine_layout.field_tys { + ocx.register_bound( + ObligationCause::new( + field_ty.source_info.span, + def_id, + ObligationCauseCode::SizedCoroutineInterior(def_id), + ), + param_env, + field_ty.ty, + tcx.require_lang_item(hir::LangItem::Sized, Some(field_ty.source_info.span)), + ); + } + + let errors = ocx.select_all_or_error(); + debug!(?errors); + if !errors.is_empty() { + infcx.err_ctxt().report_fulfillment_errors(errors); + } +} + impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let Some(old_yield_ty) = body.yield_ty() else { diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index 3d6c1a95204..10c0567eb4b 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -71,6 +71,7 @@ use rustc_data_structures::unord::UnordMap; use rustc_hir as hir; +use rustc_middle::bug; use rustc_middle::hir::place::{Projection, ProjectionKind}; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::{self, dump_mir, MirPass}; diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index 6e73a476421..b5968517d77 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -4,13 +4,14 @@ use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::graph::DirectedGraph; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::coverage::{CounterId, CovTerm, Expression, ExpressionId, Op}; use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, TraverseCoverageGraphWithLoops}; /// The coverage counter or counter expression associated with a particular /// BCB node or BCB edge. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub(super) enum BcbCounter { Counter { id: CounterId }, Expression { id: ExpressionId }, @@ -34,6 +35,13 @@ impl Debug for BcbCounter { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +struct BcbExpression { + lhs: BcbCounter, + op: Op, + rhs: BcbCounter, +} + #[derive(Debug)] pub(super) enum CounterIncrementSite { Node { bcb: BasicCoverageBlock }, @@ -55,9 +63,13 @@ pub(super) struct CoverageCounters { /// We currently don't iterate over this map, but if we do in the future, /// switch it back to `FxIndexMap` to avoid query stability hazards. bcb_edge_counters: FxHashMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, + /// Table of expression data, associating each expression ID with its /// corresponding operator (+ or -) and its LHS/RHS operands. - expressions: IndexVec<ExpressionId, Expression>, + expressions: IndexVec<ExpressionId, BcbExpression>, + /// Remember expressions that have already been created (or simplified), + /// so that we don't create unnecessary duplicates. + expressions_memo: FxHashMap<BcbExpression, BcbCounter>, } impl CoverageCounters { @@ -75,6 +87,7 @@ impl CoverageCounters { bcb_counters: IndexVec::from_elem_n(None, num_bcbs), bcb_edge_counters: FxHashMap::default(), expressions: IndexVec::new(), + expressions_memo: FxHashMap::default(), }; MakeBcbCounters::new(&mut this, basic_coverage_blocks) @@ -89,8 +102,57 @@ impl CoverageCounters { } fn make_expression(&mut self, lhs: BcbCounter, op: Op, rhs: BcbCounter) -> BcbCounter { - let expression = Expression { lhs: lhs.as_term(), op, rhs: rhs.as_term() }; - let id = self.expressions.push(expression); + let new_expr = BcbExpression { lhs, op, rhs }; + *self + .expressions_memo + .entry(new_expr) + .or_insert_with(|| Self::make_expression_inner(&mut self.expressions, new_expr)) + } + + /// This is an associated function so that we can call it while borrowing + /// `&mut self.expressions_memo`. + fn make_expression_inner( + expressions: &mut IndexVec<ExpressionId, BcbExpression>, + new_expr: BcbExpression, + ) -> BcbCounter { + // Simplify expressions using basic algebra. + // + // Some of these cases might not actually occur in practice, depending + // on the details of how the instrumentor builds expressions. + let BcbExpression { lhs, op, rhs } = new_expr; + + if let BcbCounter::Expression { id } = lhs { + let lhs_expr = &expressions[id]; + + // Simplify `(a - b) + b` to `a`. + if lhs_expr.op == Op::Subtract && op == Op::Add && lhs_expr.rhs == rhs { + return lhs_expr.lhs; + } + // Simplify `(a + b) - b` to `a`. + if lhs_expr.op == Op::Add && op == Op::Subtract && lhs_expr.rhs == rhs { + return lhs_expr.lhs; + } + // Simplify `(a + b) - a` to `b`. + if lhs_expr.op == Op::Add && op == Op::Subtract && lhs_expr.lhs == rhs { + return lhs_expr.rhs; + } + } + + if let BcbCounter::Expression { id } = rhs { + let rhs_expr = &expressions[id]; + + // Simplify `a + (b - a)` to `b`. + if op == Op::Add && rhs_expr.op == Op::Subtract && lhs == rhs_expr.rhs { + return rhs_expr.lhs; + } + // Simplify `a - (a - b)` to `b`. + if op == Op::Subtract && rhs_expr.op == Op::Subtract && lhs == rhs_expr.lhs { + return rhs_expr.rhs; + } + } + + // Simplification failed, so actually create the new expression. + let id = expressions.push(new_expr); BcbCounter::Expression { id } } @@ -165,7 +227,21 @@ impl CoverageCounters { } pub(super) fn into_expressions(self) -> IndexVec<ExpressionId, Expression> { - self.expressions + let old_len = self.expressions.len(); + let expressions = self + .expressions + .into_iter() + .map(|BcbExpression { lhs, op, rhs }| Expression { + lhs: lhs.as_term(), + op, + rhs: rhs.as_term(), + }) + .collect::<IndexVec<ExpressionId, _>>(); + + // Expression IDs are indexes into this vector, so make sure we didn't + // accidentally invalidate them by changing its length. + assert_eq!(old_len, expressions.len()); + expressions } } diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs index 1895735ab35..fd74a2a97e2 100644 --- a/compiler/rustc_mir_transform/src/coverage/graph.rs +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -4,6 +4,7 @@ use rustc_data_structures::graph::dominators::{self, Dominators}; use rustc_data_structures::graph::{self, DirectedGraph, StartNode}; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::{self, BasicBlock, Terminator, TerminatorKind}; use std::cmp::Ordering; diff --git a/compiler/rustc_mir_transform/src/coverage/mappings.rs b/compiler/rustc_mir_transform/src/coverage/mappings.rs new file mode 100644 index 00000000000..61aabea1d8b --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/mappings.rs @@ -0,0 +1,282 @@ +use std::collections::BTreeSet; + +use rustc_data_structures::graph::DirectedGraph; +use rustc_index::bit_set::BitSet; +use rustc_index::IndexVec; +use rustc_middle::mir::coverage::{BlockMarkerId, BranchSpan, ConditionInfo, CoverageKind}; +use rustc_middle::mir::{self, BasicBlock, StatementKind}; +use rustc_span::Span; + +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; +use crate::coverage::spans::{ + extract_refined_covspans, unexpand_into_body_span_with_visible_macro, +}; +use crate::coverage::ExtractedHirInfo; + +/// Associates an ordinary executable code span with its corresponding BCB. +#[derive(Debug)] +pub(super) struct CodeMapping { + pub(super) span: Span, + pub(super) bcb: BasicCoverageBlock, +} + +/// This is separate from [`MCDCBranch`] to help prepare for larger changes +/// that will be needed for improved branch coverage in the future. +/// (See <https://github.com/rust-lang/rust/pull/124217>.) +#[derive(Debug)] +pub(super) struct BranchPair { + pub(super) span: Span, + pub(super) true_bcb: BasicCoverageBlock, + pub(super) false_bcb: BasicCoverageBlock, +} + +/// Associates an MC/DC branch span with condition info besides fields for normal branch. +#[derive(Debug)] +pub(super) struct MCDCBranch { + pub(super) span: Span, + pub(super) true_bcb: BasicCoverageBlock, + pub(super) false_bcb: BasicCoverageBlock, + /// If `None`, this actually represents a normal branch mapping inserted + /// for code that was too complex for MC/DC. + pub(super) condition_info: Option<ConditionInfo>, + pub(super) decision_depth: u16, +} + +/// Associates an MC/DC decision with its join BCBs. +#[derive(Debug)] +pub(super) struct MCDCDecision { + pub(super) span: Span, + pub(super) end_bcbs: BTreeSet<BasicCoverageBlock>, + pub(super) bitmap_idx: u32, + pub(super) conditions_num: u16, + pub(super) decision_depth: u16, +} + +#[derive(Default)] +pub(super) struct ExtractedMappings { + pub(super) code_mappings: Vec<CodeMapping>, + pub(super) branch_pairs: Vec<BranchPair>, + pub(super) mcdc_bitmap_bytes: u32, + pub(super) mcdc_branches: Vec<MCDCBranch>, + pub(super) mcdc_decisions: Vec<MCDCDecision>, +} + +/// Extracts coverage-relevant spans from MIR, and associates them with +/// their corresponding BCBs. +pub(super) fn extract_all_mapping_info_from_mir( + mir_body: &mir::Body<'_>, + hir_info: &ExtractedHirInfo, + basic_coverage_blocks: &CoverageGraph, +) -> ExtractedMappings { + if hir_info.is_async_fn { + // An async function desugars into a function that returns a future, + // with the user code wrapped in a closure. Any spans in the desugared + // outer function will be unhelpful, so just keep the signature span + // and ignore all of the spans in the MIR body. + let mut mappings = ExtractedMappings::default(); + if let Some(span) = hir_info.fn_sig_span_extended { + mappings.code_mappings.push(CodeMapping { span, bcb: START_BCB }); + } + return mappings; + } + + let mut code_mappings = vec![]; + let mut branch_pairs = vec![]; + let mut mcdc_bitmap_bytes = 0; + let mut mcdc_branches = vec![]; + let mut mcdc_decisions = vec![]; + + extract_refined_covspans(mir_body, hir_info, basic_coverage_blocks, &mut code_mappings); + + branch_pairs.extend(extract_branch_pairs(mir_body, hir_info, basic_coverage_blocks)); + + extract_mcdc_mappings( + mir_body, + hir_info.body_span, + basic_coverage_blocks, + &mut mcdc_bitmap_bytes, + &mut mcdc_branches, + &mut mcdc_decisions, + ); + + ExtractedMappings { + code_mappings, + branch_pairs, + mcdc_bitmap_bytes, + mcdc_branches, + mcdc_decisions, + } +} + +impl ExtractedMappings { + pub(super) fn all_bcbs_with_counter_mappings( + &self, + basic_coverage_blocks: &CoverageGraph, // Only used for allocating a correctly-sized set + ) -> BitSet<BasicCoverageBlock> { + // Fully destructure self to make sure we don't miss any fields that have mappings. + let Self { + code_mappings, + branch_pairs, + mcdc_bitmap_bytes: _, + mcdc_branches, + mcdc_decisions, + } = self; + + // Identify which BCBs have one or more mappings. + let mut bcbs_with_counter_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); + let mut insert = |bcb| { + bcbs_with_counter_mappings.insert(bcb); + }; + + for &CodeMapping { span: _, bcb } in code_mappings { + insert(bcb); + } + for &BranchPair { true_bcb, false_bcb, .. } in branch_pairs { + insert(true_bcb); + insert(false_bcb); + } + for &MCDCBranch { true_bcb, false_bcb, .. } in mcdc_branches { + insert(true_bcb); + insert(false_bcb); + } + + // MC/DC decisions refer to BCBs, but don't require those BCBs to have counters. + if bcbs_with_counter_mappings.is_empty() { + debug_assert!( + mcdc_decisions.is_empty(), + "A function with no counter mappings shouldn't have any decisions: {mcdc_decisions:?}", + ); + } + + bcbs_with_counter_mappings + } +} + +fn resolve_block_markers( + branch_info: &mir::coverage::BranchInfo, + mir_body: &mir::Body<'_>, +) -> IndexVec<BlockMarkerId, Option<BasicBlock>> { + let mut block_markers = IndexVec::<BlockMarkerId, Option<BasicBlock>>::from_elem_n( + None, + branch_info.num_block_markers, + ); + + // Fill out the mapping from block marker IDs to their enclosing blocks. + for (bb, data) in mir_body.basic_blocks.iter_enumerated() { + for statement in &data.statements { + if let StatementKind::Coverage(CoverageKind::BlockMarker { id }) = statement.kind { + block_markers[id] = Some(bb); + } + } + } + + block_markers +} + +// FIXME: There is currently a lot of redundancy between +// `extract_branch_pairs` and `extract_mcdc_mappings`. This is needed so +// that they can each be modified without interfering with the other, but in +// the long term we should try to bring them together again when branch coverage +// and MC/DC coverage support are more mature. + +pub(super) fn extract_branch_pairs( + mir_body: &mir::Body<'_>, + hir_info: &ExtractedHirInfo, + basic_coverage_blocks: &CoverageGraph, +) -> Vec<BranchPair> { + let Some(branch_info) = mir_body.coverage_branch_info.as_deref() else { return vec![] }; + + let block_markers = resolve_block_markers(branch_info, mir_body); + + branch_info + .branch_spans + .iter() + .filter_map(|&BranchSpan { span: raw_span, true_marker, false_marker }| { + // For now, ignore any branch span that was introduced by + // expansion. This makes things like assert macros less noisy. + if !raw_span.ctxt().outer_expn_data().is_root() { + return None; + } + let (span, _) = + unexpand_into_body_span_with_visible_macro(raw_span, hir_info.body_span)?; + + let bcb_from_marker = + |marker: BlockMarkerId| basic_coverage_blocks.bcb_from_bb(block_markers[marker]?); + + let true_bcb = bcb_from_marker(true_marker)?; + let false_bcb = bcb_from_marker(false_marker)?; + + Some(BranchPair { span, true_bcb, false_bcb }) + }) + .collect::<Vec<_>>() +} + +pub(super) fn extract_mcdc_mappings( + mir_body: &mir::Body<'_>, + body_span: Span, + basic_coverage_blocks: &CoverageGraph, + mcdc_bitmap_bytes: &mut u32, + mcdc_branches: &mut impl Extend<MCDCBranch>, + mcdc_decisions: &mut impl Extend<MCDCDecision>, +) { + let Some(branch_info) = mir_body.coverage_branch_info.as_deref() else { return }; + + let block_markers = resolve_block_markers(branch_info, mir_body); + + let bcb_from_marker = + |marker: BlockMarkerId| basic_coverage_blocks.bcb_from_bb(block_markers[marker]?); + + let check_branch_bcb = + |raw_span: Span, true_marker: BlockMarkerId, false_marker: BlockMarkerId| { + // For now, ignore any branch span that was introduced by + // expansion. This makes things like assert macros less noisy. + if !raw_span.ctxt().outer_expn_data().is_root() { + return None; + } + let (span, _) = unexpand_into_body_span_with_visible_macro(raw_span, body_span)?; + + let true_bcb = bcb_from_marker(true_marker)?; + let false_bcb = bcb_from_marker(false_marker)?; + Some((span, true_bcb, false_bcb)) + }; + + mcdc_branches.extend(branch_info.mcdc_branch_spans.iter().filter_map( + |&mir::coverage::MCDCBranchSpan { + span: raw_span, + condition_info, + true_marker, + false_marker, + decision_depth, + }| { + let (span, true_bcb, false_bcb) = + check_branch_bcb(raw_span, true_marker, false_marker)?; + Some(MCDCBranch { span, true_bcb, false_bcb, condition_info, decision_depth }) + }, + )); + + mcdc_decisions.extend(branch_info.mcdc_decision_spans.iter().filter_map( + |decision: &mir::coverage::MCDCDecisionSpan| { + let (span, _) = unexpand_into_body_span_with_visible_macro(decision.span, body_span)?; + + let end_bcbs = decision + .end_markers + .iter() + .map(|&marker| bcb_from_marker(marker)) + .collect::<Option<_>>()?; + + // Each decision containing N conditions needs 2^N bits of space in + // the bitmap, rounded up to a whole number of bytes. + // The decision's "bitmap index" points to its first byte in the bitmap. + let bitmap_idx = *mcdc_bitmap_bytes; + *mcdc_bitmap_bytes += (1_u32 << decision.conditions_num).div_ceil(8); + + Some(MCDCDecision { + span, + end_bcbs, + bitmap_idx, + conditions_num: decision.conditions_num as u16, + decision_depth: decision.decision_depth, + }) + }, + )); +} diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index d382d2c03c2..28e0c633d5a 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -2,18 +2,14 @@ pub mod query; mod counters; mod graph; +mod mappings; mod spans; - #[cfg(test)] mod tests; -use self::counters::{CounterIncrementSite, CoverageCounters}; -use self::graph::{BasicCoverageBlock, CoverageGraph}; -use self::spans::{BcbMapping, BcbMappingKind, CoverageSpans}; - -use crate::MirPass; - -use rustc_middle::mir::coverage::*; +use rustc_middle::mir::coverage::{ + CodeRegion, CoverageKind, DecisionInfo, FunctionCoverageInfo, Mapping, MappingKind, +}; use rustc_middle::mir::{ self, BasicBlock, BasicBlockData, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, @@ -23,6 +19,11 @@ use rustc_span::def_id::LocalDefId; use rustc_span::source_map::SourceMap; use rustc_span::{BytePos, Pos, RelativeBytePos, Span, Symbol}; +use crate::coverage::counters::{CounterIncrementSite, CoverageCounters}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph}; +use crate::coverage::mappings::ExtractedMappings; +use crate::MirPass; + /// Inserts `StatementKind::Coverage` statements that either instrument the binary with injected /// counters, via intrinsic `llvm.instrprof.increment`, and/or inject metadata used during codegen /// to construct the coverage map. @@ -69,24 +70,27 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir: let basic_coverage_blocks = CoverageGraph::from_mir(mir_body); //////////////////////////////////////////////////// - // Compute coverage spans from the `CoverageGraph`. - let Some(coverage_spans) = - spans::generate_coverage_spans(mir_body, &hir_info, &basic_coverage_blocks) - else { - // No relevant spans were found in MIR, so skip instrumenting this function. - return; - }; + // Extract coverage spans and other mapping info from MIR. + let extracted_mappings = + mappings::extract_all_mapping_info_from_mir(mir_body, &hir_info, &basic_coverage_blocks); //////////////////////////////////////////////////// // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` // and all `Expression` dependencies (operands) are also generated, for any other // `BasicCoverageBlock`s not already associated with a coverage span. - let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); + let bcbs_with_counter_mappings = + extracted_mappings.all_bcbs_with_counter_mappings(&basic_coverage_blocks); + if bcbs_with_counter_mappings.is_empty() { + // No relevant spans were found in MIR, so skip instrumenting this function. + return; + } + + let bcb_has_counter_mappings = |bcb| bcbs_with_counter_mappings.contains(bcb); let coverage_counters = - CoverageCounters::make_bcb_counters(&basic_coverage_blocks, bcb_has_coverage_spans); + CoverageCounters::make_bcb_counters(&basic_coverage_blocks, bcb_has_counter_mappings); - let mappings = create_mappings(tcx, &hir_info, &coverage_spans, &coverage_counters); + let mappings = create_mappings(tcx, &hir_info, &extracted_mappings, &coverage_counters); if mappings.is_empty() { // No spans could be converted into valid mappings, so skip this function. debug!("no spans could be converted into valid mappings; skipping"); @@ -96,15 +100,26 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir: inject_coverage_statements( mir_body, &basic_coverage_blocks, - bcb_has_coverage_spans, + bcb_has_counter_mappings, &coverage_counters, ); + inject_mcdc_statements(mir_body, &basic_coverage_blocks, &extracted_mappings); + + let mcdc_num_condition_bitmaps = extracted_mappings + .mcdc_decisions + .iter() + .map(|&mappings::MCDCDecision { decision_depth, .. }| decision_depth) + .max() + .map_or(0, |max| usize::from(max) + 1); + mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { function_source_hash: hir_info.function_source_hash, num_counters: coverage_counters.num_counters(), + mcdc_bitmap_bytes: extracted_mappings.mcdc_bitmap_bytes, expressions: coverage_counters.into_expressions(), mappings, + mcdc_num_condition_bitmaps, })); } @@ -116,7 +131,7 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir: fn create_mappings<'tcx>( tcx: TyCtxt<'tcx>, hir_info: &ExtractedHirInfo, - coverage_spans: &CoverageSpans, + extracted_mappings: &ExtractedMappings, coverage_counters: &CoverageCounters, ) -> Vec<Mapping> { let source_map = tcx.sess.source_map(); @@ -135,21 +150,59 @@ fn create_mappings<'tcx>( .expect("all BCBs with spans were given counters") .as_term() }; - - coverage_spans - .all_bcb_mappings() - .filter_map(|&BcbMapping { kind: bcb_mapping_kind, span }| { - let kind = match bcb_mapping_kind { - BcbMappingKind::Code(bcb) => MappingKind::Code(term_for_bcb(bcb)), - BcbMappingKind::Branch { true_bcb, false_bcb } => MappingKind::Branch { - true_term: term_for_bcb(true_bcb), - false_term: term_for_bcb(false_bcb), - }, + let region_for_span = |span: Span| make_code_region(source_map, file_name, span, body_span); + + // Fully destructure the mappings struct to make sure we don't miss any kinds. + let ExtractedMappings { + code_mappings, + branch_pairs, + mcdc_bitmap_bytes: _, + mcdc_branches, + mcdc_decisions, + } = extracted_mappings; + let mut mappings = Vec::new(); + + mappings.extend(code_mappings.iter().filter_map( + // Ordinary code mappings are the simplest kind. + |&mappings::CodeMapping { span, bcb }| { + let code_region = region_for_span(span)?; + let kind = MappingKind::Code(term_for_bcb(bcb)); + Some(Mapping { kind, code_region }) + }, + )); + + mappings.extend(branch_pairs.iter().filter_map( + |&mappings::BranchPair { span, true_bcb, false_bcb }| { + let true_term = term_for_bcb(true_bcb); + let false_term = term_for_bcb(false_bcb); + let kind = MappingKind::Branch { true_term, false_term }; + let code_region = region_for_span(span)?; + Some(Mapping { kind, code_region }) + }, + )); + + mappings.extend(mcdc_branches.iter().filter_map( + |&mappings::MCDCBranch { span, true_bcb, false_bcb, condition_info, decision_depth: _ }| { + let code_region = region_for_span(span)?; + let true_term = term_for_bcb(true_bcb); + let false_term = term_for_bcb(false_bcb); + let kind = match condition_info { + Some(mcdc_params) => MappingKind::MCDCBranch { true_term, false_term, mcdc_params }, + None => MappingKind::Branch { true_term, false_term }, }; - let code_region = make_code_region(source_map, file_name, span, body_span)?; Some(Mapping { kind, code_region }) - }) - .collect::<Vec<_>>() + }, + )); + + mappings.extend(mcdc_decisions.iter().filter_map( + |&mappings::MCDCDecision { span, bitmap_idx, conditions_num, .. }| { + let code_region = region_for_span(span)?; + let kind = MappingKind::MCDCDecision(DecisionInfo { bitmap_idx, conditions_num }); + Some(Mapping { kind, code_region }) + }, + )); + + mappings } /// For each BCB node or BCB edge that has an associated coverage counter, @@ -204,6 +257,53 @@ fn inject_coverage_statements<'tcx>( } } +/// For each conditions inject statements to update condition bitmap after it has been evaluated. +/// For each decision inject statements to update test vector bitmap after it has been evaluated. +fn inject_mcdc_statements<'tcx>( + mir_body: &mut mir::Body<'tcx>, + basic_coverage_blocks: &CoverageGraph, + extracted_mappings: &ExtractedMappings, +) { + // Inject test vector update first because `inject_statement` always insert new statement at head. + for &mappings::MCDCDecision { + span: _, + ref end_bcbs, + bitmap_idx, + conditions_num: _, + decision_depth, + } in &extracted_mappings.mcdc_decisions + { + for end in end_bcbs { + let end_bb = basic_coverage_blocks[*end].leader_bb(); + inject_statement( + mir_body, + CoverageKind::TestVectorBitmapUpdate { bitmap_idx, decision_depth }, + end_bb, + ); + } + } + + for &mappings::MCDCBranch { span: _, true_bcb, false_bcb, condition_info, decision_depth } in + &extracted_mappings.mcdc_branches + { + let Some(condition_info) = condition_info else { continue }; + let id = condition_info.condition_id; + + let true_bb = basic_coverage_blocks[true_bcb].leader_bb(); + inject_statement( + mir_body, + CoverageKind::CondBitmapUpdate { id, value: true, decision_depth }, + true_bb, + ); + let false_bb = basic_coverage_blocks[false_bcb].leader_bb(); + inject_statement( + mir_body, + CoverageKind::CondBitmapUpdate { id, value: false, decision_depth }, + false_bb, + ); + } +} + /// Given two basic blocks that have a control-flow edge between them, creates /// and returns a new block that sits between those blocks. fn inject_edge_counter_basic_block( diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 03ede886688..f2f76ac70c2 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -1,101 +1,32 @@ -use rustc_data_structures::graph::DirectedGraph; -use rustc_index::bit_set::BitSet; +use rustc_middle::bug; use rustc_middle::mir; use rustc_span::{BytePos, Span}; -use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph}; +use crate::coverage::mappings; use crate::coverage::spans::from_mir::SpanFromMir; use crate::coverage::ExtractedHirInfo; mod from_mir; -#[derive(Clone, Copy, Debug)] -pub(super) enum BcbMappingKind { - /// Associates an ordinary executable code span with its corresponding BCB. - Code(BasicCoverageBlock), - /// Associates a branch span with BCBs for its true and false arms. - Branch { true_bcb: BasicCoverageBlock, false_bcb: BasicCoverageBlock }, -} - -#[derive(Debug)] -pub(super) struct BcbMapping { - pub(super) kind: BcbMappingKind, - pub(super) span: Span, -} - -pub(super) struct CoverageSpans { - bcb_has_mappings: BitSet<BasicCoverageBlock>, - mappings: Vec<BcbMapping>, -} - -impl CoverageSpans { - pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool { - self.bcb_has_mappings.contains(bcb) - } - - pub(super) fn all_bcb_mappings(&self) -> impl Iterator<Item = &BcbMapping> { - self.mappings.iter() - } -} +// FIXME(#124545) It's awkward that we have to re-export this, because it's an +// internal detail of `from_mir` that is also needed when handling branch and +// MC/DC spans. Ideally we would find a more natural home for it. +pub(super) use from_mir::unexpand_into_body_span_with_visible_macro; -/// Extracts coverage-relevant spans from MIR, and associates them with -/// their corresponding BCBs. -/// -/// Returns `None` if no coverage-relevant spans could be extracted. -pub(super) fn generate_coverage_spans( +pub(super) fn extract_refined_covspans( mir_body: &mir::Body<'_>, hir_info: &ExtractedHirInfo, basic_coverage_blocks: &CoverageGraph, -) -> Option<CoverageSpans> { - let mut mappings = vec![]; - - if hir_info.is_async_fn { - // An async function desugars into a function that returns a future, - // with the user code wrapped in a closure. Any spans in the desugared - // outer function will be unhelpful, so just keep the signature span - // and ignore all of the spans in the MIR body. - if let Some(span) = hir_info.fn_sig_span_extended { - mappings.push(BcbMapping { kind: BcbMappingKind::Code(START_BCB), span }); - } - } else { - let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans( - mir_body, - hir_info, - basic_coverage_blocks, - ); - let coverage_spans = SpansRefiner::refine_sorted_spans(sorted_spans); - mappings.extend(coverage_spans.into_iter().map(|RefinedCovspan { bcb, span, .. }| { - // Each span produced by the generator represents an ordinary code region. - BcbMapping { kind: BcbMappingKind::Code(bcb), span } - })); - - mappings.extend(from_mir::extract_branch_mappings( - mir_body, - hir_info.body_span, - basic_coverage_blocks, - )); - } - - if mappings.is_empty() { - return None; - } - - // Identify which BCBs have one or more mappings. - let mut bcb_has_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); - let mut insert = |bcb| { - bcb_has_mappings.insert(bcb); - }; - for &BcbMapping { kind, span: _ } in &mappings { - match kind { - BcbMappingKind::Code(bcb) => insert(bcb), - BcbMappingKind::Branch { true_bcb, false_bcb } => { - insert(true_bcb); - insert(false_bcb); - } - } - } - - Some(CoverageSpans { bcb_has_mappings, mappings }) + code_mappings: &mut impl Extend<mappings::CodeMapping>, +) { + let sorted_spans = + from_mir::mir_to_initial_sorted_coverage_spans(mir_body, hir_info, basic_coverage_blocks); + let coverage_spans = SpansRefiner::refine_sorted_spans(sorted_spans); + code_mappings.extend(coverage_spans.into_iter().map(|RefinedCovspan { bcb, span, .. }| { + // Each span produced by the generator represents an ordinary code region. + mappings::CodeMapping { span, bcb } + })); } #[derive(Debug)] diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs index adb0c9f1929..d1727a94a35 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -1,9 +1,9 @@ use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::FxHashSet; -use rustc_index::IndexVec; -use rustc_middle::mir::coverage::{BlockMarkerId, BranchSpan, CoverageKind}; +use rustc_middle::bug; +use rustc_middle::mir::coverage::CoverageKind; use rustc_middle::mir::{ - self, AggregateKind, BasicBlock, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, + self, AggregateKind, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, }; use rustc_span::{ExpnKind, MacroKind, Span, Symbol}; @@ -11,7 +11,6 @@ use rustc_span::{ExpnKind, MacroKind, Span, Symbol}; use crate::coverage::graph::{ BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB, }; -use crate::coverage::spans::{BcbMapping, BcbMappingKind}; use crate::coverage::ExtractedHirInfo; /// Traverses the MIR body to produce an initial collection of coverage-relevant @@ -227,7 +226,10 @@ fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> { // These coverage statements should not exist prior to coverage instrumentation. StatementKind::Coverage( - CoverageKind::CounterIncrement { .. } | CoverageKind::ExpressionUsed { .. }, + CoverageKind::CounterIncrement { .. } + | CoverageKind::ExpressionUsed { .. } + | CoverageKind::CondBitmapUpdate { .. } + | CoverageKind::TestVectorBitmapUpdate { .. }, ) => bug!( "Unexpected coverage statement found during coverage instrumentation: {statement:?}" ), @@ -282,7 +284,7 @@ fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { /// /// [^1]Expansions result from Rust syntax including macros, syntactic sugar, /// etc.). -fn unexpand_into_body_span_with_visible_macro( +pub(crate) fn unexpand_into_body_span_with_visible_macro( original_span: Span, body_span: Span, ) -> Option<(Span, Option<Symbol>)> { @@ -360,48 +362,3 @@ impl SpanFromMir { Self { span, visible_macro, bcb, is_hole } } } - -pub(super) fn extract_branch_mappings( - mir_body: &mir::Body<'_>, - body_span: Span, - basic_coverage_blocks: &CoverageGraph, -) -> Vec<BcbMapping> { - let Some(branch_info) = mir_body.coverage_branch_info.as_deref() else { - return vec![]; - }; - - let mut block_markers = IndexVec::<BlockMarkerId, Option<BasicBlock>>::from_elem_n( - None, - branch_info.num_block_markers, - ); - - // Fill out the mapping from block marker IDs to their enclosing blocks. - for (bb, data) in mir_body.basic_blocks.iter_enumerated() { - for statement in &data.statements { - if let StatementKind::Coverage(CoverageKind::BlockMarker { id }) = statement.kind { - block_markers[id] = Some(bb); - } - } - } - - branch_info - .branch_spans - .iter() - .filter_map(|&BranchSpan { span: raw_span, true_marker, false_marker }| { - // For now, ignore any branch span that was introduced by - // expansion. This makes things like assert macros less noisy. - if !raw_span.ctxt().outer_expn_data().is_root() { - return None; - } - let (span, _) = unexpand_into_body_span_with_visible_macro(raw_span, body_span)?; - - let bcb_from_marker = - |marker: BlockMarkerId| basic_coverage_blocks.bcb_from_bb(block_markers[marker]?); - - let true_bcb = bcb_from_marker(true_marker)?; - let false_bcb = bcb_from_marker(false_marker)?; - - Some(BcbMapping { kind: BcbMappingKind::Branch { true_bcb, false_bcb }, span }) - }) - .collect::<Vec<_>>() -} diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs index cf1a2b399f9..ca64688e6b8 100644 --- a/compiler/rustc_mir_transform/src/coverage/tests.rs +++ b/compiler/rustc_mir_transform/src/coverage/tests.rs @@ -30,6 +30,7 @@ use super::graph::{self, BasicCoverageBlock}; use itertools::Itertools; use rustc_data_structures::graph::{DirectedGraph, Successors}; use rustc_index::{Idx, IndexVec}; +use rustc_middle::bug; use rustc_middle::mir::*; use rustc_middle::ty; use rustc_span::{BytePos, Pos, Span, DUMMY_SP}; diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index d0f6ec8f21f..e88b727a21e 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -6,6 +6,7 @@ use rustc_const_eval::const_eval::{throw_machine_stop_str, DummyMachine}; use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; use rustc_data_structures::fx::FxHashMap; use rustc_hir::def::DefKind; +use rustc_middle::bug; use rustc_middle::mir::interpret::{InterpResult, Scalar}; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; @@ -68,7 +69,7 @@ struct ConstAnalysis<'a, 'tcx> { map: Map, tcx: TyCtxt<'tcx>, local_decls: &'a LocalDecls<'tcx>, - ecx: InterpCx<'tcx, 'tcx, DummyMachine>, + ecx: InterpCx<'tcx, DummyMachine>, param_env: ty::ParamEnv<'tcx>, } @@ -141,11 +142,10 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { _ => return, }; if let Some(variant_target_idx) = variant_target { - for (field_index, operand) in operands.iter().enumerate() { - if let Some(field) = self.map().apply( - variant_target_idx, - TrackElem::Field(FieldIdx::from_usize(field_index)), - ) { + for (field_index, operand) in operands.iter_enumerated() { + if let Some(field) = + self.map().apply(variant_target_idx, TrackElem::Field(field_index)) + { self.assign_operand(state, field, operand); } } @@ -164,7 +164,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { } } } - Rvalue::CheckedBinaryOp(op, box (left, right)) => { + Rvalue::BinaryOp(op, box (left, right)) if op.is_overflowing() => { // Flood everything now, so we can use `insert_value_idx` directly later. state.flood(target.as_ref(), self.map()); @@ -183,7 +183,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { if let Some(overflow_target) = overflow_target { let overflow = match overflow { FlatSet::Top => FlatSet::Top, - FlatSet::Elem(overflow) => FlatSet::Elem(Scalar::from_bool(overflow)), + FlatSet::Elem(overflow) => FlatSet::Elem(overflow), FlatSet::Bottom => FlatSet::Bottom, }; // We have flooded `target` earlier. @@ -202,7 +202,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { if let Some(target_len) = self.map().find_len(target.as_ref()) && let operand_ty = operand.ty(self.local_decls, self.tcx) && let Some(operand_ty) = operand_ty.builtin_deref(true) - && let ty::Array(_, len) = operand_ty.ty.kind() + && let ty::Array(_, len) = operand_ty.kind() && let Some(len) = Const::Ty(*len).try_eval_scalar_int(self.tcx, self.param_env) { state.insert_value_idx(target_len, FlatSet::Elem(len.into()), self.map()); @@ -263,15 +263,16 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { FlatSet::Top => FlatSet::Top, } } - Rvalue::BinaryOp(op, box (left, right)) => { + Rvalue::BinaryOp(op, box (left, right)) if !op.is_overflowing() => { // Overflows must be ignored here. + // The overflowing operators are handled in `handle_assign`. let (val, _overflow) = self.binary_op(state, *op, left, right); val } Rvalue::UnaryOp(op, operand) => match self.eval_operand(operand, state) { FlatSet::Elem(value) => self .ecx - .wrapping_unary_op(*op, &value) + .unary_op(*op, &value) .map_or(FlatSet::Top, |val| self.wrap_immediate(*val)), FlatSet::Bottom => FlatSet::Bottom, FlatSet::Top => FlatSet::Top, @@ -436,7 +437,7 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { op: BinOp, left: &Operand<'tcx>, right: &Operand<'tcx>, - ) -> (FlatSet<Scalar>, FlatSet<bool>) { + ) -> (FlatSet<Scalar>, FlatSet<Scalar>) { let left = self.eval_operand(left, state); let right = self.eval_operand(right, state); @@ -444,9 +445,17 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { (FlatSet::Bottom, _) | (_, FlatSet::Bottom) => (FlatSet::Bottom, FlatSet::Bottom), // Both sides are known, do the actual computation. (FlatSet::Elem(left), FlatSet::Elem(right)) => { - match self.ecx.overflowing_binary_op(op, &left, &right) { - Ok((val, overflow)) => { - (FlatSet::Elem(val.to_scalar()), FlatSet::Elem(overflow)) + match self.ecx.binary_op(op, &left, &right) { + // Ideally this would return an Immediate, since it's sometimes + // a pair and sometimes not. But as a hack we always return a pair + // and just make the 2nd component `Bottom` when it does not exist. + Ok(val) => { + if matches!(val.layout.abi, Abi::ScalarPair(..)) { + let (val, overflow) = val.to_scalar_pair(); + (FlatSet::Elem(val), FlatSet::Elem(overflow)) + } else { + (FlatSet::Elem(val.to_scalar()), FlatSet::Bottom) + } } _ => (FlatSet::Top, FlatSet::Top), } @@ -472,7 +481,7 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { (FlatSet::Elem(arg_scalar), FlatSet::Bottom) } BinOp::Mul if layout.ty.is_integral() && arg_value == 0 => { - (FlatSet::Elem(arg_scalar), FlatSet::Elem(false)) + (FlatSet::Elem(arg_scalar), FlatSet::Elem(Scalar::from_bool(false))) } _ => (FlatSet::Top, FlatSet::Top), } @@ -555,7 +564,7 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> { fn try_make_constant( &self, - ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>, + ecx: &mut InterpCx<'tcx, DummyMachine>, place: Place<'tcx>, state: &State<FlatSet<Scalar>>, map: &Map, @@ -608,7 +617,7 @@ fn propagatable_scalar( #[instrument(level = "trace", skip(ecx, state, map))] fn try_write_constant<'tcx>( - ecx: &mut InterpCx<'_, 'tcx, DummyMachine>, + ecx: &mut InterpCx<'tcx, DummyMachine>, dest: &PlaceTy<'tcx>, place: PlaceIndex, ty: Ty<'tcx>, @@ -826,7 +835,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> { struct OperandCollector<'tcx, 'map, 'locals, 'a> { state: &'a State<FlatSet<Scalar>>, visitor: &'a mut Collector<'tcx, 'locals>, - ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>, + ecx: &'map mut InterpCx<'tcx, DummyMachine>, map: &'map Map, } diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs index e6317e5469c..08dba1de500 100644 --- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -13,6 +13,7 @@ //! use crate::util::is_within_packed; +use rustc_middle::bug; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs index ca63f5550ae..370e930b740 100644 --- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -160,7 +160,7 @@ pub fn deduced_param_attrs<'tcx>( return &[]; } - // If the Freeze language item isn't present, then don't bother. + // If the Freeze lang item isn't present, then don't bother. if tcx.lang_items().freeze_trait().is_none() { return &[]; } diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index 10fea09531a..b1016c0867c 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -135,6 +135,7 @@ use crate::MirPass; use rustc_data_structures::fx::{FxIndexMap, IndexEntry, IndexOccupiedEntry}; use rustc_index::bit_set::BitSet; use rustc_index::interval::SparseIntervalMatrix; +use rustc_middle::bug; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::HasLocalDecls; use rustc_middle::mir::{dump_mir, PassWhere}; @@ -563,7 +564,7 @@ impl WriteInfo { | Rvalue::ShallowInitBox(op, _) => { self.add_operand(op); } - Rvalue::BinaryOp(_, ops) | Rvalue::CheckedBinaryOp(_, ops) => { + Rvalue::BinaryOp(_, ops) => { for op in [&ops.0, &ops.1] { self.add_operand(op); } diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs index 318674f24e7..d955b96d06a 100644 --- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -6,6 +6,7 @@ use rustc_hir::def_id::DefId; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::ty::{Ty, TyCtxt}; use rustc_target::abi::FieldIdx; diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs index 0634e321ea3..b28dcb38cb6 100644 --- a/compiler/rustc_mir_transform/src/errors.rs +++ b/compiler/rustc_mir_transform/src/errors.rs @@ -1,4 +1,4 @@ -use rustc_errors::{codes::*, Diag, DiagMessage, LintDiagnostic}; +use rustc_errors::{codes::*, Diag, LintDiagnostic}; use rustc_macros::{Diagnostic, LintDiagnostic, Subdiagnostic}; use rustc_middle::mir::AssertKind; use rustc_middle::ty::TyCtxt; @@ -50,18 +50,15 @@ pub(crate) enum AssertLintKind { impl<'a, P: std::fmt::Debug> LintDiagnostic<'a, ()> for AssertLint<P> { fn decorate_lint<'b>(self, diag: &'b mut Diag<'a, ()>) { - let message = self.assert_kind.diagnostic_message(); + diag.primary_message(match self.lint_kind { + AssertLintKind::ArithmeticOverflow => fluent::mir_transform_arithmetic_overflow, + AssertLintKind::UnconditionalPanic => fluent::mir_transform_operation_will_panic, + }); + let label = self.assert_kind.diagnostic_message(); self.assert_kind.add_args(&mut |name, value| { diag.arg(name, value); }); - diag.span_label(self.span, message); - } - - fn msg(&self) -> DiagMessage { - match self.lint_kind { - AssertLintKind::ArithmeticOverflow => fluent::mir_transform_arithmetic_overflow, - AssertLintKind::UnconditionalPanic => fluent::mir_transform_operation_will_panic, - } + diag.span_label(self.span, label); } } @@ -104,6 +101,7 @@ pub(crate) struct MustNotSupend<'tcx, 'a> { // Needed for def_path_str impl<'a> LintDiagnostic<'a, ()> for MustNotSupend<'_, '_> { fn decorate_lint<'b>(self, diag: &'b mut rustc_errors::Diag<'a, ()>) { + diag.primary_message(fluent::mir_transform_must_not_suspend); diag.span_label(self.yield_sp, fluent::_subdiag::label); if let Some(reason) = self.reason { diag.subdiagnostic(diag.dcx, reason); @@ -113,10 +111,6 @@ impl<'a> LintDiagnostic<'a, ()> for MustNotSupend<'_, '_> { diag.arg("def_path", self.tcx.def_path_str(self.def_id)); diag.arg("post", self.post); } - - fn msg(&self) -> rustc_errors::DiagMessage { - fluent::mir_transform_must_not_suspend - } } #[derive(Subdiagnostic)] diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs index 0970c4de19f..5e3cd853675 100644 --- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -4,6 +4,7 @@ use rustc_middle::query::LocalCrate; use rustc_middle::query::Providers; use rustc_middle::ty::layout; use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::{bug, span_bug}; use rustc_session::lint::builtin::FFI_UNWIND_CALLS; use rustc_target::spec::abi::Abi; use rustc_target::spec::PanicStrategy; diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index 30b1ca67800..434529ccff4 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -158,7 +158,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { .lint_root; // FIXME: use existing printing routines to print the function signature let fn_sig = self.tcx.fn_sig(fn_id).instantiate(self.tcx, fn_args); - let unsafety = fn_sig.unsafety().prefix_str(); + let unsafety = fn_sig.safety().prefix_str(); let abi = match fn_sig.abi() { Abi::Rust => String::from(""), other_abi => { diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index 8e8d78226c3..fadb5edefdf 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -91,6 +91,7 @@ use rustc_hir::def::DefKind; use rustc_index::bit_set::BitSet; use rustc_index::newtype_index; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::interpret::GlobalAlloc; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; @@ -121,7 +122,7 @@ impl<'tcx> MirPass<'tcx> for GVN { fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - let ssa = SsaLocals::new(body); + let ssa = SsaLocals::new(tcx, body, param_env); // Clone dominators as we need them while mutating the body. let dominators = body.basic_blocks.dominators().clone(); @@ -222,7 +223,7 @@ enum Value<'tcx> { NullaryOp(NullOp<'tcx>, Ty<'tcx>), UnaryOp(UnOp, VnIndex), BinaryOp(BinOp, VnIndex, VnIndex), - CheckedBinaryOp(BinOp, VnIndex, VnIndex), + CheckedBinaryOp(BinOp, VnIndex, VnIndex), // FIXME get rid of this, work like MIR instead Cast { kind: CastKind, value: VnIndex, @@ -233,7 +234,7 @@ enum Value<'tcx> { struct VnState<'body, 'tcx> { tcx: TyCtxt<'tcx>, - ecx: InterpCx<'tcx, 'tcx, DummyMachine>, + ecx: InterpCx<'tcx, DummyMachine>, param_env: ty::ParamEnv<'tcx>, local_decls: &'body LocalDecls<'tcx>, /// Value stored in each local. @@ -496,7 +497,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { UnaryOp(un_op, operand) => { let operand = self.evaluated[operand].as_ref()?; let operand = self.ecx.read_immediate(operand).ok()?; - let (val, _) = self.ecx.overflowing_unary_op(un_op, &operand).ok()?; + let val = self.ecx.unary_op(un_op, &operand).ok()?; val.into() } BinaryOp(bin_op, lhs, rhs) => { @@ -504,7 +505,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let lhs = self.ecx.read_immediate(lhs).ok()?; let rhs = self.evaluated[rhs].as_ref()?; let rhs = self.ecx.read_immediate(rhs).ok()?; - let (val, _) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; + let val = self.ecx.binary_op(bin_op, &lhs, &rhs).ok()?; val.into() } CheckedBinaryOp(bin_op, lhs, rhs) => { @@ -512,14 +513,11 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let lhs = self.ecx.read_immediate(lhs).ok()?; let rhs = self.evaluated[rhs].as_ref()?; let rhs = self.ecx.read_immediate(rhs).ok()?; - let (val, overflowed) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; - let tuple = Ty::new_tup_from_iter( - self.tcx, - [val.layout.ty, self.tcx.types.bool].into_iter(), - ); - let tuple = self.ecx.layout_of(tuple).ok()?; - ImmTy::from_scalar_pair(val.to_scalar(), Scalar::from_bool(overflowed), tuple) - .into() + let val = self + .ecx + .binary_op(bin_op.wrapping_to_overflowing().unwrap(), &lhs, &rhs) + .ok()?; + val.into() } Cast { kind, value, from: _, to } => match kind { CastKind::IntToInt | CastKind::IntToFloat => { @@ -594,7 +592,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let ty = place.ty(self.local_decls, self.tcx).ty; if let Some(Mutability::Not) = ty.ref_mutability() && let Some(pointee_ty) = ty.builtin_deref(true) - && pointee_ty.ty.is_freeze(self.tcx, self.param_env) + && pointee_ty.is_freeze(self.tcx, self.param_env) { // An immutable borrow `_x` always points to the same value for the // lifetime of the borrow, so we can merge all instances of `*_x`. @@ -724,6 +722,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { // Invariant: `value` holds the value up-to the `index`th projection excluded. let mut value = self.locals[place.local]?; for (index, proj) in place.projection.iter().enumerate() { + if let Value::Projection(pointer, ProjectionElem::Deref) = *self.get(value) + && let Value::Address { place: mut pointee, kind, .. } = *self.get(pointer) + && let AddressKind::Ref(BorrowKind::Shared) = kind + && let Some(v) = self.simplify_place_value(&mut pointee, location) + { + value = v; + place_ref = pointee.project_deeper(&place.projection[index..], self.tcx).as_ref(); + } if let Some(local) = self.try_as_local(value, location) { // Both `local` and `Place { local: place.local, projection: projection[..index] }` // hold the same value. Therefore, following place holds the value in the original @@ -735,6 +741,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { value = self.project(base, value, proj)?; } + if let Value::Projection(pointer, ProjectionElem::Deref) = *self.get(value) + && let Value::Address { place: mut pointee, kind, .. } = *self.get(pointer) + && let AddressKind::Ref(BorrowKind::Shared) = kind + && let Some(v) = self.simplify_place_value(&mut pointee, location) + { + value = v; + place_ref = pointee.project_deeper(&[], self.tcx).as_ref(); + } if let Some(new_local) = self.try_as_local(value, location) { place_ref = PlaceRef { local: new_local, projection: &[] }; } @@ -814,23 +828,18 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { // on both operands for side effect. let lhs = lhs?; let rhs = rhs?; - if let Some(value) = self.simplify_binary(op, false, ty, lhs, rhs) { - return Some(value); - } - Value::BinaryOp(op, lhs, rhs) - } - Rvalue::CheckedBinaryOp(op, box (ref mut lhs, ref mut rhs)) => { - let ty = lhs.ty(self.local_decls, self.tcx); - let lhs = self.simplify_operand(lhs, location); - let rhs = self.simplify_operand(rhs, location); - // Only short-circuit options after we called `simplify_operand` - // on both operands for side effect. - let lhs = lhs?; - let rhs = rhs?; - if let Some(value) = self.simplify_binary(op, true, ty, lhs, rhs) { - return Some(value); + + if let Some(op) = op.overflowing_to_wrapping() { + if let Some(value) = self.simplify_binary(op, true, ty, lhs, rhs) { + return Some(value); + } + Value::CheckedBinaryOp(op, lhs, rhs) + } else { + if let Some(value) = self.simplify_binary(op, false, ty, lhs, rhs) { + return Some(value); + } + Value::BinaryOp(op, lhs, rhs) } - Value::CheckedBinaryOp(op, lhs, rhs) } Rvalue::UnaryOp(op, ref mut arg) => { let arg = self.simplify_operand(arg, location)?; @@ -885,6 +894,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum, // Coroutines are never ZST, as they at least contain the implicit states. AggregateKind::Coroutine(..) => false, + AggregateKind::RawPtr(..) => bug!("MIR for RawPtr aggregate must have 2 fields"), }; if is_zst { @@ -910,6 +920,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } // Do not track unions. AggregateKind::Adt(_, _, _, _, Some(_)) => return None, + // FIXME: Do the extra work to GVN `from_raw_parts` + AggregateKind::RawPtr(..) => return None, }; let fields: Option<Vec<_>> = fields @@ -1114,9 +1126,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { if let Value::Cast { kind, from, to, .. } = self.get(inner) && let CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize) = kind && let Some(from) = from.builtin_deref(true) - && let ty::Array(_, len) = from.ty.kind() + && let ty::Array(_, len) = from.kind() && let Some(to) = to.builtin_deref(true) - && let ty::Slice(..) = to.ty.kind() + && let ty::Slice(..) = to.kind() { return self.insert_constant(Const::from_ty_const(*len, self.tcx)); } @@ -1127,7 +1139,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } fn op_to_prop_const<'tcx>( - ecx: &mut InterpCx<'_, 'tcx, DummyMachine>, + ecx: &mut InterpCx<'tcx, DummyMachine>, op: &OpTy<'tcx>, ) -> Option<ConstValue<'tcx>> { // Do not attempt to propagate unsized locals. diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 60513a674af..fe2237dd2e9 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -1,17 +1,17 @@ //! Inlining pass for MIR functions use crate::deref_separator::deref_finder; use rustc_attr::InlineAttr; -use rustc_const_eval::transform::validate::validate_types; use rustc_hir::def::DefKind; use rustc_hir::def_id::DefId; use rustc_index::bit_set::BitSet; use rustc_index::Idx; +use rustc_middle::bug; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TypeVisitableExt; use rustc_middle::ty::{self, Instance, InstanceDef, ParamEnv, Ty, TyCtxt}; -use rustc_session::config::OptLevel; +use rustc_session::config::{DebugInfo, OptLevel}; use rustc_span::source_map::Spanned; use rustc_span::sym; use rustc_target::abi::FieldIdx; @@ -20,6 +20,7 @@ use rustc_target::spec::abi::Abi; use crate::cost_checker::CostChecker; use crate::simplify::simplify_cfg; use crate::util; +use crate::validate::validate_types; use std::iter; use std::ops::{Range, RangeFrom}; @@ -332,7 +333,8 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) - | InstanceDef::FnPtrAddrShim(..) => return Ok(()), + | InstanceDef::FnPtrAddrShim(..) + | InstanceDef::AsyncDropGlueCtorShim(..) => return Ok(()), } if self.tcx.is_constructor(callee_def_id) { @@ -699,7 +701,19 @@ impl<'tcx> Inliner<'tcx> { // Insert all of the (mapped) parts of the callee body into the caller. caller_body.local_decls.extend(callee_body.drain_vars_and_temps()); caller_body.source_scopes.extend(&mut callee_body.source_scopes.drain(..)); - caller_body.var_debug_info.append(&mut callee_body.var_debug_info); + if self + .tcx + .sess + .opts + .unstable_opts + .inline_mir_preserve_debug + .unwrap_or(self.tcx.sess.opts.debuginfo != DebugInfo::None) + { + // Note that we need to preserve these in the standard library so that + // people working on rust can build with or without debuginfo while + // still getting consistent results from the mir-opt tests. + caller_body.var_debug_info.append(&mut callee_body.var_debug_info); + } caller_body.basic_blocks_mut().extend(callee_body.basic_blocks_mut().drain(..)); caller_body[callsite.block].terminator = Some(Terminator { @@ -707,18 +721,12 @@ impl<'tcx> Inliner<'tcx> { kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) }, }); - // Copy only unevaluated constants from the callee_body into the caller_body. - // Although we are only pushing `ConstKind::Unevaluated` consts to - // `required_consts`, here we may not only have `ConstKind::Unevaluated` - // because we are calling `instantiate_and_normalize_erasing_regions`. - caller_body.required_consts.extend(callee_body.required_consts.iter().copied().filter( - |&ct| match ct.const_ { - Const::Ty(_) => { - bug!("should never encounter ty::UnevaluatedConst in `required_consts`") - } - Const::Val(..) | Const::Unevaluated(..) => true, - }, - )); + // Copy required constants from the callee_body into the caller_body. Although we are only + // pushing unevaluated consts to `required_consts`, here they may have been evaluated + // because we are calling `instantiate_and_normalize_erasing_regions` -- so we filter again. + caller_body.required_consts.extend( + callee_body.required_consts.into_iter().filter(|ct| ct.const_.is_required_const()), + ); // Now that we incorporated the callee's `required_consts`, we can remove the callee from // `mentioned_items` -- but we have to take their `mentioned_items` in return. This does // some extra work here to save the monomorphization collector work later. It helps a lot, @@ -734,8 +742,9 @@ impl<'tcx> Inliner<'tcx> { caller_body.mentioned_items.remove(idx); caller_body.mentioned_items.extend(callee_body.mentioned_items); } else { - // If we can't find the callee, there's no point in adding its items. - // Probably it already got removed by being inlined elsewhere in the same function. + // If we can't find the callee, there's no point in adding its items. Probably it + // already got removed by being inlined elsewhere in the same function, so we already + // took its items. } } @@ -1071,7 +1080,8 @@ fn try_instance_mir<'tcx>( tcx: TyCtxt<'tcx>, instance: InstanceDef<'tcx>, ) -> Result<&'tcx Body<'tcx>, &'static str> { - if let ty::InstanceDef::DropGlue(_, Some(ty)) = instance + if let ty::InstanceDef::DropGlue(_, Some(ty)) + | ty::InstanceDef::AsyncDropGlueCtorShim(_, Some(ty)) = instance && let ty::Adt(def, args) = ty.kind() { let fields = def.all_fields(); diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 99c7b616f1b..8c5f965108b 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -94,8 +94,10 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::CloneShim(..) => {} // This shim does not call any other functions, thus there can be no recursion. - InstanceDef::FnPtrAddrShim(..) => continue, - InstanceDef::DropGlue(..) => { + InstanceDef::FnPtrAddrShim(..) => { + continue; + } + InstanceDef::DropGlue(..) | InstanceDef::AsyncDropGlueCtorShim(..) => { // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this // needs some more analysis. diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs index ff786d44d6a..a54332b6f25 100644 --- a/compiler/rustc_mir_transform/src/instsimplify.rs +++ b/compiler/rustc_mir_transform/src/instsimplify.rs @@ -2,13 +2,13 @@ use crate::simplify::simplify_duplicate_switch_targets; use rustc_ast::attr; +use rustc_middle::bug; use rustc_middle::mir::*; use rustc_middle::ty::layout; use rustc_middle::ty::layout::ValidityRequirement; use rustc_middle::ty::{self, GenericArgsRef, ParamEnv, Ty, TyCtxt}; use rustc_span::sym; use rustc_span::symbol::Symbol; -use rustc_target::abi::FieldIdx; use rustc_target::spec::abi::Abi; pub struct InstSimplify; @@ -36,6 +36,7 @@ impl<'tcx> MirPass<'tcx> for InstSimplify { ctx.simplify_bool_cmp(&statement.source_info, rvalue); ctx.simplify_ref_deref(&statement.source_info, rvalue); ctx.simplify_len(&statement.source_info, rvalue); + ctx.simplify_ptr_aggregate(&statement.source_info, rvalue); ctx.simplify_cast(rvalue); } _ => {} @@ -58,8 +59,17 @@ struct InstSimplifyContext<'tcx, 'a> { impl<'tcx> InstSimplifyContext<'tcx, '_> { fn should_simplify(&self, source_info: &SourceInfo, rvalue: &Rvalue<'tcx>) -> bool { + self.should_simplify_custom(source_info, "Rvalue", rvalue) + } + + fn should_simplify_custom( + &self, + source_info: &SourceInfo, + label: &str, + value: impl std::fmt::Debug, + ) -> bool { self.tcx.consider_optimizing(|| { - format!("InstSimplify - Rvalue: {rvalue:?} SourceInfo: {source_info:?}") + format!("InstSimplify - {label}: {value:?} SourceInfo: {source_info:?}") }) } @@ -111,7 +121,7 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { if a.const_.ty().is_bool() { a.const_.try_to_bool() } else { None } } - /// Transform "&(*a)" ==> "a". + /// Transform `&(*a)` ==> `a`. fn simplify_ref_deref(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Ref(_, _, place) = rvalue { if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() { @@ -131,7 +141,7 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { } } - /// Transform "Len([_; N])" ==> "N". + /// Transform `Len([_; N])` ==> `N`. fn simplify_len(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Len(ref place) = *rvalue { let place_ty = place.ty(self.local_decls, self.tcx).ty; @@ -147,6 +157,30 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { } } + /// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`. + fn simplify_ptr_aggregate(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Aggregate(box AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue + { + let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx); + if meta_ty.is_unit() { + // The mutable borrows we're holding prevent printing `rvalue` here + if !self.should_simplify_custom( + source_info, + "Aggregate::RawPtr", + (&pointee_ty, *mutability, &fields), + ) { + return; + } + + let mut fields = std::mem::take(fields); + let _meta = fields.pop().unwrap(); + let data = fields.pop().unwrap(); + let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability); + *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty); + } + } + } + fn simplify_ub_check(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue { let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks()); @@ -182,13 +216,11 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { && let Some(place) = operand.place() { let variant = adt_def.non_enum_variant(); - for (i, field) in variant.fields.iter().enumerate() { + for (i, field) in variant.fields.iter_enumerated() { let field_ty = field.ty(self.tcx, args); if field_ty == *cast_ty { - let place = place.project_deeper( - &[ProjectionElem::Field(FieldIdx::from_usize(i), *cast_ty)], - self.tcx, - ); + let place = place + .project_deeper(&[ProjectionElem::Field(i, *cast_ty)], self.tcx); let operand = if operand.is_move() { Operand::Move(place) } else { diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs index a458297210d..23cc0c46e73 100644 --- a/compiler/rustc_mir_transform/src/jump_threading.rs +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -41,6 +41,7 @@ use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable} use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; @@ -154,7 +155,7 @@ struct ThreadingOpportunity { struct TOFinder<'tcx, 'a> { tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>, - ecx: InterpCx<'tcx, 'tcx, DummyMachine>, + ecx: InterpCx<'tcx, DummyMachine>, body: &'a Body<'tcx>, map: &'a Map, loop_headers: &'a BitSet<BasicBlock>, diff --git a/compiler/rustc_mir_transform/src/known_panics_lint.rs b/compiler/rustc_mir_transform/src/known_panics_lint.rs index 2218154ea5e..9ba22870403 100644 --- a/compiler/rustc_mir_transform/src/known_panics_lint.rs +++ b/compiler/rustc_mir_transform/src/known_panics_lint.rs @@ -14,6 +14,7 @@ use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; use rustc_hir::HirId; use rustc_index::{bit_set::BitSet, IndexVec}; +use rustc_middle::bug; use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; @@ -63,7 +64,7 @@ impl<'tcx> MirLint<'tcx> for KnownPanicsLint { /// Visits MIR nodes, performs const propagation /// and runs lint checks as it goes struct ConstPropagator<'mir, 'tcx> { - ecx: InterpCx<'mir, 'tcx, DummyMachine>, + ecx: InterpCx<'tcx, DummyMachine>, tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, worklist: Vec<BasicBlock>, @@ -303,20 +304,25 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>, location: Location) -> Option<()> { let arg = self.eval_operand(arg)?; - if let (val, true) = self.use_ecx(|this| { - let val = this.ecx.read_immediate(&arg)?; - let (_res, overflow) = this.ecx.overflowing_unary_op(op, &val)?; - Ok((val, overflow)) - })? { - // `AssertKind` only has an `OverflowNeg` variant, so make sure that is - // appropriate to use. - assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); - self.report_assert_as_lint( - location, - AssertLintKind::ArithmeticOverflow, - AssertKind::OverflowNeg(val.to_const_int()), - ); - return None; + // The only operator that can overflow is `Neg`. + if op == UnOp::Neg && arg.layout.ty.is_integral() { + // Compute this as `0 - arg` so we can use `SubWithOverflow` to check for overflow. + let (arg, overflow) = self.use_ecx(|this| { + let arg = this.ecx.read_immediate(&arg)?; + let (_res, overflow) = this + .ecx + .binary_op(BinOp::SubWithOverflow, &ImmTy::from_int(0, arg.layout), &arg)? + .to_scalar_pair(); + Ok((arg, overflow.to_bool()?)) + })?; + if overflow { + self.report_assert_as_lint( + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::OverflowNeg(arg.to_const_int()), + ); + return None; + } } Some(()) @@ -362,11 +368,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } } - if let (Some(l), Some(r)) = (l, r) { - // The remaining operators are handled through `overflowing_binary_op`. + // Div/Rem are handled via the assertions they trigger. + // But for Add/Sub/Mul, those assertions only exist in debug builds, and we want to + // lint in release builds as well, so we check on the operation instead. + // So normalize to the "overflowing" operator, and then ensure that it + // actually is an overflowing operator. + let op = op.wrapping_to_overflowing().unwrap_or(op); + // The remaining operators are handled through `wrapping_to_overflowing`. + if let (Some(l), Some(r)) = (l, r) + && l.layout.ty.is_integral() + && op.is_overflowing() + { if self.use_ecx(|this| { - let (_res, overflow) = this.ecx.overflowing_binary_op(op, &l, &r)?; - Ok(overflow) + let (_res, overflow) = this.ecx.binary_op(op, &l, &r)?.to_scalar_pair(); + overflow.to_bool() })? { self.report_assert_as_lint( location, @@ -400,15 +415,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); self.check_binary_op(*op, left, right, location)?; } - Rvalue::CheckedBinaryOp(op, box (left, right)) => { - trace!( - "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", - op, - left, - right - ); - self.check_binary_op(*op, left, right, location)?; - } // Do not try creating references (#67862) Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { @@ -554,24 +560,16 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let right = self.eval_operand(right)?; let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; - let val = - self.use_ecx(|this| this.ecx.wrapping_binary_op(bin_op, &left, &right))?; - val.into() - } - - CheckedBinaryOp(bin_op, box (ref left, ref right)) => { - let left = self.eval_operand(left)?; - let left = self.use_ecx(|this| this.ecx.read_immediate(&left))?; - - let right = self.eval_operand(right)?; - let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; - - let (val, overflowed) = - self.use_ecx(|this| this.ecx.overflowing_binary_op(bin_op, &left, &right))?; - let overflowed = ImmTy::from_bool(overflowed, self.tcx); - Value::Aggregate { - variant: VariantIdx::ZERO, - fields: [Value::from(val), overflowed.into()].into_iter().collect(), + let val = self.use_ecx(|this| this.ecx.binary_op(bin_op, &left, &right))?; + if matches!(val.layout.abi, Abi::ScalarPair(..)) { + // FIXME `Value` should properly support pairs in `Immediate`... but currently it does not. + let (val, overflow) = val.to_pair(&self.ecx); + Value::Aggregate { + variant: VariantIdx::ZERO, + fields: [val.into(), overflow.into()].into_iter().collect(), + } + } else { + val.into() } } @@ -579,36 +577,25 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let operand = self.eval_operand(operand)?; let val = self.use_ecx(|this| this.ecx.read_immediate(&operand))?; - let val = self.use_ecx(|this| this.ecx.wrapping_unary_op(un_op, &val))?; + let val = self.use_ecx(|this| this.ecx.unary_op(un_op, &val))?; val.into() } - Aggregate(ref kind, ref fields) => { - // Do not const prop union fields as they can be - // made to produce values that don't match their - // underlying layout's type (see ICE #121534). - // If the last element of the `Adt` tuple - // is `Some` it indicates the ADT is a union - if let AggregateKind::Adt(_, _, _, _, Some(_)) = **kind { - return None; - }; - Value::Aggregate { - fields: fields - .iter() - .map(|field| { - self.eval_operand(field).map_or(Value::Uninit, Value::Immediate) - }) - .collect(), - variant: match **kind { - AggregateKind::Adt(_, variant, _, _, _) => variant, - AggregateKind::Array(_) - | AggregateKind::Tuple - | AggregateKind::Closure(_, _) - | AggregateKind::Coroutine(_, _) - | AggregateKind::CoroutineClosure(_, _) => VariantIdx::ZERO, - }, - } - } + Aggregate(ref kind, ref fields) => Value::Aggregate { + fields: fields + .iter() + .map(|field| self.eval_operand(field).map_or(Value::Uninit, Value::Immediate)) + .collect(), + variant: match **kind { + AggregateKind::Adt(_, variant, _, _, _) => variant, + AggregateKind::Array(_) + | AggregateKind::Tuple + | AggregateKind::RawPtr(_, _) + | AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(_, _) => VariantIdx::ZERO, + }, + }, Repeat(ref op, n) => { trace!(?op, ?n); @@ -796,7 +783,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { if let Some(ref value) = self.eval_operand(discr) && let Some(value_const) = self.use_ecx(|this| this.ecx.read_scalar(value)) && let Ok(constant) = value_const.try_to_int() - && let Ok(constant) = constant.to_bits(constant.size()) + && let Ok(constant) = constant.try_to_bits(constant.size()) { // We managed to evaluate the discriminant, so we know we only need to visit // one target. @@ -895,13 +882,20 @@ impl CanConstProp { }; for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { let ty = body.local_decls[local].ty; - match tcx.layout_of(param_env.and(ty)) { - Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} - // Either the layout fails to compute, then we can't use this local anyway - // or the local is too large, then we don't want to. - _ => { - *val = ConstPropMode::NoPropagation; - continue; + if ty.is_union() { + // Unions are incompatible with the current implementation of + // const prop because Rust has no concept of an active + // variant of a union + *val = ConstPropMode::NoPropagation; + } else { + match tcx.layout_of(param_env.and(ty)) { + Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} + // Either the layout fails to compute, then we can't use this local anyway + // or the local is too large, then we don't want to. + _ => { + *val = ConstPropMode::NoPropagation; + continue; + } } } } diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs index 8be96b6ba8f..e407929c9a7 100644 --- a/compiler/rustc_mir_transform/src/large_enums.rs +++ b/compiler/rustc_mir_transform/src/large_enums.rs @@ -1,7 +1,7 @@ -use crate::rustc_middle::ty::util::IntTypeExt; use rustc_data_structures::fx::FxHashMap; use rustc_middle::mir::interpret::AllocId; use rustc_middle::mir::*; +use rustc_middle::ty::util::IntTypeExt; use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt}; use rustc_session::Session; use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants}; diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index e477c068229..93ae105150c 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -4,7 +4,6 @@ #![feature(cow_is_borrowed)] #![feature(decl_macro)] #![feature(impl_trait_in_assoc_type)] -#![feature(inline_const)] #![feature(is_sorted)] #![feature(let_chains)] #![feature(map_try_insert)] @@ -17,8 +16,6 @@ #[macro_use] extern crate tracing; -#[macro_use] -extern crate rustc_middle; use hir::ConstContext; use required_consts::RequiredConstsVisitor; @@ -39,6 +36,7 @@ use rustc_middle::mir::{ use rustc_middle::query; use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt}; use rustc_middle::util::Providers; +use rustc_middle::{bug, span_bug}; use rustc_span::{source_map::Spanned, sym, DUMMY_SP}; use rustc_trait_selection::traits; @@ -111,9 +109,9 @@ mod simplify_comparison_integral; mod sroa; mod unreachable_enum_branching; mod unreachable_prop; +mod validate; -use rustc_const_eval::transform::check_consts::{self, ConstCx}; -use rustc_const_eval::transform::validate; +use rustc_const_eval::check_consts::{self, ConstCx}; use rustc_mir_dataflow::rustc_peek; rustc_fluent_macro::fluent_messages! { "../messages.ftl" } @@ -213,7 +211,7 @@ fn remap_mir_for_const_eval_select<'tcx>( } fn is_mir_available(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { - tcx.mir_keys(()).contains(&def_id) + tcx.hir().maybe_body_owned_by(def_id).is_some() } /// Finds the full set of `DefId`s within the current crate that have @@ -224,6 +222,15 @@ fn mir_keys(tcx: TyCtxt<'_>, (): ()) -> FxIndexSet<LocalDefId> { // All body-owners have MIR associated with them. set.extend(tcx.hir().body_owners()); + // Inline consts' bodies are created in + // typeck instead of during ast lowering, like all other bodies so far. + for def_id in tcx.hir().body_owners() { + // Incremental performance optimization: only load typeck results for things that actually have inline consts + if tcx.hir_owner_nodes(tcx.hir().body_owned_by(def_id).hir_id.owner).has_inline_consts { + set.extend(tcx.typeck(def_id).inline_consts.values()) + } + } + // Additionally, tuple struct/variant constructors have MIR, but // they don't have a BodyId, so we need to build them separately. struct GatherCtors<'a> { @@ -333,6 +340,8 @@ fn mir_promoted( body.tainted_by_errors = Some(error_reported); } + // Collect `required_consts` *before* promotion, so if there are any consts being promoted + // we still add them to the list in the outer MIR body. let mut required_consts = Vec::new(); let mut required_consts_visitor = RequiredConstsVisitor::new(&mut required_consts); for (bb, bb_data) in traversal::reverse_postorder(&body) { diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs index 7e8920604c1..3ffc447217d 100644 --- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -2,6 +2,7 @@ use rustc_middle::mir::*; use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::{bug, span_bug}; use rustc_span::symbol::sym; pub struct LowerIntrinsics; @@ -139,16 +140,16 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { rhs = args.next().unwrap(); } let bin_op = match intrinsic.name { - sym::add_with_overflow => BinOp::Add, - sym::sub_with_overflow => BinOp::Sub, - sym::mul_with_overflow => BinOp::Mul, + sym::add_with_overflow => BinOp::AddWithOverflow, + sym::sub_with_overflow => BinOp::SubWithOverflow, + sym::mul_with_overflow => BinOp::MulWithOverflow, _ => bug!("unexpected intrinsic"), }; block.statements.push(Statement { source_info: terminator.source_info, kind: StatementKind::Assign(Box::new(( *destination, - Rvalue::CheckedBinaryOp(bin_op, Box::new((lhs.node, rhs.node))), + Rvalue::BinaryOp(bin_op, Box::new((lhs.node, rhs.node))), ))), }); terminator.kind = TerminatorKind::Goto { target }; @@ -287,6 +288,51 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { terminator.kind = TerminatorKind::Unreachable; } } + sym::aggregate_raw_ptr => { + let Ok([data, meta]) = <[_; 2]>::try_from(std::mem::take(args)) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for aggregate_raw_ptr intrinsic", + ); + }; + let target = target.unwrap(); + let pointer_ty = generic_args.type_at(0); + let kind = if let ty::RawPtr(pointee_ty, mutability) = pointer_ty.kind() { + AggregateKind::RawPtr(*pointee_ty, *mutability) + } else { + span_bug!( + terminator.source_info.span, + "Return type of aggregate_raw_ptr intrinsic must be a raw pointer", + ); + }; + let fields = [data.node, meta.node]; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Aggregate(Box::new(kind), fields.into()), + ))), + }); + + terminator.kind = TerminatorKind::Goto { target }; + } + sym::ptr_metadata => { + let Ok([ptr]) = <[_; 1]>::try_from(std::mem::take(args)) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for ptr_metadata intrinsic", + ); + }; + let target = target.unwrap(); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::UnaryOp(UnOp::PtrMetadata, ptr.node), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } _ => {} } } diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs index 8137525a332..2267a621a83 100644 --- a/compiler/rustc_mir_transform/src/lower_slice_len.rs +++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs @@ -21,7 +21,7 @@ impl<'tcx> MirPass<'tcx> for LowerSliceLenCalls { pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let language_items = tcx.lang_items(); let Some(slice_len_fn_item_def_id) = language_items.slice_len_fn() else { - // there is no language item to compare to :) + // there is no lang item to compare to :) return; }; diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 4d9a198eeb2..1411d9be223 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -41,7 +41,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { should_cleanup = true; continue; } - if SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env).is_some() { + // unsound: https://github.com/rust-lang/rust/issues/124150 + if tcx.sess.opts.unstable_opts.unsound_mir_opts + && SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env).is_some() + { should_cleanup = true; continue; } @@ -369,8 +372,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { } fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool { - l.try_to_int(l.size()).unwrap() - == ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap() + l.assert_int(l.size()) == ScalarInt::try_from_uint(r, size).unwrap().assert_int(size) } // We first compare the two branches, and then the other branches need to fulfill the same conditions. diff --git a/compiler/rustc_mir_transform/src/mentioned_items.rs b/compiler/rustc_mir_transform/src/mentioned_items.rs index 57b6126dece..db2bb60bdac 100644 --- a/compiler/rustc_mir_transform/src/mentioned_items.rs +++ b/compiler/rustc_mir_transform/src/mentioned_items.rs @@ -79,8 +79,8 @@ impl<'tcx> Visitor<'tcx> for MentionedItemsVisitor<'_, 'tcx> { // add everything that may involve a vtable. let source_ty = operand.ty(self.body, self.tcx); let may_involve_vtable = match ( - source_ty.builtin_deref(true).map(|t| t.ty.kind()), - target_ty.builtin_deref(true).map(|t| t.ty.kind()), + source_ty.builtin_deref(true).map(|t| t.kind()), + target_ty.builtin_deref(true).map(|t| t.kind()), ) { (Some(ty::Array(..)), Some(ty::Str | ty::Slice(..))) => false, // &str/&[T] unsizing _ => true, diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs index 128634bd7f2..2070895c900 100644 --- a/compiler/rustc_mir_transform/src/normalize_array_len.rs +++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs @@ -22,7 +22,8 @@ impl<'tcx> MirPass<'tcx> for NormalizeArrayLen { } fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let ssa = SsaLocals::new(body); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let ssa = SsaLocals::new(tcx, body, param_env); let slice_lengths = compute_slice_length(tcx, &ssa, body); debug!(?slice_lengths); @@ -47,9 +48,9 @@ fn compute_slice_length<'tcx>( let operand_ty = operand.ty(body, tcx); debug!(?operand_ty); if let Some(operand_ty) = operand_ty.builtin_deref(true) - && let ty::Array(_, len) = operand_ty.ty.kind() + && let ty::Array(_, len) = operand_ty.kind() && let Some(cast_ty) = cast_ty.builtin_deref(true) - && let ty::Slice(..) = cast_ty.ty.kind() + && let ty::Slice(..) = cast_ty.kind() { slice_lengths[local] = Some(*len); } diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs index 232c290e0fb..885dbd5f339 100644 --- a/compiler/rustc_mir_transform/src/nrvo.rs +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -2,6 +2,7 @@ use rustc_hir::Mutability; use rustc_index::bit_set::BitSet; +use rustc_middle::bug; use rustc_middle::mir::visit::{MutVisitor, NonUseContext, PlaceContext, Visitor}; use rustc_middle::mir::{self, BasicBlock, Local, Location}; use rustc_middle::ty::TyCtxt; diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs index a9d4b860b7a..7ec59cc983f 100644 --- a/compiler/rustc_mir_transform/src/promote_consts.rs +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -13,12 +13,14 @@ //! move analysis runs after promotion on broken MIR. use either::{Left, Right}; +use rustc_data_structures::fx::FxHashSet; use rustc_hir as hir; use rustc_middle::mir; use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, List, Ty, TyCtxt, TypeVisitableExt}; +use rustc_middle::{bug, span_bug}; use rustc_span::Span; use rustc_index::{Idx, IndexSlice, IndexVec}; @@ -28,7 +30,7 @@ use std::assert_matches::assert_matches; use std::cell::Cell; use std::{cmp, iter, mem}; -use rustc_const_eval::transform::check_consts::{qualifs, ConstCx}; +use rustc_const_eval::check_consts::{qualifs, ConstCx}; /// A `MirPass` for promotion. /// @@ -175,6 +177,12 @@ fn collect_temps_and_candidates<'tcx>( struct Validator<'a, 'tcx> { ccx: &'a ConstCx<'a, 'tcx>, temps: &'a mut IndexSlice<Local, TempState>, + /// For backwards compatibility, we are promoting function calls in `const`/`static` + /// initializers. But we want to avoid evaluating code that might panic and that otherwise would + /// not have been evaluated, so we only promote such calls in basic blocks that are guaranteed + /// to execute. In other words, we only promote such calls in basic blocks that are definitely + /// not dead code. Here we cache the result of computing that set of basic blocks. + promotion_safe_blocks: Option<FxHashSet<BasicBlock>>, } impl<'a, 'tcx> std::ops::Deref for Validator<'a, 'tcx> { @@ -260,7 +268,9 @@ impl<'tcx> Validator<'_, 'tcx> { self.validate_rvalue(rhs) } Right(terminator) => match &terminator.kind { - TerminatorKind::Call { func, args, .. } => self.validate_call(func, args), + TerminatorKind::Call { func, args, .. } => { + self.validate_call(func, args, loc.block) + } TerminatorKind::Yield { .. } => Err(Unpromotable), kind => { span_bug!(terminator.source_info.span, "{:?} not promotable", kind); @@ -384,7 +394,7 @@ impl<'tcx> Validator<'_, 'tcx> { match kind { // Reject these borrow types just to be safe. // FIXME(RalfJung): could we allow them? Should we? No point in it until we have a usecase. - BorrowKind::Fake | BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture } => { + BorrowKind::Fake(_) | BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture } => { return Err(Unpromotable); } @@ -454,13 +464,13 @@ impl<'tcx> Validator<'_, 'tcx> { Rvalue::UnaryOp(op, operand) => { match op { // These operations can never fail. - UnOp::Neg | UnOp::Not => {} + UnOp::Neg | UnOp::Not | UnOp::PtrMetadata => {} } self.validate_operand(operand)?; } - Rvalue::BinaryOp(op, box (lhs, rhs)) | Rvalue::CheckedBinaryOp(op, box (lhs, rhs)) => { + Rvalue::BinaryOp(op, box (lhs, rhs)) => { let op = *op; let lhs_ty = lhs.ty(self.body, self.tcx); @@ -490,14 +500,14 @@ impl<'tcx> Validator<'_, 'tcx> { } _ => None, }; - match rhs_val.map(|x| x.try_to_uint(sz).unwrap()) { + match rhs_val.map(|x| x.assert_uint(sz)) { // for the zero test, int vs uint does not matter Some(x) if x != 0 => {} // okay _ => return Err(Unpromotable), // value not known or 0 -- not okay } // Furthermore, for signed divison, we also have to exclude `int::MIN / -1`. if lhs_ty.is_signed() { - match rhs_val.map(|x| x.try_to_int(sz).unwrap()) { + match rhs_val.map(|x| x.assert_int(sz)) { Some(-1) | None => { // The RHS is -1 or unknown, so we have to be careful. // But is the LHS int::MIN? @@ -508,7 +518,7 @@ impl<'tcx> Validator<'_, 'tcx> { _ => None, }; let lhs_min = sz.signed_int_min(); - match lhs_val.map(|x| x.try_to_int(sz).unwrap()) { + match lhs_val.map(|x| x.assert_int(sz)) { Some(x) if x != lhs_min => {} // okay _ => return Err(Unpromotable), // value not known or int::MIN -- not okay } @@ -529,10 +539,13 @@ impl<'tcx> Validator<'_, 'tcx> { | BinOp::Offset | BinOp::Add | BinOp::AddUnchecked + | BinOp::AddWithOverflow | BinOp::Sub | BinOp::SubUnchecked + | BinOp::SubWithOverflow | BinOp::Mul | BinOp::MulUnchecked + | BinOp::MulWithOverflow | BinOp::BitXor | BinOp::BitAnd | BinOp::BitOr @@ -588,29 +601,79 @@ impl<'tcx> Validator<'_, 'tcx> { Ok(()) } + /// Computes the sets of blocks of this MIR that are definitely going to be executed + /// if the function returns successfully. That makes it safe to promote calls in them + /// that might fail. + fn promotion_safe_blocks(body: &mir::Body<'tcx>) -> FxHashSet<BasicBlock> { + let mut safe_blocks = FxHashSet::default(); + let mut safe_block = START_BLOCK; + loop { + safe_blocks.insert(safe_block); + // Let's see if we can find another safe block. + safe_block = match body.basic_blocks[safe_block].terminator().kind { + TerminatorKind::Goto { target } => target, + TerminatorKind::Call { target: Some(target), .. } + | TerminatorKind::Drop { target, .. } => { + // This calls a function or the destructor. `target` does not get executed if + // the callee loops or panics. But in both cases the const already fails to + // evaluate, so we are fine considering `target` a safe block for promotion. + target + } + TerminatorKind::Assert { target, .. } => { + // Similar to above, we only consider successful execution. + target + } + _ => { + // No next safe block. + break; + } + }; + } + safe_blocks + } + + /// Returns whether the block is "safe" for promotion, which means it cannot be dead code. + /// We use this to avoid promoting operations that can fail in dead code. + fn is_promotion_safe_block(&mut self, block: BasicBlock) -> bool { + let body = self.body; + let safe_blocks = + self.promotion_safe_blocks.get_or_insert_with(|| Self::promotion_safe_blocks(body)); + safe_blocks.contains(&block) + } + fn validate_call( &mut self, callee: &Operand<'tcx>, args: &[Spanned<Operand<'tcx>>], + block: BasicBlock, ) -> Result<(), Unpromotable> { + // Validate the operands. If they fail, there's no question -- we cannot promote. + self.validate_operand(callee)?; + for arg in args { + self.validate_operand(&arg.node)?; + } + + // Functions marked `#[rustc_promotable]` are explicitly allowed to be promoted, so we can + // accept them at this point. let fn_ty = callee.ty(self.body, self.tcx); + if let ty::FnDef(def_id, _) = *fn_ty.kind() { + if self.tcx.is_promotable_const_fn(def_id) { + return Ok(()); + } + } - // Inside const/static items, we promote all (eligible) function calls. - // Everywhere else, we require `#[rustc_promotable]` on the callee. - let promote_all_const_fn = matches!( + // Ideally, we'd stop here and reject the rest. + // But for backward compatibility, we have to accept some promotion in const/static + // initializers. Inline consts are explicitly excluded, they are more recent so we have no + // backwards compatibility reason to allow more promotion inside of them. + let promote_all_fn = matches!( self.const_kind, Some(hir::ConstContext::Static(_) | hir::ConstContext::Const { inline: false }) ); - if !promote_all_const_fn { - if let ty::FnDef(def_id, _) = *fn_ty.kind() { - // Never promote runtime `const fn` calls of - // functions without `#[rustc_promotable]`. - if !self.tcx.is_promotable_const_fn(def_id) { - return Err(Unpromotable); - } - } + if !promote_all_fn { + return Err(Unpromotable); } - + // Make sure the callee is a `const fn`. let is_const_fn = match *fn_ty.kind() { ty::FnDef(def_id, _) => self.tcx.is_const_fn_raw(def_id), _ => false, @@ -618,23 +681,23 @@ impl<'tcx> Validator<'_, 'tcx> { if !is_const_fn { return Err(Unpromotable); } - - self.validate_operand(callee)?; - for arg in args { - self.validate_operand(&arg.node)?; + // The problem is, this may promote calls to functions that panic. + // We don't want to introduce compilation errors if there's a panic in a call in dead code. + // So we ensure that this is not dead code. + if !self.is_promotion_safe_block(block) { + return Err(Unpromotable); } - + // This passed all checks, so let's accept. Ok(()) } } -// FIXME(eddyb) remove the differences for promotability in `static`, `const`, `const fn`. fn validate_candidates( ccx: &ConstCx<'_, '_>, temps: &mut IndexSlice<Local, TempState>, candidates: &[Candidate], ) -> Vec<Candidate> { - let mut validator = Validator { ccx, temps }; + let mut validator = Validator { ccx, temps, promotion_safe_blocks: None }; candidates .iter() @@ -653,6 +716,10 @@ struct Promoter<'a, 'tcx> { /// If true, all nested temps are also kept in the /// source MIR, not moved to the promoted MIR. keep_original: bool, + + /// If true, add the new const (the promoted) to the required_consts of the parent MIR. + /// This is initially false and then set by the visitor when it encounters a `Call` terminator. + add_to_required: bool, } impl<'a, 'tcx> Promoter<'a, 'tcx> { @@ -755,6 +822,10 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { TerminatorKind::Call { mut func, mut args, call_source: desugar, fn_span, .. } => { + // This promoted involves a function call, so it may fail to evaluate. + // Let's make sure it is added to `required_consts` so that that failure cannot get lost. + self.add_to_required = true; + self.visit_operand(&mut func, loc); for arg in &mut args { self.visit_operand(&mut arg.node, loc); @@ -789,7 +860,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { fn promote_candidate(mut self, candidate: Candidate, next_promoted_id: usize) -> Body<'tcx> { let def = self.source.source.def_id(); - let mut rvalue = { + let (mut rvalue, promoted_op) = { let promoted = &mut self.promoted; let promoted_id = Promoted::new(next_promoted_id); let tcx = self.tcx; @@ -799,11 +870,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { let args = tcx.erase_regions(GenericArgs::identity_for_item(tcx, def)); let uneval = mir::UnevaluatedConst { def, args, promoted: Some(promoted_id) }; - Operand::Constant(Box::new(ConstOperand { - span, - user_ty: None, - const_: Const::Unevaluated(uneval, ty), - })) + ConstOperand { span, user_ty: None, const_: Const::Unevaluated(uneval, ty) } }; let blocks = self.source.basic_blocks.as_mut(); @@ -836,22 +903,26 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { let promoted_ref = local_decls.push(promoted_ref); assert_eq!(self.temps.push(TempState::Unpromotable), promoted_ref); + let promoted_operand = promoted_operand(ref_ty, span); let promoted_ref_statement = Statement { source_info: statement.source_info, kind: StatementKind::Assign(Box::new(( Place::from(promoted_ref), - Rvalue::Use(promoted_operand(ref_ty, span)), + Rvalue::Use(Operand::Constant(Box::new(promoted_operand))), ))), }; self.extra_statements.push((loc, promoted_ref_statement)); - Rvalue::Ref( - tcx.lifetimes.re_erased, - *borrow_kind, - Place { - local: mem::replace(&mut place.local, promoted_ref), - projection: List::empty(), - }, + ( + Rvalue::Ref( + tcx.lifetimes.re_erased, + *borrow_kind, + Place { + local: mem::replace(&mut place.local, promoted_ref), + projection: List::empty(), + }, + ), + promoted_operand, ) }; @@ -863,6 +934,12 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { let span = self.promoted.span; self.assign(RETURN_PLACE, rvalue, span); + + // Now that we did promotion, we know whether we'll want to add this to `required_consts`. + if self.add_to_required { + self.source.required_consts.push(promoted_op); + } + self.promoted } } @@ -878,6 +955,14 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Promoter<'a, 'tcx> { *local = self.promote_temp(*local); } } + + fn visit_constant(&mut self, constant: &mut ConstOperand<'tcx>, _location: Location) { + if constant.const_.is_required_const() { + self.promoted.required_consts.push(*constant); + } + + // Skipping `super_constant` as the visitor is otherwise only looking for locals. + } } fn promote_candidates<'tcx>( @@ -931,8 +1016,10 @@ fn promote_candidates<'tcx>( temps: &mut temps, extra_statements: &mut extra_statements, keep_original: false, + add_to_required: false, }; + // `required_consts` of the promoted itself gets filled while building the MIR body. let mut promoted = promoter.promote_candidate(candidate, promotions.len()); promoted.source.promoted = Some(promotions.next_index()); promotions.push(promoted); diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs index d5642be5513..801ef14c9cd 100644 --- a/compiler/rustc_mir_transform/src/ref_prop.rs +++ b/compiler/rustc_mir_transform/src/ref_prop.rs @@ -1,6 +1,7 @@ use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; @@ -82,7 +83,8 @@ impl<'tcx> MirPass<'tcx> for ReferencePropagation { } fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { - let ssa = SsaLocals::new(body); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let ssa = SsaLocals::new(tcx, body, param_env); let mut replacer = compute_replacement(tcx, body, &ssa); debug!(?replacer.targets); diff --git a/compiler/rustc_mir_transform/src/required_consts.rs b/compiler/rustc_mir_transform/src/required_consts.rs index abde6a47e83..71ac929d35e 100644 --- a/compiler/rustc_mir_transform/src/required_consts.rs +++ b/compiler/rustc_mir_transform/src/required_consts.rs @@ -1,6 +1,5 @@ use rustc_middle::mir::visit::Visitor; -use rustc_middle::mir::{Const, ConstOperand, Location}; -use rustc_middle::ty::ConstKind; +use rustc_middle::mir::{ConstOperand, Location}; pub struct RequiredConstsVisitor<'a, 'tcx> { required_consts: &'a mut Vec<ConstOperand<'tcx>>, @@ -14,14 +13,8 @@ impl<'a, 'tcx> RequiredConstsVisitor<'a, 'tcx> { impl<'tcx> Visitor<'tcx> for RequiredConstsVisitor<'_, 'tcx> { fn visit_constant(&mut self, constant: &ConstOperand<'tcx>, _: Location) { - let const_ = constant.const_; - match const_ { - Const::Ty(c) => match c.kind() { - ConstKind::Param(_) | ConstKind::Error(_) | ConstKind::Value(_) => {} - _ => bug!("only ConstKind::Param/Value should be encountered here, got {:#?}", c), - }, - Const::Unevaluated(..) => self.required_consts.push(*constant), - Const::Val(..) => {} + if constant.const_.is_required_const() { + self.required_consts.push(*constant); } } } diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index fa6906bdd55..dcf54ad2cfc 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -5,6 +5,7 @@ use rustc_middle::mir::*; use rustc_middle::query::Providers; use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt}; +use rustc_middle::{bug, span_bug}; use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_index::{Idx, IndexVec}; @@ -22,6 +23,8 @@ use crate::{ use rustc_middle::mir::patch::MirPatch; use rustc_mir_dataflow::elaborate_drops::{self, DropElaborator, DropFlagMode, DropStyle}; +mod async_destructor_ctor; + pub fn provide(providers: &mut Providers) { providers.mir_shims = make_shim; } @@ -127,6 +130,9 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' ty::InstanceDef::ThreadLocalShim(..) => build_thread_local_shim(tcx, instance), ty::InstanceDef::CloneShim(def_id, ty) => build_clone_shim(tcx, def_id, ty), ty::InstanceDef::FnPtrAddrShim(def_id, ty) => build_fn_ptr_addr_shim(tcx, def_id, ty), + ty::InstanceDef::AsyncDropGlueCtorShim(def_id, ty) => { + async_destructor_ctor::build_async_destructor_ctor_shim(tcx, def_id, ty) + } ty::InstanceDef::Virtual(..) => { bug!("InstanceDef::Virtual ({:?}) is for direct calls only", instance) } @@ -1041,7 +1047,7 @@ fn build_construct_coroutine_by_move_shim<'tcx>( args.as_coroutine_closure().coroutine_captures_by_ref_ty(), ), sig.c_variadic, - sig.unsafety, + sig.safety, sig.abi, ) }); diff --git a/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs b/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs new file mode 100644 index 00000000000..f4481c22fc1 --- /dev/null +++ b/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs @@ -0,0 +1,618 @@ +use std::iter; + +use itertools::Itertools; +use rustc_ast::Mutability; +use rustc_const_eval::interpret; +use rustc_hir::def_id::DefId; +use rustc_hir::lang_items::LangItem; +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, Body, CallSource, CastKind, Const, ConstOperand, ConstValue, Local, + LocalDecl, MirSource, Operand, Place, PlaceElem, Rvalue, SourceInfo, Statement, StatementKind, + Terminator, TerminatorKind, UnwindAction, UnwindTerminateReason, RETURN_PLACE, +}; +use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::util::Discr; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::{bug, span_bug}; +use rustc_span::source_map::respan; +use rustc_span::{Span, Symbol}; +use rustc_target::abi::{FieldIdx, VariantIdx}; +use rustc_target::spec::PanicStrategy; + +use super::{local_decls_for_sig, new_body}; + +pub fn build_async_destructor_ctor_shim<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, + ty: Option<Ty<'tcx>>, +) -> Body<'tcx> { + debug!("build_drop_shim(def_id={:?}, ty={:?})", def_id, ty); + + AsyncDestructorCtorShimBuilder::new(tcx, def_id, ty).build() +} + +/// Builder for async_drop_in_place shim. Functions as a stack machine +/// to build up an expression using combinators. Stack contains pairs +/// of locals and types. Combinator is a not yet instantiated pair of a +/// function and a type, is considered to be an operator which consumes +/// operands from the stack by instantiating its function and its type +/// with operand types and moving locals into the function call. Top +/// pair is considered to be the last operand. +// FIXME: add mir-opt tests +struct AsyncDestructorCtorShimBuilder<'tcx> { + tcx: TyCtxt<'tcx>, + def_id: DefId, + self_ty: Option<Ty<'tcx>>, + span: Span, + source_info: SourceInfo, + param_env: ty::ParamEnv<'tcx>, + + stack: Vec<Operand<'tcx>>, + last_bb: BasicBlock, + top_cleanup_bb: Option<BasicBlock>, + + locals: IndexVec<Local, LocalDecl<'tcx>>, + bbs: IndexVec<BasicBlock, BasicBlockData<'tcx>>, +} + +#[derive(Clone, Copy)] +enum SurfaceDropKind { + Async, + Sync, +} + +impl<'tcx> AsyncDestructorCtorShimBuilder<'tcx> { + const SELF_PTR: Local = Local::from_u32(1); + const INPUT_COUNT: usize = 1; + const MAX_STACK_LEN: usize = 2; + + fn new(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Option<Ty<'tcx>>) -> Self { + let args = if let Some(ty) = self_ty { + tcx.mk_args(&[ty.into()]) + } else { + ty::GenericArgs::identity_for_item(tcx, def_id) + }; + let sig = tcx.fn_sig(def_id).instantiate(tcx, args); + let sig = tcx.instantiate_bound_regions_with_erased(sig); + let span = tcx.def_span(def_id); + + let source_info = SourceInfo::outermost(span); + + debug_assert_eq!(sig.inputs().len(), Self::INPUT_COUNT); + let locals = local_decls_for_sig(&sig, span); + + // Usual case: noop() + unwind resume + return + let mut bbs = IndexVec::with_capacity(3); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + AsyncDestructorCtorShimBuilder { + tcx, + def_id, + self_ty, + span, + source_info, + param_env, + + stack: Vec::with_capacity(Self::MAX_STACK_LEN), + last_bb: bbs.push(BasicBlockData::new(None)), + top_cleanup_bb: match tcx.sess.panic_strategy() { + PanicStrategy::Unwind => { + // Don't drop input arg because it's just a pointer + Some(bbs.push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::UnwindResume, + }), + is_cleanup: true, + })) + } + PanicStrategy::Abort => None, + }, + + locals, + bbs, + } + } + + fn build(self) -> Body<'tcx> { + let (tcx, def_id, Some(self_ty)) = (self.tcx, self.def_id, self.self_ty) else { + return self.build_zst_output(); + }; + + let surface_drop_kind = || { + let param_env = tcx.param_env_reveal_all_normalized(def_id); + if self_ty.has_surface_async_drop(tcx, param_env) { + Some(SurfaceDropKind::Async) + } else if self_ty.has_surface_drop(tcx, param_env) { + Some(SurfaceDropKind::Sync) + } else { + None + } + }; + + match self_ty.kind() { + ty::Array(elem_ty, _) => self.build_slice(true, *elem_ty), + ty::Slice(elem_ty) => self.build_slice(false, *elem_ty), + + ty::Tuple(elem_tys) => self.build_chain(None, elem_tys.iter()), + ty::Adt(adt_def, args) if adt_def.is_struct() => { + let field_tys = adt_def.non_enum_variant().fields.iter().map(|f| f.ty(tcx, args)); + self.build_chain(surface_drop_kind(), field_tys) + } + ty::Closure(_, args) => self.build_chain(None, args.as_closure().upvar_tys().iter()), + ty::CoroutineClosure(_, args) => { + self.build_chain(None, args.as_coroutine_closure().upvar_tys().iter()) + } + + ty::Adt(adt_def, args) if adt_def.is_enum() => { + self.build_enum(*adt_def, *args, surface_drop_kind()) + } + + ty::Adt(adt_def, _) => { + assert!(adt_def.is_union()); + match surface_drop_kind().unwrap() { + SurfaceDropKind::Async => self.build_fused_async_surface(), + SurfaceDropKind::Sync => self.build_fused_sync_surface(), + } + } + + ty::Bound(..) + | ty::Foreign(_) + | ty::Placeholder(_) + | ty::Infer(ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) | ty::TyVar(_)) + | ty::Param(_) + | ty::Alias(..) => { + bug!("Building async destructor for unexpected type: {self_ty:?}") + } + + _ => { + bug!( + "Building async destructor constructor shim is not yet implemented for type: {self_ty:?}" + ) + } + } + } + + fn build_enum( + mut self, + adt_def: ty::AdtDef<'tcx>, + args: ty::GenericArgsRef<'tcx>, + surface_drop: Option<SurfaceDropKind>, + ) -> Body<'tcx> { + let tcx = self.tcx; + + let surface = match surface_drop { + None => None, + Some(kind) => { + self.put_self(); + Some(match kind { + SurfaceDropKind::Async => self.combine_async_surface(), + SurfaceDropKind::Sync => self.combine_sync_surface(), + }) + } + }; + + let mut other = None; + for (variant_idx, discr) in adt_def.discriminants(tcx) { + let variant = adt_def.variant(variant_idx); + + let mut chain = None; + for (field_idx, field) in variant.fields.iter_enumerated() { + let field_ty = field.ty(tcx, args); + self.put_variant_field(variant.name, variant_idx, field_idx, field_ty); + let defer = self.combine_defer(field_ty); + chain = Some(match chain { + None => defer, + Some(chain) => self.combine_chain(chain, defer), + }) + } + let variant_dtor = chain.unwrap_or_else(|| self.put_noop()); + + other = Some(match other { + None => variant_dtor, + Some(other) => { + self.put_self(); + self.put_discr(discr); + self.combine_either(other, variant_dtor) + } + }); + } + let variants_dtor = other.unwrap_or_else(|| self.put_noop()); + + let dtor = match surface { + None => variants_dtor, + Some(surface) => self.combine_chain(surface, variants_dtor), + }; + self.combine_fuse(dtor); + self.return_() + } + + fn build_chain<I>(mut self, surface_drop: Option<SurfaceDropKind>, elem_tys: I) -> Body<'tcx> + where + I: Iterator<Item = Ty<'tcx>> + ExactSizeIterator, + { + let surface = match surface_drop { + None => None, + Some(kind) => { + self.put_self(); + Some(match kind { + SurfaceDropKind::Async => self.combine_async_surface(), + SurfaceDropKind::Sync => self.combine_sync_surface(), + }) + } + }; + + let mut chain = None; + for (field_idx, field_ty) in elem_tys.enumerate().map(|(i, ty)| (FieldIdx::new(i), ty)) { + self.put_field(field_idx, field_ty); + let defer = self.combine_defer(field_ty); + chain = Some(match chain { + None => defer, + Some(chain) => self.combine_chain(chain, defer), + }) + } + let chain = chain.unwrap_or_else(|| self.put_noop()); + + let dtor = match surface { + None => chain, + Some(surface) => self.combine_chain(surface, chain), + }; + self.combine_fuse(dtor); + self.return_() + } + + fn build_zst_output(mut self) -> Body<'tcx> { + self.put_zst_output(); + self.return_() + } + + fn build_fused_async_surface(mut self) -> Body<'tcx> { + self.put_self(); + let surface = self.combine_async_surface(); + self.combine_fuse(surface); + self.return_() + } + + fn build_fused_sync_surface(mut self) -> Body<'tcx> { + self.put_self(); + let surface = self.combine_sync_surface(); + self.combine_fuse(surface); + self.return_() + } + + fn build_slice(mut self, is_array: bool, elem_ty: Ty<'tcx>) -> Body<'tcx> { + if is_array { + self.put_array_as_slice(elem_ty) + } else { + self.put_self() + } + let dtor = self.combine_slice(elem_ty); + self.combine_fuse(dtor); + self.return_() + } + + fn put_zst_output(&mut self) { + let return_ty = self.locals[RETURN_PLACE].ty; + self.put_operand(Operand::Constant(Box::new(ConstOperand { + span: self.span, + user_ty: None, + const_: Const::zero_sized(return_ty), + }))); + } + + /// Puts `to_drop: *mut Self` on top of the stack. + fn put_self(&mut self) { + self.put_operand(Operand::Copy(Self::SELF_PTR.into())) + } + + /// Given that `Self is [ElemTy; N]` puts `to_drop: *mut [ElemTy]` + /// on top of the stack. + fn put_array_as_slice(&mut self, elem_ty: Ty<'tcx>) { + let slice_ptr_ty = Ty::new_mut_ptr(self.tcx, Ty::new_slice(self.tcx, elem_ty)); + self.put_temp_rvalue(Rvalue::Cast( + CastKind::PointerCoercion(PointerCoercion::Unsize), + Operand::Copy(Self::SELF_PTR.into()), + slice_ptr_ty, + )) + } + + /// If given Self is a struct puts `to_drop: *mut FieldTy` on top + /// of the stack. + fn put_field(&mut self, field: FieldIdx, field_ty: Ty<'tcx>) { + let place = Place { + local: Self::SELF_PTR, + projection: self + .tcx + .mk_place_elems(&[PlaceElem::Deref, PlaceElem::Field(field, field_ty)]), + }; + self.put_temp_rvalue(Rvalue::AddressOf(Mutability::Mut, place)) + } + + /// If given Self is an enum puts `to_drop: *mut FieldTy` on top of + /// the stack. + fn put_variant_field( + &mut self, + variant_sym: Symbol, + variant: VariantIdx, + field: FieldIdx, + field_ty: Ty<'tcx>, + ) { + let place = Place { + local: Self::SELF_PTR, + projection: self.tcx.mk_place_elems(&[ + PlaceElem::Deref, + PlaceElem::Downcast(Some(variant_sym), variant), + PlaceElem::Field(field, field_ty), + ]), + }; + self.put_temp_rvalue(Rvalue::AddressOf(Mutability::Mut, place)) + } + + /// If given Self is an enum puts `to_drop: *mut FieldTy` on top of + /// the stack. + fn put_discr(&mut self, discr: Discr<'tcx>) { + let (size, _) = discr.ty.int_size_and_signed(self.tcx); + self.put_operand(Operand::const_from_scalar( + self.tcx, + discr.ty, + interpret::Scalar::from_uint(discr.val, size), + self.span, + )); + } + + /// Puts `x: RvalueType` on top of the stack. + fn put_temp_rvalue(&mut self, rvalue: Rvalue<'tcx>) { + let last_bb = &mut self.bbs[self.last_bb]; + debug_assert!(last_bb.terminator.is_none()); + let source_info = self.source_info; + + let local_ty = rvalue.ty(&self.locals, self.tcx); + // We need to create a new local to be able to "consume" it with + // a combinator + let local = self.locals.push(LocalDecl::with_source_info(local_ty, source_info)); + last_bb.statements.extend_from_slice(&[ + Statement { source_info, kind: StatementKind::StorageLive(local) }, + Statement { + source_info, + kind: StatementKind::Assign(Box::new((local.into(), rvalue))), + }, + ]); + + self.put_operand(Operand::Move(local.into())); + } + + /// Puts operand on top of the stack. + fn put_operand(&mut self, operand: Operand<'tcx>) { + if let Some(top_cleanup_bb) = &mut self.top_cleanup_bb { + let source_info = self.source_info; + match &operand { + Operand::Copy(_) | Operand::Constant(_) => { + *top_cleanup_bb = self.bbs.push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Goto { target: *top_cleanup_bb }, + }), + is_cleanup: true, + }); + } + Operand::Move(place) => { + let local = place.as_local().unwrap(); + *top_cleanup_bb = self.bbs.push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: if self.locals[local].ty.needs_drop(self.tcx, self.param_env) { + TerminatorKind::Drop { + place: local.into(), + target: *top_cleanup_bb, + unwind: UnwindAction::Terminate( + UnwindTerminateReason::InCleanup, + ), + replace: false, + } + } else { + TerminatorKind::Goto { target: *top_cleanup_bb } + }, + }), + is_cleanup: true, + }); + } + }; + } + self.stack.push(operand); + } + + /// Puts `noop: async_drop::Noop` on top of the stack + fn put_noop(&mut self) -> Ty<'tcx> { + self.apply_combinator(0, LangItem::AsyncDropNoop, &[]) + } + + fn combine_async_surface(&mut self) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::SurfaceAsyncDropInPlace, &[self.self_ty.unwrap().into()]) + } + + fn combine_sync_surface(&mut self) -> Ty<'tcx> { + self.apply_combinator( + 1, + LangItem::AsyncDropSurfaceDropInPlace, + &[self.self_ty.unwrap().into()], + ) + } + + fn combine_fuse(&mut self, inner_future_ty: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::AsyncDropFuse, &[inner_future_ty.into()]) + } + + fn combine_slice(&mut self, elem_ty: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::AsyncDropSlice, &[elem_ty.into()]) + } + + fn combine_defer(&mut self, to_drop_ty: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(1, LangItem::AsyncDropDefer, &[to_drop_ty.into()]) + } + + fn combine_chain(&mut self, first: Ty<'tcx>, second: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator(2, LangItem::AsyncDropChain, &[first.into(), second.into()]) + } + + fn combine_either(&mut self, other: Ty<'tcx>, matched: Ty<'tcx>) -> Ty<'tcx> { + self.apply_combinator( + 4, + LangItem::AsyncDropEither, + &[other.into(), matched.into(), self.self_ty.unwrap().into()], + ) + } + + fn return_(mut self) -> Body<'tcx> { + let last_bb = &mut self.bbs[self.last_bb]; + debug_assert!(last_bb.terminator.is_none()); + let source_info = self.source_info; + + let (1, Some(output)) = (self.stack.len(), self.stack.pop()) else { + span_bug!( + self.span, + "async destructor ctor shim builder finished with invalid number of stack items: expected 1 found {}", + self.stack.len(), + ) + }; + #[cfg(debug_assertions)] + if let Some(ty) = self.self_ty { + debug_assert_eq!( + output.ty(&self.locals, self.tcx), + ty.async_destructor_ty(self.tcx, self.param_env), + "output async destructor types did not match for type: {ty:?}", + ); + } + + let dead_storage = match &output { + Operand::Move(place) => Some(Statement { + source_info, + kind: StatementKind::StorageDead(place.as_local().unwrap()), + }), + _ => None, + }; + + last_bb.statements.extend( + iter::once(Statement { + source_info, + kind: StatementKind::Assign(Box::new((RETURN_PLACE.into(), Rvalue::Use(output)))), + }) + .chain(dead_storage), + ); + + last_bb.terminator = Some(Terminator { source_info, kind: TerminatorKind::Return }); + + let source = MirSource::from_instance(ty::InstanceDef::AsyncDropGlueCtorShim( + self.def_id, + self.self_ty, + )); + new_body(source, self.bbs, self.locals, Self::INPUT_COUNT, self.span) + } + + fn apply_combinator( + &mut self, + arity: usize, + function: LangItem, + args: &[ty::GenericArg<'tcx>], + ) -> Ty<'tcx> { + let function = self.tcx.require_lang_item(function, Some(self.span)); + let operands_split = self + .stack + .len() + .checked_sub(arity) + .expect("async destructor ctor shim combinator tried to consume too many items"); + let operands = &self.stack[operands_split..]; + + let func_ty = Ty::new_fn_def(self.tcx, function, args.iter().copied()); + let func_sig = func_ty.fn_sig(self.tcx).no_bound_vars().unwrap(); + #[cfg(debug_assertions)] + operands.iter().zip(func_sig.inputs()).for_each(|(operand, expected_ty)| { + let operand_ty = operand.ty(&self.locals, self.tcx); + if operand_ty == *expected_ty { + return; + } + + // If projection of Discriminant then compare with `Ty::discriminant_ty` + if let ty::Alias(ty::Projection, ty::AliasTy { args, def_id, .. }) = expected_ty.kind() + && Some(*def_id) == self.tcx.lang_items().discriminant_type() + && args.first().unwrap().as_type().unwrap().discriminant_ty(self.tcx) == operand_ty + { + return; + } + + span_bug!( + self.span, + "Operand type and combinator argument type are not equal. + operand_ty: {:?} + argument_ty: {:?} +", + operand_ty, + expected_ty + ); + }); + + let target = self.bbs.push(BasicBlockData { + statements: operands + .iter() + .rev() + .filter_map(|o| { + if let Operand::Move(Place { local, projection }) = o { + assert!(projection.is_empty()); + Some(Statement { + source_info: self.source_info, + kind: StatementKind::StorageDead(*local), + }) + } else { + None + } + }) + .collect(), + terminator: None, + is_cleanup: false, + }); + + let dest_ty = func_sig.output(); + let dest = + self.locals.push(LocalDecl::with_source_info(dest_ty, self.source_info).immutable()); + + let unwind = if let Some(top_cleanup_bb) = &mut self.top_cleanup_bb { + for _ in 0..arity { + *top_cleanup_bb = + self.bbs[*top_cleanup_bb].terminator().successors().exactly_one().ok().unwrap(); + } + UnwindAction::Cleanup(*top_cleanup_bb) + } else { + UnwindAction::Unreachable + }; + + let last_bb = &mut self.bbs[self.last_bb]; + debug_assert!(last_bb.terminator.is_none()); + last_bb.statements.push(Statement { + source_info: self.source_info, + kind: StatementKind::StorageLive(dest), + }); + last_bb.terminator = Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::Call { + func: Operand::Constant(Box::new(ConstOperand { + span: self.span, + user_ty: None, + const_: Const::Val(ConstValue::ZeroSized, func_ty), + })), + destination: dest.into(), + target: Some(target), + unwind, + call_source: CallSource::Misc, + fn_span: self.span, + args: self.stack.drain(operands_split..).map(|o| respan(self.span, o)).collect(), + }, + }); + + self.put_operand(Operand::Move(dest.into())); + self.last_bb = target; + + dest_ty + } +} diff --git a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs index 1a8cfc41178..03907babf2b 100644 --- a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs +++ b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs @@ -2,6 +2,7 @@ use std::iter; use super::MirPass; use rustc_middle::{ + bug, mir::{ interpret::Scalar, BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, StatementKind, SwitchTargets, TerminatorKind, diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 06d5e17fdd6..f19c34cae7a 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -1,6 +1,7 @@ use rustc_data_structures::flat_map_in_place::FlatMapInPlace; use rustc_index::bit_set::{BitSet, GrowableBitSet}; use rustc_index::IndexVec; +use rustc_middle::bug; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; @@ -69,6 +70,11 @@ fn escaping_locals<'tcx>( // Exclude #[repr(simd)] types so that they are not de-optimized into an array return true; } + if Some(def.did()) == tcx.lang_items().dyn_metadata() { + // codegen wants to see the `DynMetadata<T>`, + // not the inner reference-to-opaque-type. + return true; + } // We already excluded unions and enums, so this ADT must have one variant let variant = def.variant(FIRST_VARIANT); if variant.fields.len() > 1 { diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs index fddc62e6652..fb870425f6e 100644 --- a/compiler/rustc_mir_transform/src/ssa.rs +++ b/compiler/rustc_mir_transform/src/ssa.rs @@ -2,15 +2,18 @@ //! 1/ They are only assigned-to once, either as a function parameter, or in an assign statement; //! 2/ This single assignment dominates all uses; //! -//! As a consequence of rule 2, we consider that borrowed locals are not SSA, even if they are -//! `Freeze`, as we do not track that the assignment dominates all uses of the borrow. +//! As we do not track indirect assignments, a local that has its address taken (either by +//! AddressOf or by borrowing) is considered non-SSA. However, it is UB to modify through an +//! immutable borrow of a `Freeze` local. Those can still be considered to be SSA. use rustc_data_structures::graph::dominators::Dominators; use rustc_index::bit_set::BitSet; use rustc_index::{IndexSlice, IndexVec}; +use rustc_middle::bug; use rustc_middle::middle::resolve_bound_vars::Set1; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; +use rustc_middle::ty::{ParamEnv, TyCtxt}; pub struct SsaLocals { /// Assignments to each local. This defines whether the local is SSA. @@ -24,6 +27,8 @@ pub struct SsaLocals { /// Number of "direct" uses of each local, ie. uses that are not dereferences. /// We ignore non-uses (Storage statements, debuginfo). direct_uses: IndexVec<Local, u32>, + /// Set of SSA locals that are immutably borrowed. + borrowed_locals: BitSet<Local>, } pub enum AssignedValue<'a, 'tcx> { @@ -33,15 +38,22 @@ pub enum AssignedValue<'a, 'tcx> { } impl SsaLocals { - pub fn new<'tcx>(body: &Body<'tcx>) -> SsaLocals { + pub fn new<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ParamEnv<'tcx>) -> SsaLocals { let assignment_order = Vec::with_capacity(body.local_decls.len()); let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls); let dominators = body.basic_blocks.dominators(); let direct_uses = IndexVec::from_elem(0, &body.local_decls); - let mut visitor = - SsaVisitor { body, assignments, assignment_order, dominators, direct_uses }; + let borrowed_locals = BitSet::new_empty(body.local_decls.len()); + let mut visitor = SsaVisitor { + body, + assignments, + assignment_order, + dominators, + direct_uses, + borrowed_locals, + }; for local in body.args_iter() { visitor.assignments[local] = Set1::One(DefLocation::Argument); @@ -58,6 +70,16 @@ impl SsaLocals { visitor.visit_var_debug_info(var_debug_info); } + // The immutability of shared borrows only works on `Freeze` locals. If the visitor found + // borrows, we need to check the types. For raw pointers and mutable borrows, the locals + // have already been marked as non-SSA. + debug!(?visitor.borrowed_locals); + for local in visitor.borrowed_locals.iter() { + if !body.local_decls[local].ty.is_freeze(tcx, param_env) { + visitor.assignments[local] = Set1::Many; + } + } + debug!(?visitor.assignments); debug!(?visitor.direct_uses); @@ -70,6 +92,7 @@ impl SsaLocals { assignments: visitor.assignments, assignment_order: visitor.assignment_order, direct_uses: visitor.direct_uses, + borrowed_locals: visitor.borrowed_locals, // This is filled by `compute_copy_classes`. copy_classes: IndexVec::default(), }; @@ -174,6 +197,11 @@ impl SsaLocals { &self.copy_classes } + /// Set of SSA locals that are immutably borrowed. + pub fn borrowed_locals(&self) -> &BitSet<Local> { + &self.borrowed_locals + } + /// Make a property uniform on a copy equivalence class by removing elements. pub fn meet_copy_equivalence(&self, property: &mut BitSet<Local>) { // Consolidate to have a local iff all its copies are. @@ -208,6 +236,8 @@ struct SsaVisitor<'tcx, 'a> { assignments: IndexVec<Local, Set1<DefLocation>>, assignment_order: Vec<Local>, direct_uses: IndexVec<Local, u32>, + // Track locals that are immutably borrowed, so we can check their type is `Freeze` later. + borrowed_locals: BitSet<Local>, } impl SsaVisitor<'_, '_> { @@ -232,16 +262,18 @@ impl<'tcx> Visitor<'tcx> for SsaVisitor<'tcx, '_> { PlaceContext::MutatingUse(MutatingUseContext::Projection) | PlaceContext::NonMutatingUse(NonMutatingUseContext::Projection) => bug!(), // Anything can happen with raw pointers, so remove them. - // We do not verify that all uses of the borrow dominate the assignment to `local`, - // so we have to remove them too. - PlaceContext::NonMutatingUse( - NonMutatingUseContext::SharedBorrow - | NonMutatingUseContext::FakeBorrow - | NonMutatingUseContext::AddressOf, - ) + PlaceContext::NonMutatingUse(NonMutatingUseContext::AddressOf) | PlaceContext::MutatingUse(_) => { self.assignments[local] = Set1::Many; } + // Immutable borrows are ok, but we need to delay a check that the type is `Freeze`. + PlaceContext::NonMutatingUse( + NonMutatingUseContext::SharedBorrow | NonMutatingUseContext::FakeBorrow, + ) => { + self.borrowed_locals.insert(local); + self.check_dominates(local, loc); + self.direct_uses[local] += 1; + } PlaceContext::NonMutatingUse(_) => { self.check_dominates(local, loc); self.direct_uses[local] += 1; diff --git a/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs b/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs index 66b6235eb93..1404a45f4d2 100644 --- a/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs +++ b/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs @@ -2,6 +2,7 @@ use crate::MirPass; use rustc_data_structures::fx::FxHashSet; +use rustc_middle::bug; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::{ BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind, diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index 8ad7bc394c5..a6c3c3b189e 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -3,6 +3,7 @@ //! post-order traversal of the blocks. use rustc_data_structures::fx::FxHashSet; +use rustc_middle::bug; use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; diff --git a/compiler/rustc_mir_transform/src/validate.rs b/compiler/rustc_mir_transform/src/validate.rs new file mode 100644 index 00000000000..851e1655958 --- /dev/null +++ b/compiler/rustc_mir_transform/src/validate.rs @@ -0,0 +1,1436 @@ +//! Validates the MIR to ensure that invariants are upheld. + +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_index::bit_set::BitSet; +use rustc_index::IndexVec; +use rustc_infer::traits::Reveal; +use rustc_middle::mir::coverage::CoverageKind; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeVisitableExt, Variance}; +use rustc_middle::{bug, span_bug}; +use rustc_target::abi::{Size, FIRST_VARIANT}; +use rustc_target::spec::abi::Abi; + +use crate::util::is_within_packed; + +use crate::util::relate_types; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum EdgeKind { + Unwind, + Normal, +} + +pub struct Validator { + /// Describes at which point in the pipeline this validation is happening. + pub when: String, + /// The phase for which we are upholding the dialect. If the given phase forbids a specific + /// element, this validator will now emit errors if that specific element is encountered. + /// Note that phases that change the dialect cause all *following* phases to check the + /// invariants of the new dialect. A phase that changes dialects never checks the new invariants + /// itself. + pub mir_phase: MirPhase, +} + +impl<'tcx> MirPass<'tcx> for Validator { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // FIXME(JakobDegen): These bodies never instantiated in codegend anyway, so it's not + // terribly important that they pass the validator. However, I think other passes might + // still see them, in which case they might be surprised. It would probably be better if we + // didn't put this through the MIR pipeline at all. + if matches!(body.source.instance, InstanceDef::Intrinsic(..) | InstanceDef::Virtual(..)) { + return; + } + let def_id = body.source.def_id(); + let mir_phase = self.mir_phase; + let param_env = match mir_phase.reveal() { + Reveal::UserFacing => tcx.param_env(def_id), + Reveal::All => tcx.param_env_reveal_all_normalized(def_id), + }; + + let can_unwind = if mir_phase <= MirPhase::Runtime(RuntimePhase::Initial) { + // In this case `AbortUnwindingCalls` haven't yet been executed. + true + } else if !tcx.def_kind(def_id).is_fn_like() { + true + } else { + let body_ty = tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), + ty::Closure(..) => Abi::RustCall, + ty::CoroutineClosure(..) => Abi::RustCall, + ty::Coroutine(..) => Abi::Rust, + // No need to do MIR validation on error bodies + ty::Error(_) => return, + _ => { + span_bug!(body.span, "unexpected body ty: {:?} phase {:?}", body_ty, mir_phase) + } + }; + + ty::layout::fn_can_unwind(tcx, Some(def_id), body_abi) + }; + + let mut cfg_checker = CfgChecker { + when: &self.when, + body, + tcx, + mir_phase, + unwind_edge_count: 0, + reachable_blocks: traversal::reachable_as_bitset(body), + value_cache: FxHashSet::default(), + can_unwind, + }; + cfg_checker.visit_body(body); + cfg_checker.check_cleanup_control_flow(); + + // Also run the TypeChecker. + for (location, msg) in validate_types(tcx, self.mir_phase, param_env, body, body) { + cfg_checker.fail(location, msg); + } + + if let MirPhase::Runtime(_) = body.phase { + if let ty::InstanceDef::Item(_) = body.source.instance { + if body.has_free_regions() { + cfg_checker.fail( + Location::START, + format!("Free regions in optimized {} MIR", body.phase.name()), + ); + } + } + } + + // Enforce that coroutine-closure layouts are identical. + if let Some(layout) = body.coroutine_layout_raw() + && let Some(by_move_body) = body.coroutine_by_move_body() + && let Some(by_move_layout) = by_move_body.coroutine_layout_raw() + { + // FIXME(async_closures): We could do other validation here? + if layout.variant_fields.len() != by_move_layout.variant_fields.len() { + cfg_checker.fail( + Location::START, + format!( + "Coroutine layout has different number of variant fields from \ + by-move coroutine layout:\n\ + layout: {layout:#?}\n\ + by_move_layout: {by_move_layout:#?}", + ), + ); + } + } + } +} + +struct CfgChecker<'a, 'tcx> { + when: &'a str, + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + mir_phase: MirPhase, + unwind_edge_count: usize, + reachable_blocks: BitSet<BasicBlock>, + value_cache: FxHashSet<u128>, + // If `false`, then the MIR must not contain `UnwindAction::Continue` or + // `TerminatorKind::Resume`. + can_unwind: bool, +} + +impl<'a, 'tcx> CfgChecker<'a, 'tcx> { + #[track_caller] + fn fail(&self, location: Location, msg: impl AsRef<str>) { + // We might see broken MIR when other errors have already occurred. + assert!( + self.tcx.dcx().has_errors().is_some(), + "broken MIR in {:?} ({}) at {:?}:\n{}", + self.body.source.instance, + self.when, + location, + msg.as_ref(), + ); + } + + fn check_edge(&mut self, location: Location, bb: BasicBlock, edge_kind: EdgeKind) { + if bb == START_BLOCK { + self.fail(location, "start block must not have predecessors") + } + if let Some(bb) = self.body.basic_blocks.get(bb) { + let src = self.body.basic_blocks.get(location.block).unwrap(); + match (src.is_cleanup, bb.is_cleanup, edge_kind) { + // Non-cleanup blocks can jump to non-cleanup blocks along non-unwind edges + (false, false, EdgeKind::Normal) + // Cleanup blocks can jump to cleanup blocks along non-unwind edges + | (true, true, EdgeKind::Normal) => {} + // Non-cleanup blocks can jump to cleanup blocks along unwind edges + (false, true, EdgeKind::Unwind) => { + self.unwind_edge_count += 1; + } + // All other jumps are invalid + _ => { + self.fail( + location, + format!( + "{:?} edge to {:?} violates unwind invariants (cleanup {:?} -> {:?})", + edge_kind, + bb, + src.is_cleanup, + bb.is_cleanup, + ) + ) + } + } + } else { + self.fail(location, format!("encountered jump to invalid basic block {bb:?}")) + } + } + + fn check_cleanup_control_flow(&self) { + if self.unwind_edge_count <= 1 { + return; + } + let doms = self.body.basic_blocks.dominators(); + let mut post_contract_node = FxHashMap::default(); + // Reusing the allocation across invocations of the closure + let mut dom_path = vec![]; + let mut get_post_contract_node = |mut bb| { + let root = loop { + if let Some(root) = post_contract_node.get(&bb) { + break *root; + } + let parent = doms.immediate_dominator(bb).unwrap(); + dom_path.push(bb); + if !self.body.basic_blocks[parent].is_cleanup { + break bb; + } + bb = parent; + }; + for bb in dom_path.drain(..) { + post_contract_node.insert(bb, root); + } + root + }; + + let mut parent = IndexVec::from_elem(None, &self.body.basic_blocks); + for (bb, bb_data) in self.body.basic_blocks.iter_enumerated() { + if !bb_data.is_cleanup || !self.reachable_blocks.contains(bb) { + continue; + } + let bb = get_post_contract_node(bb); + for s in bb_data.terminator().successors() { + let s = get_post_contract_node(s); + if s == bb { + continue; + } + let parent = &mut parent[bb]; + match parent { + None => { + *parent = Some(s); + } + Some(e) if *e == s => (), + Some(e) => self.fail( + Location { block: bb, statement_index: 0 }, + format!( + "Cleanup control flow violation: The blocks dominated by {:?} have edges to both {:?} and {:?}", + bb, + s, + *e + ) + ), + } + } + } + + // Check for cycles + let mut stack = FxHashSet::default(); + for i in 0..parent.len() { + let mut bb = BasicBlock::from_usize(i); + stack.clear(); + stack.insert(bb); + loop { + let Some(parent) = parent[bb].take() else { break }; + let no_cycle = stack.insert(parent); + if !no_cycle { + self.fail( + Location { block: bb, statement_index: 0 }, + format!( + "Cleanup control flow violation: Cycle involving edge {bb:?} -> {parent:?}", + ), + ); + break; + } + bb = parent; + } + } + } + + fn check_unwind_edge(&mut self, location: Location, unwind: UnwindAction) { + let is_cleanup = self.body.basic_blocks[location.block].is_cleanup; + match unwind { + UnwindAction::Cleanup(unwind) => { + if is_cleanup { + self.fail(location, "`UnwindAction::Cleanup` in cleanup block"); + } + self.check_edge(location, unwind, EdgeKind::Unwind); + } + UnwindAction::Continue => { + if is_cleanup { + self.fail(location, "`UnwindAction::Continue` in cleanup block"); + } + + if !self.can_unwind { + self.fail(location, "`UnwindAction::Continue` in no-unwind function"); + } + } + UnwindAction::Terminate(UnwindTerminateReason::InCleanup) => { + if !is_cleanup { + self.fail( + location, + "`UnwindAction::Terminate(InCleanup)` in a non-cleanup block", + ); + } + } + // These are allowed everywhere. + UnwindAction::Unreachable | UnwindAction::Terminate(UnwindTerminateReason::Abi) => (), + } + } + + fn is_critical_call_edge(&self, target: Option<BasicBlock>, unwind: UnwindAction) -> bool { + let Some(target) = target else { return false }; + matches!(unwind, UnwindAction::Cleanup(_) | UnwindAction::Terminate(_)) + && self.body.basic_blocks.predecessors()[target].len() > 1 + } +} + +impl<'a, 'tcx> Visitor<'tcx> for CfgChecker<'a, 'tcx> { + fn visit_local(&mut self, local: Local, _context: PlaceContext, location: Location) { + if self.body.local_decls.get(local).is_none() { + self.fail( + location, + format!("local {local:?} has no corresponding declaration in `body.local_decls`"), + ); + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::AscribeUserType(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`AscribeUserType` should have been removed after drop lowering phase", + ); + } + } + StatementKind::FakeRead(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FakeRead` should have been removed after drop lowering phase", + ); + } + } + StatementKind::SetDiscriminant { .. } => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`SetDiscriminant`is not allowed until deaggregation"); + } + } + StatementKind::Deinit(..) => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Deinit`is not allowed until deaggregation"); + } + } + StatementKind::Retag(kind, _) => { + // FIXME(JakobDegen) The validator should check that `self.mir_phase < + // DropsLowered`. However, this causes ICEs with generation of drop shims, which + // seem to fail to set their `MirPhase` correctly. + if matches!(kind, RetagKind::TwoPhase) { + self.fail(location, format!("explicit `{kind:?}` is forbidden")); + } + } + StatementKind::Coverage(kind) => { + if self.mir_phase >= MirPhase::Analysis(AnalysisPhase::PostCleanup) + && let CoverageKind::BlockMarker { .. } | CoverageKind::SpanMarker { .. } = kind + { + self.fail( + location, + format!("{kind:?} should have been removed after analysis"), + ); + } + } + StatementKind::Assign(..) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::Nop => {} + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + match &terminator.kind { + TerminatorKind::Goto { target } => { + self.check_edge(location, *target, EdgeKind::Normal); + } + TerminatorKind::SwitchInt { targets, discr: _ } => { + for (_, target) in targets.iter() { + self.check_edge(location, target, EdgeKind::Normal); + } + self.check_edge(location, targets.otherwise(), EdgeKind::Normal); + + self.value_cache.clear(); + self.value_cache.extend(targets.iter().map(|(value, _)| value)); + let has_duplicates = targets.iter().len() != self.value_cache.len(); + if has_duplicates { + self.fail( + location, + format!( + "duplicated values in `SwitchInt` terminator: {:?}", + terminator.kind, + ), + ); + } + } + TerminatorKind::Drop { target, unwind, .. } => { + self.check_edge(location, *target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::Call { args, destination, target, unwind, .. } => { + if let Some(target) = target { + self.check_edge(location, *target, EdgeKind::Normal); + } + self.check_unwind_edge(location, *unwind); + + // The code generation assumes that there are no critical call edges. The assumption + // is used to simplify inserting code that should be executed along the return edge + // from the call. FIXME(tmiasko): Since this is a strictly code generation concern, + // the code generation should be responsible for handling it. + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Optimized) + && self.is_critical_call_edge(*target, *unwind) + { + self.fail( + location, + format!( + "encountered critical edge in `Call` terminator {:?}", + terminator.kind, + ), + ); + } + + // The call destination place and Operand::Move place used as an argument might be + // passed by a reference to the callee. Consequently they cannot be packed. + if is_within_packed(self.tcx, &self.body.local_decls, *destination).is_some() { + // This is bad! The callee will expect the memory to be aligned. + self.fail( + location, + format!( + "encountered packed place in `Call` terminator destination: {:?}", + terminator.kind, + ), + ); + } + for arg in args { + if let Operand::Move(place) = &arg.node { + if is_within_packed(self.tcx, &self.body.local_decls, *place).is_some() { + // This is bad! The callee will expect the memory to be aligned. + self.fail( + location, + format!( + "encountered `Move` of a packed place in `Call` terminator: {:?}", + terminator.kind, + ), + ); + } + } + } + } + TerminatorKind::Assert { target, unwind, .. } => { + self.check_edge(location, *target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::Yield { resume, drop, .. } => { + if self.body.coroutine.is_none() { + self.fail(location, "`Yield` cannot appear outside coroutine bodies"); + } + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Yield` should have been replaced by coroutine lowering"); + } + self.check_edge(location, *resume, EdgeKind::Normal); + if let Some(drop) = drop { + self.check_edge(location, *drop, EdgeKind::Normal); + } + } + TerminatorKind::FalseEdge { real_target, imaginary_target } => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FalseEdge` should have been removed after drop elaboration", + ); + } + self.check_edge(location, *real_target, EdgeKind::Normal); + self.check_edge(location, *imaginary_target, EdgeKind::Normal); + } + TerminatorKind::FalseUnwind { real_target, unwind } => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FalseUnwind` should have been removed after drop elaboration", + ); + } + self.check_edge(location, *real_target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::InlineAsm { targets, unwind, .. } => { + for &target in targets { + self.check_edge(location, target, EdgeKind::Normal); + } + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::CoroutineDrop => { + if self.body.coroutine.is_none() { + self.fail(location, "`CoroutineDrop` cannot appear outside coroutine bodies"); + } + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`CoroutineDrop` should have been replaced by coroutine lowering", + ); + } + } + TerminatorKind::UnwindResume => { + let bb = location.block; + if !self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `UnwindResume` from non-cleanup basic block") + } + if !self.can_unwind { + self.fail(location, "Cannot `UnwindResume` in a function that cannot unwind") + } + } + TerminatorKind::UnwindTerminate(_) => { + let bb = location.block; + if !self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `UnwindTerminate` from non-cleanup basic block") + } + } + TerminatorKind::Return => { + let bb = location.block; + if self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `Return` from cleanup basic block") + } + } + TerminatorKind::Unreachable => {} + } + + self.super_terminator(terminator, location); + } + + fn visit_source_scope(&mut self, scope: SourceScope) { + if self.body.source_scopes.get(scope).is_none() { + self.tcx.dcx().span_bug( + self.body.span, + format!( + "broken MIR in {:?} ({}):\ninvalid source scope {:?}", + self.body.source.instance, self.when, scope, + ), + ); + } + } +} + +/// A faster version of the validation pass that only checks those things which may break when +/// instantiating any generic parameters. +/// +/// `caller_body` is used to detect cycles in MIR inlining and MIR validation before +/// `optimized_mir` is available. +pub fn validate_types<'tcx>( + tcx: TyCtxt<'tcx>, + mir_phase: MirPhase, + param_env: ty::ParamEnv<'tcx>, + body: &Body<'tcx>, + caller_body: &Body<'tcx>, +) -> Vec<(Location, String)> { + let mut type_checker = + TypeChecker { body, caller_body, tcx, param_env, mir_phase, failures: Vec::new() }; + type_checker.visit_body(body); + type_checker.failures +} + +struct TypeChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + caller_body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + mir_phase: MirPhase, + failures: Vec<(Location, String)>, +} + +impl<'a, 'tcx> TypeChecker<'a, 'tcx> { + fn fail(&mut self, location: Location, msg: impl Into<String>) { + self.failures.push((location, msg.into())); + } + + /// Check if src can be assigned into dest. + /// This is not precise, it will accept some incorrect assignments. + fn mir_assign_valid_types(&self, src: Ty<'tcx>, dest: Ty<'tcx>) -> bool { + // Fast path before we normalize. + if src == dest { + // Equal types, all is good. + return true; + } + + // We sometimes have to use `defining_opaque_types` for subtyping + // to succeed here and figuring out how exactly that should work + // is annoying. It is harmless enough to just not validate anything + // in that case. We still check this after analysis as all opaque + // types have been revealed at this point. + if (src, dest).has_opaque_types() { + return true; + } + + // After borrowck subtyping should be fully explicit via + // `Subtype` projections. + let variance = if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + Variance::Invariant + } else { + Variance::Covariant + }; + + crate::util::relate_types(self.tcx, self.param_env, variance, src, dest) + } +} + +impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + // This check is somewhat expensive, so only run it when -Zvalidate-mir is passed. + if self.tcx.sess.opts.unstable_opts.validate_mir + && self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) + { + // `Operand::Copy` is only supposed to be used with `Copy` types. + if let Operand::Copy(place) = operand { + let ty = place.ty(&self.body.local_decls, self.tcx).ty; + + if !ty.is_copy_modulo_regions(self.tcx, self.param_env) { + self.fail(location, format!("`Operand::Copy` with non-`Copy` type {ty}")); + } + } + } + + self.super_operand(operand, location); + } + + fn visit_projection_elem( + &mut self, + place_ref: PlaceRef<'tcx>, + elem: PlaceElem<'tcx>, + context: PlaceContext, + location: Location, + ) { + match elem { + ProjectionElem::OpaqueCast(ty) + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) => + { + self.fail( + location, + format!("explicit opaque type cast to `{ty}` after `RevealAll`"), + ) + } + ProjectionElem::Index(index) => { + let index_ty = self.body.local_decls[index].ty; + if index_ty != self.tcx.types.usize { + self.fail(location, format!("bad index ({index_ty:?} != usize)")) + } + } + ProjectionElem::Deref + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::PostCleanup) => + { + let base_ty = place_ref.ty(&self.body.local_decls, self.tcx).ty; + + if base_ty.is_box() { + self.fail( + location, + format!("{base_ty:?} dereferenced after ElaborateBoxDerefs"), + ) + } + } + ProjectionElem::Field(f, ty) => { + let parent_ty = place_ref.ty(&self.body.local_decls, self.tcx); + let fail_out_of_bounds = |this: &mut Self, location| { + this.fail(location, format!("Out of bounds field {f:?} for {parent_ty:?}")); + }; + let check_equal = |this: &mut Self, location, f_ty| { + if !this.mir_assign_valid_types(ty, f_ty) { + this.fail( + location, + format!( + "Field projection `{place_ref:?}.{f:?}` specified type `{ty:?}`, but actual type is `{f_ty:?}`" + ) + ) + } + }; + + let kind = match parent_ty.ty.kind() { + &ty::Alias(ty::Opaque, ty::AliasTy { def_id, args, .. }) => { + self.tcx.type_of(def_id).instantiate(self.tcx, args).kind() + } + kind => kind, + }; + + match kind { + ty::Tuple(fields) => { + let Some(f_ty) = fields.get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, *f_ty); + } + ty::Adt(adt_def, args) => { + // see <https://github.com/rust-lang/rust/blob/7601adcc764d42c9f2984082b49948af652df986/compiler/rustc_middle/src/ty/layout.rs#L861-L864> + if Some(adt_def.did()) == self.tcx.lang_items().dyn_metadata() { + self.fail( + location, + format!( + "You can't project to field {f:?} of `DynMetadata` because \ + layout is weird and thinks it doesn't have fields." + ), + ); + } + + let var = parent_ty.variant_index.unwrap_or(FIRST_VARIANT); + let Some(field) = adt_def.variant(var).fields.get(f) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, field.ty(self.tcx, args)); + } + ty::Closure(_, args) => { + let args = args.as_closure(); + let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, f_ty); + } + ty::CoroutineClosure(_, args) => { + let args = args.as_coroutine_closure(); + let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, f_ty); + } + &ty::Coroutine(def_id, args) => { + let f_ty = if let Some(var) = parent_ty.variant_index { + // If we're currently validating an inlined copy of this body, + // then it will no longer be parameterized over the original + // args of the coroutine. Otherwise, we prefer to use this body + // since we may be in the process of computing this MIR in the + // first place. + let layout = if def_id == self.caller_body.source.def_id() { + // FIXME: This is not right for async closures. + self.caller_body.coroutine_layout_raw() + } else { + self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) + }; + + let Some(layout) = layout else { + self.fail( + location, + format!("No coroutine layout for {parent_ty:?}"), + ); + return; + }; + + let Some(&local) = layout.variant_fields[var].get(f) else { + fail_out_of_bounds(self, location); + return; + }; + + let Some(f_ty) = layout.field_tys.get(local) else { + self.fail( + location, + format!("Out of bounds local {local:?} for {parent_ty:?}"), + ); + return; + }; + + ty::EarlyBinder::bind(f_ty.ty).instantiate(self.tcx, args) + } else { + let Some(&f_ty) = args.as_coroutine().prefix_tys().get(f.index()) + else { + fail_out_of_bounds(self, location); + return; + }; + + f_ty + }; + + check_equal(self, location, f_ty); + } + _ => { + self.fail(location, format!("{:?} does not have fields", parent_ty.ty)); + } + } + } + ProjectionElem::Subtype(ty) => { + if !relate_types( + self.tcx, + self.param_env, + Variance::Covariant, + ty, + place_ref.ty(&self.body.local_decls, self.tcx).ty, + ) { + self.fail( + location, + format!( + "Failed subtyping {ty:#?} and {:#?}", + place_ref.ty(&self.body.local_decls, self.tcx).ty + ), + ) + } + } + _ => {} + } + self.super_projection_elem(place_ref, elem, context, location); + } + + fn visit_var_debug_info(&mut self, debuginfo: &VarDebugInfo<'tcx>) { + if let Some(box VarDebugInfoFragment { ty, ref projection }) = debuginfo.composite { + if ty.is_union() || ty.is_enum() { + self.fail( + START_BLOCK.start_location(), + format!("invalid type {ty:?} in debuginfo for {:?}", debuginfo.name), + ); + } + if projection.is_empty() { + self.fail( + START_BLOCK.start_location(), + format!("invalid empty projection in debuginfo for {:?}", debuginfo.name), + ); + } + if projection.iter().any(|p| !matches!(p, PlaceElem::Field(..))) { + self.fail( + START_BLOCK.start_location(), + format!( + "illegal projection {:?} in debuginfo for {:?}", + projection, debuginfo.name + ), + ); + } + } + match debuginfo.value { + VarDebugInfoContents::Const(_) => {} + VarDebugInfoContents::Place(place) => { + if place.projection.iter().any(|p| !p.can_use_in_debuginfo()) { + self.fail( + START_BLOCK.start_location(), + format!("illegal place {:?} in debuginfo for {:?}", place, debuginfo.name), + ); + } + } + } + self.super_var_debug_info(debuginfo); + } + + fn visit_place(&mut self, place: &Place<'tcx>, cntxt: PlaceContext, location: Location) { + // Set off any `bug!`s in the type computation code + let _ = place.ty(&self.body.local_decls, self.tcx); + + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) + && place.projection.len() > 1 + && cntxt != PlaceContext::NonUse(NonUseContext::VarDebugInfo) + && place.projection[1..].contains(&ProjectionElem::Deref) + { + self.fail( + location, + format!("place {place:?} has deref as a later projection (it is only permitted as the first projection)"), + ); + } + + // Ensure all downcast projections are followed by field projections. + let mut projections_iter = place.projection.iter(); + while let Some(proj) = projections_iter.next() { + if matches!(proj, ProjectionElem::Downcast(..)) { + if !matches!(projections_iter.next(), Some(ProjectionElem::Field(..))) { + self.fail( + location, + format!( + "place {place:?} has `Downcast` projection not followed by `Field`" + ), + ); + } + } + } + + self.super_place(place, cntxt, location); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + macro_rules! check_kinds { + ($t:expr, $text:literal, $typat:pat) => { + if !matches!(($t).kind(), $typat) { + self.fail(location, format!($text, $t)); + } + }; + } + match rvalue { + Rvalue::Use(_) | Rvalue::CopyForDeref(_) => {} + Rvalue::Aggregate(kind, fields) => match **kind { + AggregateKind::Tuple => {} + AggregateKind::Array(dest) => { + for src in fields { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "array field has the wrong type"); + } + } + } + AggregateKind::Adt(def_id, idx, args, _, Some(field)) => { + let adt_def = self.tcx.adt_def(def_id); + assert!(adt_def.is_union()); + assert_eq!(idx, FIRST_VARIANT); + let dest_ty = self.tcx.normalize_erasing_regions( + self.param_env, + adt_def.non_enum_variant().fields[field].ty(self.tcx, args), + ); + if fields.len() == 1 { + let src_ty = fields.raw[0].ty(self.body, self.tcx); + if !self.mir_assign_valid_types(src_ty, dest_ty) { + self.fail(location, "union field has the wrong type"); + } + } else { + self.fail(location, "unions should have one initialized field"); + } + } + AggregateKind::Adt(def_id, idx, args, _, None) => { + let adt_def = self.tcx.adt_def(def_id); + assert!(!adt_def.is_union()); + let variant = &adt_def.variants()[idx]; + if variant.fields.len() != fields.len() { + self.fail(location, "adt has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, &variant.fields) { + let dest_ty = self + .tcx + .normalize_erasing_regions(self.param_env, dest.ty(self.tcx, args)); + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest_ty) { + self.fail(location, "adt field has the wrong type"); + } + } + } + AggregateKind::Closure(_, args) => { + let upvars = args.as_closure().upvar_tys(); + if upvars.len() != fields.len() { + self.fail(location, "closure has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "closure field has the wrong type"); + } + } + } + AggregateKind::Coroutine(_, args) => { + let upvars = args.as_coroutine().upvar_tys(); + if upvars.len() != fields.len() { + self.fail(location, "coroutine has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "coroutine field has the wrong type"); + } + } + } + AggregateKind::CoroutineClosure(_, args) => { + let upvars = args.as_coroutine_closure().upvar_tys(); + if upvars.len() != fields.len() { + self.fail( + location, + "coroutine-closure has the wrong number of initialized fields", + ); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "coroutine-closure field has the wrong type"); + } + } + } + AggregateKind::RawPtr(pointee_ty, mutability) => { + if !matches!(self.mir_phase, MirPhase::Runtime(_)) { + // It would probably be fine to support this in earlier phases, + // but at the time of writing it's only ever introduced from intrinsic lowering, + // so earlier things just `bug!` on it. + self.fail(location, "RawPtr should be in runtime MIR only"); + } + + if fields.len() != 2 { + self.fail(location, "raw pointer aggregate must have 2 fields"); + } else { + let data_ptr_ty = fields.raw[0].ty(self.body, self.tcx); + let metadata_ty = fields.raw[1].ty(self.body, self.tcx); + if let ty::RawPtr(in_pointee, in_mut) = data_ptr_ty.kind() { + if *in_mut != mutability { + self.fail(location, "input and output mutability must match"); + } + + // FIXME: check `Thin` instead of `Sized` + if !in_pointee.is_sized(self.tcx, self.param_env) { + self.fail(location, "input pointer must be thin"); + } + } else { + self.fail( + location, + "first operand to raw pointer aggregate must be a raw pointer", + ); + } + + // FIXME: Check metadata more generally + if pointee_ty.is_slice() { + if !self.mir_assign_valid_types(metadata_ty, self.tcx.types.usize) { + self.fail(location, "slice metadata must be usize"); + } + } else if pointee_ty.is_sized(self.tcx, self.param_env) { + if metadata_ty != self.tcx.types.unit { + self.fail(location, "metadata for pointer-to-thin must be unit"); + } + } + } + } + }, + Rvalue::Ref(_, BorrowKind::Fake(_), _) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`Assign` statement with a `Fake` borrow should have been removed in runtime MIR", + ); + } + } + Rvalue::Ref(..) => {} + Rvalue::Len(p) => { + let pty = p.ty(&self.body.local_decls, self.tcx).ty; + check_kinds!( + pty, + "Cannot compute length of non-array type {:?}", + ty::Array(..) | ty::Slice(..) + ); + } + Rvalue::BinaryOp(op, vals) => { + use BinOp::*; + let a = vals.0.ty(&self.body.local_decls, self.tcx); + let b = vals.1.ty(&self.body.local_decls, self.tcx); + if crate::util::binop_right_homogeneous(*op) { + if let Eq | Lt | Le | Ne | Ge | Gt = op { + // The function pointer types can have lifetimes + if !self.mir_assign_valid_types(a, b) { + self.fail( + location, + format!("Cannot {op:?} compare incompatible types {a:?} and {b:?}"), + ); + } + } else if a != b { + self.fail( + location, + format!( + "Cannot perform binary op {op:?} on unequal types {a:?} and {b:?}" + ), + ); + } + } + + match op { + Offset => { + check_kinds!(a, "Cannot offset non-pointer type {:?}", ty::RawPtr(..)); + if b != self.tcx.types.isize && b != self.tcx.types.usize { + self.fail(location, format!("Cannot offset by non-isize type {b:?}")); + } + } + Eq | Lt | Le | Ne | Ge | Gt => { + for x in [a, b] { + check_kinds!( + x, + "Cannot {op:?} compare type {:?}", + ty::Bool + | ty::Char + | ty::Int(..) + | ty::Uint(..) + | ty::Float(..) + | ty::RawPtr(..) + | ty::FnPtr(..) + ) + } + } + Cmp => { + for x in [a, b] { + check_kinds!( + x, + "Cannot three-way compare non-integer type {:?}", + ty::Char | ty::Uint(..) | ty::Int(..) + ) + } + } + AddUnchecked | AddWithOverflow | SubUnchecked | SubWithOverflow + | MulUnchecked | MulWithOverflow | Shl | ShlUnchecked | Shr | ShrUnchecked => { + for x in [a, b] { + check_kinds!( + x, + "Cannot {op:?} non-integer type {:?}", + ty::Uint(..) | ty::Int(..) + ) + } + } + BitAnd | BitOr | BitXor => { + for x in [a, b] { + check_kinds!( + x, + "Cannot perform bitwise op {op:?} on type {:?}", + ty::Uint(..) | ty::Int(..) | ty::Bool + ) + } + } + Add | Sub | Mul | Div | Rem => { + for x in [a, b] { + check_kinds!( + x, + "Cannot perform arithmetic {op:?} on type {:?}", + ty::Uint(..) | ty::Int(..) | ty::Float(..) + ) + } + } + } + } + Rvalue::UnaryOp(op, operand) => { + let a = operand.ty(&self.body.local_decls, self.tcx); + match op { + UnOp::Neg => { + check_kinds!(a, "Cannot negate type {:?}", ty::Int(..) | ty::Float(..)) + } + UnOp::Not => { + check_kinds!( + a, + "Cannot binary not type {:?}", + ty::Int(..) | ty::Uint(..) | ty::Bool + ); + } + UnOp::PtrMetadata => { + if !matches!(self.mir_phase, MirPhase::Runtime(_)) { + // It would probably be fine to support this in earlier phases, + // but at the time of writing it's only ever introduced from intrinsic lowering, + // so earlier things can just `bug!` on it. + self.fail(location, "PtrMetadata should be in runtime MIR only"); + } + + check_kinds!(a, "Cannot PtrMetadata non-pointer type {:?}", ty::RawPtr(..)); + } + } + } + Rvalue::ShallowInitBox(operand, _) => { + let a = operand.ty(&self.body.local_decls, self.tcx); + check_kinds!(a, "Cannot shallow init type {:?}", ty::RawPtr(..)); + } + Rvalue::Cast(kind, operand, target_type) => { + let op_ty = operand.ty(self.body, self.tcx); + match kind { + CastKind::DynStar => { + // FIXME(dyn-star): make sure nothing needs to be done here. + } + // FIXME: Add Checks for these + CastKind::PointerWithExposedProvenance + | CastKind::PointerExposeProvenance + | CastKind::PointerCoercion(_) => {} + CastKind::IntToInt | CastKind::IntToFloat => { + let input_valid = op_ty.is_integral() || op_ty.is_char() || op_ty.is_bool(); + let target_valid = target_type.is_numeric() || target_type.is_char(); + if !input_valid || !target_valid { + self.fail( + location, + format!("Wrong cast kind {kind:?} for the type {op_ty}",), + ); + } + } + CastKind::FnPtrToPtr | CastKind::PtrToPtr => { + if !(op_ty.is_any_ptr() && target_type.is_unsafe_ptr()) { + self.fail(location, "Can't cast {op_ty} into 'Ptr'"); + } + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + if !op_ty.is_floating_point() || !target_type.is_numeric() { + self.fail( + location, + format!( + "Trying to cast non 'Float' as {kind:?} into {target_type:?}" + ), + ); + } + } + CastKind::Transmute => { + if let MirPhase::Runtime(..) = self.mir_phase { + // Unlike `mem::transmute`, a MIR `Transmute` is well-formed + // for any two `Sized` types, just potentially UB to run. + + if !self + .tcx + .normalize_erasing_regions(self.param_env, op_ty) + .is_sized(self.tcx, self.param_env) + { + self.fail( + location, + format!("Cannot transmute from non-`Sized` type {op_ty:?}"), + ); + } + if !self + .tcx + .normalize_erasing_regions(self.param_env, *target_type) + .is_sized(self.tcx, self.param_env) + { + self.fail( + location, + format!("Cannot transmute to non-`Sized` type {target_type:?}"), + ); + } + } else { + self.fail( + location, + format!( + "Transmute is not supported in non-runtime phase {:?}.", + self.mir_phase + ), + ); + } + } + } + } + Rvalue::NullaryOp(NullOp::OffsetOf(indices), container) => { + let fail_out_of_bounds = |this: &mut Self, location, field, ty| { + this.fail(location, format!("Out of bounds field {field:?} for {ty:?}")); + }; + + let mut current_ty = *container; + + for (variant, field) in indices.iter() { + match current_ty.kind() { + ty::Tuple(fields) => { + if variant != FIRST_VARIANT { + self.fail( + location, + format!("tried to get variant {variant:?} of tuple"), + ); + return; + } + let Some(&f_ty) = fields.get(field.as_usize()) else { + fail_out_of_bounds(self, location, field, current_ty); + return; + }; + + current_ty = self.tcx.normalize_erasing_regions(self.param_env, f_ty); + } + ty::Adt(adt_def, args) => { + let Some(field) = adt_def.variant(variant).fields.get(field) else { + fail_out_of_bounds(self, location, field, current_ty); + return; + }; + + let f_ty = field.ty(self.tcx, args); + current_ty = self.tcx.normalize_erasing_regions(self.param_env, f_ty); + } + _ => { + self.fail( + location, + format!("Cannot get offset ({variant:?}, {field:?}) from type {current_ty:?}"), + ); + return; + } + } + } + } + Rvalue::Repeat(_, _) + | Rvalue::ThreadLocalRef(_) + | Rvalue::AddressOf(_, _) + | Rvalue::NullaryOp(NullOp::SizeOf | NullOp::AlignOf | NullOp::UbChecks, _) + | Rvalue::Discriminant(_) => {} + } + self.super_rvalue(rvalue, location); + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(box (dest, rvalue)) => { + // LHS and RHS of the assignment must have the same type. + let left_ty = dest.ty(&self.body.local_decls, self.tcx).ty; + let right_ty = rvalue.ty(&self.body.local_decls, self.tcx); + + if !self.mir_assign_valid_types(right_ty, left_ty) { + self.fail( + location, + format!( + "encountered `{:?}` with incompatible types:\n\ + left-hand side has type: {}\n\ + right-hand side has type: {}", + statement.kind, left_ty, right_ty, + ), + ); + } + if let Rvalue::CopyForDeref(place) = rvalue { + if place.ty(&self.body.local_decls, self.tcx).ty.builtin_deref(true).is_none() { + self.fail( + location, + "`CopyForDeref` should only be used for dereferenceable types", + ) + } + } + } + StatementKind::AscribeUserType(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`AscribeUserType` should have been removed after drop lowering phase", + ); + } + } + StatementKind::FakeRead(..) => { + if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FakeRead` should have been removed after drop lowering phase", + ); + } + } + StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(op)) => { + let ty = op.ty(&self.body.local_decls, self.tcx); + if !ty.is_bool() { + self.fail( + location, + format!("`assume` argument must be `bool`, but got: `{ty}`"), + ); + } + } + StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping( + CopyNonOverlapping { src, dst, count }, + )) => { + let src_ty = src.ty(&self.body.local_decls, self.tcx); + let op_src_ty = if let Some(src_deref) = src_ty.builtin_deref(true) { + src_deref + } else { + self.fail( + location, + format!("Expected src to be ptr in copy_nonoverlapping, got: {src_ty}"), + ); + return; + }; + let dst_ty = dst.ty(&self.body.local_decls, self.tcx); + let op_dst_ty = if let Some(dst_deref) = dst_ty.builtin_deref(true) { + dst_deref + } else { + self.fail( + location, + format!("Expected dst to be ptr in copy_nonoverlapping, got: {dst_ty}"), + ); + return; + }; + // since CopyNonOverlapping is parametrized by 1 type, + // we only need to check that they are equal and not keep an extra parameter. + if !self.mir_assign_valid_types(op_src_ty, op_dst_ty) { + self.fail(location, format!("bad arg ({op_src_ty:?} != {op_dst_ty:?})")); + } + + let op_cnt_ty = count.ty(&self.body.local_decls, self.tcx); + if op_cnt_ty != self.tcx.types.usize { + self.fail(location, format!("bad arg ({op_cnt_ty:?} != usize)")) + } + } + StatementKind::SetDiscriminant { place, .. } => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`SetDiscriminant`is not allowed until deaggregation"); + } + let pty = place.ty(&self.body.local_decls, self.tcx).ty.kind(); + if !matches!(pty, ty::Adt(..) | ty::Coroutine(..) | ty::Alias(ty::Opaque, ..)) { + self.fail( + location, + format!( + "`SetDiscriminant` is only allowed on ADTs and coroutines, not {pty:?}" + ), + ); + } + } + StatementKind::Deinit(..) => { + if self.mir_phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Deinit`is not allowed until deaggregation"); + } + } + StatementKind::Retag(kind, _) => { + // FIXME(JakobDegen) The validator should check that `self.mir_phase < + // DropsLowered`. However, this causes ICEs with generation of drop shims, which + // seem to fail to set their `MirPhase` correctly. + if matches!(kind, RetagKind::TwoPhase) { + self.fail(location, format!("explicit `{kind:?}` is forbidden")); + } + } + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Coverage(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::Nop => {} + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + match &terminator.kind { + TerminatorKind::SwitchInt { targets, discr } => { + let switch_ty = discr.ty(&self.body.local_decls, self.tcx); + + let target_width = self.tcx.sess.target.pointer_width; + + let size = Size::from_bits(match switch_ty.kind() { + ty::Uint(uint) => uint.normalize(target_width).bit_width().unwrap(), + ty::Int(int) => int.normalize(target_width).bit_width().unwrap(), + ty::Char => 32, + ty::Bool => 1, + other => bug!("unhandled type: {:?}", other), + }); + + for (value, _) in targets.iter() { + if Scalar::<()>::try_from_uint(value, size).is_none() { + self.fail( + location, + format!("the value {value:#x} is not a proper {switch_ty:?}"), + ) + } + } + } + TerminatorKind::Call { func, .. } => { + let func_ty = func.ty(&self.body.local_decls, self.tcx); + match func_ty.kind() { + ty::FnPtr(..) | ty::FnDef(..) => {} + _ => self.fail( + location, + format!("encountered non-callable type {func_ty} in `Call` terminator"), + ), + } + } + TerminatorKind::Assert { cond, .. } => { + let cond_ty = cond.ty(&self.body.local_decls, self.tcx); + if cond_ty != self.tcx.types.bool { + self.fail( + location, + format!( + "encountered non-boolean condition of type {cond_ty} in `Assert` terminator" + ), + ); + } + } + TerminatorKind::Goto { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable => {} + } + + self.super_terminator(terminator, location); + } +} |
