diff options
Diffstat (limited to 'compiler/rustc_mir_build/src')
| -rw-r--r-- | compiler/rustc_mir_build/src/builder/expr/as_place.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/builder/expr/category.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/builder/expr/into.rs | 122 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/builder/expr/stmt.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/builder/matches/mod.rs | 144 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/builder/scope.rs | 264 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/check_unsafety.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/errors.rs | 77 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/thir/cx/expr.rs | 143 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/thir/pattern/check_match.rs | 7 | ||||
| -rw-r--r-- | compiler/rustc_mir_build/src/thir/print.rs | 21 |
12 files changed, 752 insertions, 37 deletions
diff --git a/compiler/rustc_mir_build/src/builder/expr/as_place.rs b/compiler/rustc_mir_build/src/builder/expr/as_place.rs index f8c64d7d13e..99148504a87 100644 --- a/compiler/rustc_mir_build/src/builder/expr/as_place.rs +++ b/compiler/rustc_mir_build/src/builder/expr/as_place.rs @@ -565,12 +565,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | ExprKind::Match { .. } | ExprKind::If { .. } | ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::Block { .. } | ExprKind::Let { .. } | ExprKind::Assign { .. } | ExprKind::AssignOp { .. } | ExprKind::Break { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } | ExprKind::Literal { .. } diff --git a/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs b/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs index b23bc089cd4..9e07dd5da7e 100644 --- a/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs +++ b/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs @@ -538,6 +538,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | ExprKind::RawBorrow { .. } | ExprKind::Adt { .. } | ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::LogicalOp { .. } | ExprKind::Call { .. } | ExprKind::Field { .. } @@ -548,6 +549,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | ExprKind::UpvarRef { .. } | ExprKind::Break { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } | ExprKind::InlineAsm { .. } diff --git a/compiler/rustc_mir_build/src/builder/expr/category.rs b/compiler/rustc_mir_build/src/builder/expr/category.rs index 34524aed406..5e4219dbf5b 100644 --- a/compiler/rustc_mir_build/src/builder/expr/category.rs +++ b/compiler/rustc_mir_build/src/builder/expr/category.rs @@ -83,9 +83,11 @@ impl Category { | ExprKind::NamedConst { .. } => Some(Category::Constant), ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::Block { .. } | ExprKind::Break { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } => // FIXME(#27840) these probably want their own diff --git a/compiler/rustc_mir_build/src/builder/expr/into.rs b/compiler/rustc_mir_build/src/builder/expr/into.rs index 2074fbce0ae..fe3d072fa88 100644 --- a/compiler/rustc_mir_build/src/builder/expr/into.rs +++ b/compiler/rustc_mir_build/src/builder/expr/into.rs @@ -8,15 +8,16 @@ use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; use rustc_middle::span_bug; use rustc_middle::thir::*; -use rustc_middle::ty::{CanonicalUserTypeAnnotation, Ty}; +use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty}; use rustc_span::DUMMY_SP; use rustc_span::source_map::Spanned; use rustc_trait_selection::infer::InferCtxtExt; use tracing::{debug, instrument}; use crate::builder::expr::category::{Category, RvalueFunc}; -use crate::builder::matches::DeclareLetBindings; +use crate::builder::matches::{DeclareLetBindings, HasMatchGuard}; use crate::builder::{BlockAnd, BlockAndExtension, BlockFrame, Builder, NeedsTemporary}; +use crate::errors::{LoopMatchArmWithGuard, LoopMatchUnsupportedType}; impl<'a, 'tcx> Builder<'a, 'tcx> { /// Compile `expr`, storing the result into `destination`, which @@ -244,6 +245,122 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { None }) } + ExprKind::LoopMatch { state, region_scope, match_span, ref arms } => { + // Intuitively, this is a combination of a loop containing a labeled block + // containing a match. + // + // The only new bit here is that the lowering of the match is wrapped in a + // `in_const_continuable_scope`, which makes the match arms and their target basic + // block available to the lowering of `#[const_continue]`. + + fn is_supported_loop_match_type(ty: Ty<'_>) -> bool { + match ty.kind() { + ty::Uint(_) | ty::Int(_) | ty::Float(_) | ty::Bool | ty::Char => true, + ty::Adt(adt_def, _) => match adt_def.adt_kind() { + ty::AdtKind::Struct | ty::AdtKind::Union => false, + ty::AdtKind::Enum => { + adt_def.variants().iter().all(|v| v.fields.is_empty()) + } + }, + _ => false, + } + } + + let state_ty = this.thir.exprs[state].ty; + if !is_supported_loop_match_type(state_ty) { + let span = this.thir.exprs[state].span; + this.tcx.dcx().emit_fatal(LoopMatchUnsupportedType { span, ty: state_ty }) + } + + let loop_block = this.cfg.start_new_block(); + + // Start the loop. + this.cfg.goto(block, source_info, loop_block); + + this.in_breakable_scope(Some(loop_block), destination, expr_span, |this| { + // Logic for `loop`. + let mut body_block = this.cfg.start_new_block(); + this.cfg.terminate( + loop_block, + source_info, + TerminatorKind::FalseUnwind { + real_target: body_block, + unwind: UnwindAction::Continue, + }, + ); + this.diverge_from(loop_block); + + // Logic for `match`. + let scrutinee_place_builder = + unpack!(body_block = this.as_place_builder(body_block, state)); + let scrutinee_span = this.thir.exprs[state].span; + let match_start_span = match_span.shrink_to_lo().to(scrutinee_span); + + let mut patterns = Vec::with_capacity(arms.len()); + for &arm_id in arms.iter() { + let arm = &this.thir[arm_id]; + + if let Some(guard) = arm.guard { + let span = this.thir.exprs[guard].span; + this.tcx.dcx().emit_fatal(LoopMatchArmWithGuard { span }) + } + + patterns.push((&*arm.pattern, HasMatchGuard::No)); + } + + // The `built_tree` maps match arms to their basic block (where control flow + // jumps to when a value matches the arm). This structure is stored so that a + // `#[const_continue]` can figure out what basic block to jump to. + let built_tree = this.lower_match_tree( + body_block, + scrutinee_span, + &scrutinee_place_builder, + match_start_span, + patterns, + false, + ); + + let state_place = scrutinee_place_builder.to_place(this); + + // This is logic for the labeled block: a block is a drop scope, hence + // `in_scope`, and a labeled block can be broken out of with a `break 'label`, + // hence the `in_breakable_scope`. + // + // Then `in_const_continuable_scope` stores information for the lowering of + // `#[const_continue]`, and finally the match is lowered in the standard way. + unpack!( + body_block = this.in_scope( + (region_scope, source_info), + LintLevel::Inherited, + move |this| { + this.in_breakable_scope(None, state_place, expr_span, |this| { + Some(this.in_const_continuable_scope( + arms.clone(), + built_tree.clone(), + state_place, + expr_span, + |this| { + this.lower_match_arms( + destination, + scrutinee_place_builder, + scrutinee_span, + arms, + built_tree, + this.source_info(match_span), + ) + }, + )) + }) + } + ) + ); + + this.cfg.goto(body_block, source_info, loop_block); + + // Loops are only exited by `break` expressions. + None + }) + } ExprKind::Call { ty: _, fun, ref args, from_hir_call, fn_span } => { let fun = unpack!(block = this.as_local_operand(block, fun)); let args: Box<[_]> = args @@ -601,6 +718,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Break { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } => { diff --git a/compiler/rustc_mir_build/src/builder/expr/stmt.rs b/compiler/rustc_mir_build/src/builder/expr/stmt.rs index 2dff26f02f3..675beceea14 100644 --- a/compiler/rustc_mir_build/src/builder/expr/stmt.rs +++ b/compiler/rustc_mir_build/src/builder/expr/stmt.rs @@ -98,6 +98,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ExprKind::Break { label, value } => { this.break_scope(block, value, BreakableTarget::Break(label), source_info) } + ExprKind::ConstContinue { label, value } => { + this.break_const_continuable_scope(block, value, label, source_info) + } ExprKind::Return { value } => { this.break_scope(block, value, BreakableTarget::Return, source_info) } diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index 977d4f3e931..270a7d4b154 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -18,7 +18,9 @@ use rustc_middle::bug; use rustc_middle::middle::region; use rustc_middle::mir::{self, *}; use rustc_middle::thir::{self, *}; -use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty}; +use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, ValTree, ValTreeKind}; +use rustc_pattern_analysis::constructor::RangeEnd; +use rustc_pattern_analysis::rustc::{DeconstructedPat, RustcPatCtxt}; use rustc_span::{BytePos, Pos, Span, Symbol, sym}; use tracing::{debug, instrument}; @@ -426,7 +428,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// (by [Builder::lower_match_tree]). /// /// `outer_source_info` is the SourceInfo for the whole match. - fn lower_match_arms( + pub(crate) fn lower_match_arms( &mut self, destination: Place<'tcx>, scrutinee_place_builder: PlaceBuilder<'tcx>, @@ -1395,7 +1397,7 @@ pub(crate) struct ArmHasGuard(pub(crate) bool); /// A sub-branch in the output of match lowering. Match lowering has generated MIR code that will /// branch to `success_block` when the matched value matches the corresponding pattern. If there is /// a guard, its failure must continue to `otherwise_block`, which will resume testing patterns. -#[derive(Debug)] +#[derive(Debug, Clone)] struct MatchTreeSubBranch<'tcx> { span: Span, /// The block that is branched to if the corresponding subpattern matches. @@ -1411,7 +1413,7 @@ struct MatchTreeSubBranch<'tcx> { } /// A branch in the output of match lowering. -#[derive(Debug)] +#[derive(Debug, Clone)] struct MatchTreeBranch<'tcx> { sub_branches: Vec<MatchTreeSubBranch<'tcx>>, } @@ -1430,8 +1432,8 @@ struct MatchTreeBranch<'tcx> { /// Here the first arm gives the first `MatchTreeBranch`, which has two sub-branches, one for each /// alternative of the or-pattern. They are kept separate because each needs to bind `x` to a /// different place. -#[derive(Debug)] -struct BuiltMatchTree<'tcx> { +#[derive(Debug, Clone)] +pub(crate) struct BuiltMatchTree<'tcx> { branches: Vec<MatchTreeBranch<'tcx>>, otherwise_block: BasicBlock, /// If any of the branches had a guard, we collect here the places and locals to fakely borrow @@ -1489,7 +1491,7 @@ impl<'tcx> MatchTreeBranch<'tcx> { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum HasMatchGuard { +pub(crate) enum HasMatchGuard { Yes, No, } @@ -1504,7 +1506,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// `refutable` indicates whether the candidate list is refutable (for `if let` and `let else`) /// or not (for `let` and `match`). In the refutable case we return the block to which we branch /// on failure. - fn lower_match_tree( + pub(crate) fn lower_match_tree( &mut self, block: BasicBlock, scrutinee_span: Span, @@ -1890,7 +1892,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { debug!("expanding or-pattern: candidate={:#?}\npats={:#?}", candidate, pats); candidate.or_span = Some(match_pair.pattern_span); candidate.subcandidates = pats - .into_vec() .into_iter() .map(|flat_pat| Candidate::from_flat_pat(flat_pat, candidate.has_guard)) .collect(); @@ -2864,4 +2865,129 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { true } + + /// Attempt to statically pick the `BasicBlock` that a value would resolve to at runtime. + pub(crate) fn static_pattern_match( + &self, + cx: &RustcPatCtxt<'_, 'tcx>, + valtree: ValTree<'tcx>, + arms: &[ArmId], + built_match_tree: &BuiltMatchTree<'tcx>, + ) -> Option<BasicBlock> { + let it = arms.iter().zip(built_match_tree.branches.iter()); + for (&arm_id, branch) in it { + let pat = cx.lower_pat(&*self.thir.arms[arm_id].pattern); + + // Peel off or-patterns if they exist. + if let rustc_pattern_analysis::rustc::Constructor::Or = pat.ctor() { + for pat in pat.iter_fields() { + // For top-level or-patterns (the only ones we accept right now), when the + // bindings are the same (e.g. there are none), the sub_branch is stored just + // once. + let sub_branch = branch + .sub_branches + .get(pat.idx) + .or_else(|| branch.sub_branches.last()) + .unwrap(); + + match self.static_pattern_match_inner(valtree, &pat.pat) { + true => return Some(sub_branch.success_block), + false => continue, + } + } + } else if self.static_pattern_match_inner(valtree, &pat) { + return Some(branch.sub_branches[0].success_block); + } + } + + None + } + + /// Helper for [`Self::static_pattern_match`], checking whether the value represented by the + /// `ValTree` matches the given pattern. This function does not recurse, meaning that it does + /// not handle or-patterns, or patterns for types with fields. + fn static_pattern_match_inner( + &self, + valtree: ty::ValTree<'tcx>, + pat: &DeconstructedPat<'_, 'tcx>, + ) -> bool { + use rustc_pattern_analysis::constructor::{IntRange, MaybeInfiniteInt}; + use rustc_pattern_analysis::rustc::Constructor; + + match pat.ctor() { + Constructor::Variant(variant_index) => { + let ValTreeKind::Branch(box [actual_variant_idx]) = *valtree else { + bug!("malformed valtree for an enum") + }; + + let ValTreeKind::Leaf(actual_variant_idx) = ***actual_variant_idx else { + bug!("malformed valtree for an enum") + }; + + *variant_index == VariantIdx::from_u32(actual_variant_idx.to_u32()) + } + Constructor::IntRange(int_range) => { + let size = pat.ty().primitive_size(self.tcx); + let actual_int = valtree.unwrap_leaf().to_bits(size); + let actual_int = if pat.ty().is_signed() { + MaybeInfiniteInt::new_finite_int(actual_int, size.bits()) + } else { + MaybeInfiniteInt::new_finite_uint(actual_int) + }; + IntRange::from_singleton(actual_int).is_subrange(int_range) + } + Constructor::Bool(pattern_value) => match valtree.unwrap_leaf().try_to_bool() { + Ok(actual_value) => *pattern_value == actual_value, + Err(()) => bug!("bool value with invalid bits"), + }, + Constructor::F16Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f16(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::F32Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f32(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::F64Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f64(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::F128Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f128(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::Wildcard => true, + + // These we may eventually support: + Constructor::Struct + | Constructor::Ref + | Constructor::DerefPattern(_) + | Constructor::Slice(_) + | Constructor::UnionField + | Constructor::Or + | Constructor::Str(_) => bug!("unsupported pattern constructor {:?}", pat.ctor()), + + // These should never occur here: + Constructor::Opaque(_) + | Constructor::Never + | Constructor::NonExhaustive + | Constructor::Hidden + | Constructor::Missing + | Constructor::PrivateUninhabited => { + bug!("unsupported pattern constructor {:?}", pat.ctor()) + } + } + } } diff --git a/compiler/rustc_mir_build/src/builder/scope.rs b/compiler/rustc_mir_build/src/builder/scope.rs index 67988f1fcbc..1d15e7e126f 100644 --- a/compiler/rustc_mir_build/src/builder/scope.rs +++ b/compiler/rustc_mir_build/src/builder/scope.rs @@ -83,20 +83,24 @@ that contains only loops and breakable blocks. It tracks where a `break`, use std::mem; +use interpret::ErrorHandled; use rustc_data_structures::fx::FxHashMap; use rustc_hir::HirId; use rustc_index::{IndexSlice, IndexVec}; use rustc_middle::middle::region; -use rustc_middle::mir::*; -use rustc_middle::thir::{ExprId, LintLevel}; -use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::mir::{self, *}; +use rustc_middle::thir::{AdtExpr, AdtExprBase, ArmId, ExprId, ExprKind, LintLevel}; +use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt, ValTree}; use rustc_middle::{bug, span_bug}; +use rustc_pattern_analysis::rustc::RustcPatCtxt; use rustc_session::lint::Level; use rustc_span::source_map::Spanned; use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument}; +use super::matches::BuiltMatchTree; use crate::builder::{BlockAnd, BlockAndExtension, BlockFrame, Builder, CFG}; +use crate::errors::{ConstContinueBadConst, ConstContinueUnknownJumpTarget}; #[derive(Debug)] pub(crate) struct Scopes<'tcx> { @@ -105,6 +109,8 @@ pub(crate) struct Scopes<'tcx> { /// The current set of breakable scopes. See module comment for more details. breakable_scopes: Vec<BreakableScope<'tcx>>, + const_continuable_scopes: Vec<ConstContinuableScope<'tcx>>, + /// The scope of the innermost if-then currently being lowered. if_then_scope: Option<IfThenScope>, @@ -175,6 +181,20 @@ struct BreakableScope<'tcx> { } #[derive(Debug)] +struct ConstContinuableScope<'tcx> { + /// The scope for the `#[loop_match]` which its `#[const_continue]`s will jump to. + region_scope: region::Scope, + /// The place of the state of a `#[loop_match]`, which a `#[const_continue]` must update. + state_place: Place<'tcx>, + + arms: Box<[ArmId]>, + built_match_tree: BuiltMatchTree<'tcx>, + + /// Drops that happen on a `#[const_continue]` + const_continue_drops: DropTree, +} + +#[derive(Debug)] struct IfThenScope { /// The if-then scope or arm scope region_scope: region::Scope, @@ -461,6 +481,7 @@ impl<'tcx> Scopes<'tcx> { Self { scopes: Vec::new(), breakable_scopes: Vec::new(), + const_continuable_scopes: Vec::new(), if_then_scope: None, unwind_drops: DropTree::new(), coroutine_drops: DropTree::new(), @@ -552,6 +573,59 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } + /// Start a const-continuable scope, which tracks where `#[const_continue] break` should + /// branch to. + pub(crate) fn in_const_continuable_scope<F>( + &mut self, + arms: Box<[ArmId]>, + built_match_tree: BuiltMatchTree<'tcx>, + state_place: Place<'tcx>, + span: Span, + f: F, + ) -> BlockAnd<()> + where + F: FnOnce(&mut Builder<'a, 'tcx>) -> BlockAnd<()>, + { + let region_scope = self.scopes.topmost(); + let scope = ConstContinuableScope { + region_scope, + state_place, + const_continue_drops: DropTree::new(), + arms, + built_match_tree, + }; + self.scopes.const_continuable_scopes.push(scope); + let normal_exit_block = f(self); + let const_continue_scope = self.scopes.const_continuable_scopes.pop().unwrap(); + assert!(const_continue_scope.region_scope == region_scope); + + let break_block = self.build_exit_tree( + const_continue_scope.const_continue_drops, + region_scope, + span, + None, + ); + + match (normal_exit_block, break_block) { + (block, None) => block, + (normal_block, Some(exit_block)) => { + let target = self.cfg.start_new_block(); + let source_info = self.source_info(span); + self.cfg.terminate( + normal_block.into_block(), + source_info, + TerminatorKind::Goto { target }, + ); + self.cfg.terminate( + exit_block.into_block(), + source_info, + TerminatorKind::Goto { target }, + ); + target.unit() + } + } + } + /// Start an if-then scope which tracks drop for `if` expressions and `if` /// guards. /// @@ -742,6 +816,190 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.cfg.start_new_block().unit() } + /// Based on `FunctionCx::eval_unevaluated_mir_constant_to_valtree`. + fn eval_unevaluated_mir_constant_to_valtree( + &self, + constant: ConstOperand<'tcx>, + ) -> Result<(ty::ValTree<'tcx>, Ty<'tcx>), interpret::ErrorHandled> { + assert!(!constant.const_.ty().has_param()); + let (uv, ty) = match constant.const_ { + mir::Const::Unevaluated(uv, ty) => (uv.shrink(), ty), + mir::Const::Ty(_, c) => match c.kind() { + // A constant that came from a const generic but was then used as an argument to + // old-style simd_shuffle (passing as argument instead of as a generic param). + ty::ConstKind::Value(cv) => return Ok((cv.valtree, cv.ty)), + other => span_bug!(constant.span, "{other:#?}"), + }, + mir::Const::Val(mir::ConstValue::Scalar(mir::interpret::Scalar::Int(val)), ty) => { + return Ok((ValTree::from_scalar_int(self.tcx, val), ty)); + } + // We should never encounter `Const::Val` unless MIR opts (like const prop) evaluate + // a constant and write that value back into `Operand`s. This could happen, but is + // unlikely. Also: all users of `simd_shuffle` are on unstable and already need to take + // a lot of care around intrinsics. For an issue to happen here, it would require a + // macro expanding to a `simd_shuffle` call without wrapping the constant argument in a + // `const {}` block, but the user pass through arbitrary expressions. + + // FIXME(oli-obk): Replace the magic const generic argument of `simd_shuffle` with a + // real const generic, and get rid of this entire function. + other => span_bug!(constant.span, "{other:#?}"), + }; + + match self.tcx.const_eval_resolve_for_typeck(self.typing_env(), uv, constant.span) { + Ok(Ok(valtree)) => Ok((valtree, ty)), + Ok(Err(ty)) => span_bug!(constant.span, "could not convert {ty:?} to a valtree"), + Err(e) => Err(e), + } + } + + /// Sets up the drops for jumping from `block` to `scope`. + pub(crate) fn break_const_continuable_scope( + &mut self, + mut block: BasicBlock, + value: ExprId, + scope: region::Scope, + source_info: SourceInfo, + ) -> BlockAnd<()> { + let span = source_info.span; + + // A break can only break out of a scope, so the value should be a scope. + let rustc_middle::thir::ExprKind::Scope { value, .. } = self.thir[value].kind else { + span_bug!(span, "break value must be a scope") + }; + + let constant = match &self.thir[value].kind { + ExprKind::Adt(box AdtExpr { variant_index, fields, base, .. }) => { + assert!(matches!(base, AdtExprBase::None)); + assert!(fields.is_empty()); + ConstOperand { + span: self.thir[value].span, + user_ty: None, + const_: Const::Ty( + self.thir[value].ty, + ty::Const::new_value( + self.tcx, + ValTree::from_branches( + self.tcx, + [ValTree::from_scalar_int(self.tcx, variant_index.as_u32().into())], + ), + self.thir[value].ty, + ), + ), + } + } + _ => self.as_constant(&self.thir[value]), + }; + + let break_index = self + .scopes + .const_continuable_scopes + .iter() + .rposition(|const_continuable_scope| const_continuable_scope.region_scope == scope) + .unwrap_or_else(|| span_bug!(span, "no enclosing const-continuable scope found")); + + let scope = &self.scopes.const_continuable_scopes[break_index]; + + let state_decl = &self.local_decls[scope.state_place.as_local().unwrap()]; + let state_ty = state_decl.ty; + let (discriminant_ty, rvalue) = match state_ty.kind() { + ty::Adt(adt_def, _) if adt_def.is_enum() => { + (state_ty.discriminant_ty(self.tcx), Rvalue::Discriminant(scope.state_place)) + } + ty::Uint(_) | ty::Int(_) | ty::Float(_) | ty::Bool | ty::Char => { + (state_ty, Rvalue::Use(Operand::Copy(scope.state_place))) + } + _ => span_bug!(state_decl.source_info.span, "unsupported #[loop_match] state"), + }; + + // The `PatCtxt` is normally used in pattern exhaustiveness checking, but reused + // here because it performs normalization and const evaluation. + let dropless_arena = rustc_arena::DroplessArena::default(); + let typeck_results = self.tcx.typeck(self.def_id); + let cx = RustcPatCtxt { + tcx: self.tcx, + typeck_results, + module: self.tcx.parent_module(self.hir_id).to_def_id(), + // FIXME(#132279): We're in a body, should handle opaques. + typing_env: rustc_middle::ty::TypingEnv::non_body_analysis(self.tcx, self.def_id), + dropless_arena: &dropless_arena, + match_lint_level: self.hir_id, + whole_match_span: Some(rustc_span::Span::default()), + scrut_span: rustc_span::Span::default(), + refutable: true, + known_valid_scrutinee: true, + }; + + let valtree = match self.eval_unevaluated_mir_constant_to_valtree(constant) { + Ok((valtree, ty)) => { + // Defensively check that the type is monomorphic. + assert!(!ty.has_param()); + + valtree + } + Err(ErrorHandled::Reported(..)) => return self.cfg.start_new_block().unit(), + Err(ErrorHandled::TooGeneric(_)) => { + self.tcx.dcx().emit_fatal(ConstContinueBadConst { span: constant.span }); + } + }; + + let Some(real_target) = + self.static_pattern_match(&cx, valtree, &*scope.arms, &scope.built_match_tree) + else { + self.tcx.dcx().emit_fatal(ConstContinueUnknownJumpTarget { span }) + }; + + self.block_context.push(BlockFrame::SubExpr); + let state_place = scope.state_place; + block = self.expr_into_dest(state_place, block, value).into_block(); + self.block_context.pop(); + + let discr = self.temp(discriminant_ty, source_info.span); + let scope_index = self + .scopes + .scope_index(self.scopes.const_continuable_scopes[break_index].region_scope, span); + let scope = &mut self.scopes.const_continuable_scopes[break_index]; + self.cfg.push_assign(block, source_info, discr, rvalue); + let drop_and_continue_block = self.cfg.start_new_block(); + let imaginary_target = self.cfg.start_new_block(); + self.cfg.terminate( + block, + source_info, + TerminatorKind::FalseEdge { real_target: drop_and_continue_block, imaginary_target }, + ); + + let drops = &mut scope.const_continue_drops; + + let drop_idx = self.scopes.scopes[scope_index + 1..] + .iter() + .flat_map(|scope| &scope.drops) + .fold(ROOT_NODE, |drop_idx, &drop| drops.add_drop(drop, drop_idx)); + + drops.add_entry_point(imaginary_target, drop_idx); + + self.cfg.terminate(imaginary_target, source_info, TerminatorKind::UnwindResume); + + let region_scope = scope.region_scope; + let scope_index = self.scopes.scope_index(region_scope, span); + let mut drops = DropTree::new(); + + let drop_idx = self.scopes.scopes[scope_index + 1..] + .iter() + .flat_map(|scope| &scope.drops) + .fold(ROOT_NODE, |drop_idx, &drop| drops.add_drop(drop, drop_idx)); + + drops.add_entry_point(drop_and_continue_block, drop_idx); + + // `build_drop_trees` doesn't have access to our source_info, so we + // create a dummy terminator now. `TerminatorKind::UnwindResume` is used + // because MIR type checking will panic if it hasn't been overwritten. + // (See `<ExitScopes as DropTreeBuilder>::link_entry_point`.) + self.cfg.terminate(drop_and_continue_block, source_info, TerminatorKind::UnwindResume); + + self.build_exit_tree(drops, region_scope, span, Some(real_target)); + + return self.cfg.start_new_block().unit(); + } + /// Sets up the drops for breaking from `block` due to an `if` condition /// that turned out to be false. /// diff --git a/compiler/rustc_mir_build/src/check_unsafety.rs b/compiler/rustc_mir_build/src/check_unsafety.rs index d5061b71699..0b6b36640e9 100644 --- a/compiler/rustc_mir_build/src/check_unsafety.rs +++ b/compiler/rustc_mir_build/src/check_unsafety.rs @@ -465,10 +465,12 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> { | ExprKind::Break { .. } | ExprKind::Closure { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } | ExprKind::Yield { .. } | ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::Let { .. } | ExprKind::Match { .. } | ExprKind::Box { .. } diff --git a/compiler/rustc_mir_build/src/errors.rs b/compiler/rustc_mir_build/src/errors.rs index ae09db50235..32df191cbca 100644 --- a/compiler/rustc_mir_build/src/errors.rs +++ b/compiler/rustc_mir_build/src/errors.rs @@ -1149,3 +1149,80 @@ impl Subdiagnostic for Rust2024IncompatiblePatSugg { } } } + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_invalid_update)] +pub(crate) struct LoopMatchInvalidUpdate { + #[primary_span] + pub lhs: Span, + #[label] + pub scrutinee: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_invalid_match)] +#[note] +pub(crate) struct LoopMatchInvalidMatch { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_unsupported_type)] +#[note] +pub(crate) struct LoopMatchUnsupportedType<'tcx> { + #[primary_span] + pub span: Span, + pub ty: Ty<'tcx>, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_bad_statements)] +pub(crate) struct LoopMatchBadStatements { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_bad_rhs)] +pub(crate) struct LoopMatchBadRhs { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_missing_assignment)] +pub(crate) struct LoopMatchMissingAssignment { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_arm_with_guard)] +pub(crate) struct LoopMatchArmWithGuard { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_const_continue_bad_const)] +pub(crate) struct ConstContinueBadConst { + #[primary_span] + #[label] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_const_continue_missing_value)] +pub(crate) struct ConstContinueMissingValue { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_const_continue_unknown_jump_target)] +#[note] +pub(crate) struct ConstContinueUnknownJumpTarget { + #[primary_span] + pub span: Span, +} diff --git a/compiler/rustc_mir_build/src/thir/cx/expr.rs b/compiler/rustc_mir_build/src/thir/cx/expr.rs index 764b7efe2a3..5197e93fda7 100644 --- a/compiler/rustc_mir_build/src/thir/cx/expr.rs +++ b/compiler/rustc_mir_build/src/thir/cx/expr.rs @@ -1,6 +1,7 @@ use itertools::Itertools; use rustc_abi::{FIRST_VARIANT, FieldIdx}; use rustc_ast::UnsafeBinderCastKind; +use rustc_attr_data_structures::{AttributeKind, find_attr}; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_hir as hir; use rustc_hir::def::{CtorKind, CtorOf, DefKind, Res}; @@ -21,6 +22,7 @@ use rustc_middle::{bug, span_bug}; use rustc_span::{Span, sym}; use tracing::{debug, info, instrument, trace}; +use crate::errors::*; use crate::thir::cx::ThirBuildCx; impl<'tcx> ThirBuildCx<'tcx> { @@ -845,16 +847,38 @@ impl<'tcx> ThirBuildCx<'tcx> { } hir::ExprKind::Ret(v) => ExprKind::Return { value: v.map(|v| self.mirror_expr(v)) }, hir::ExprKind::Become(call) => ExprKind::Become { value: self.mirror_expr(call) }, - hir::ExprKind::Break(dest, ref value) => match dest.target_id { - Ok(target_id) => ExprKind::Break { - label: region::Scope { - local_id: target_id.local_id, - data: region::ScopeData::Node, - }, - value: value.map(|value| self.mirror_expr(value)), - }, - Err(err) => bug!("invalid loop id for break: {}", err), - }, + hir::ExprKind::Break(dest, ref value) => { + if find_attr!(self.tcx.hir_attrs(expr.hir_id), AttributeKind::ConstContinue(_)) { + match dest.target_id { + Ok(target_id) => { + let Some(value) = value else { + let span = expr.span; + self.tcx.dcx().emit_fatal(ConstContinueMissingValue { span }) + }; + + ExprKind::ConstContinue { + label: region::Scope { + local_id: target_id.local_id, + data: region::ScopeData::Node, + }, + value: self.mirror_expr(value), + } + } + Err(err) => bug!("invalid loop id for break: {}", err), + } + } else { + match dest.target_id { + Ok(target_id) => ExprKind::Break { + label: region::Scope { + local_id: target_id.local_id, + data: region::ScopeData::Node, + }, + value: value.map(|value| self.mirror_expr(value)), + }, + Err(err) => bug!("invalid loop id for break: {}", err), + } + } + } hir::ExprKind::Continue(dest) => match dest.target_id { Ok(loop_id) => ExprKind::Continue { label: region::Scope { @@ -889,18 +913,93 @@ impl<'tcx> ThirBuildCx<'tcx> { match_source, }, hir::ExprKind::Loop(body, ..) => { - let block_ty = self.typeck_results.node_type(body.hir_id); - let (temp_lifetime, backwards_incompatible) = self - .rvalue_scopes - .temporary_scope(self.region_scope_tree, body.hir_id.local_id); - let block = self.mirror_block(body); - let body = self.thir.exprs.push(Expr { - ty: block_ty, - temp_lifetime: TempLifetime { temp_lifetime, backwards_incompatible }, - span: self.thir[block].span, - kind: ExprKind::Block { block }, - }); - ExprKind::Loop { body } + if find_attr!(self.tcx.hir_attrs(expr.hir_id), AttributeKind::LoopMatch(_)) { + let dcx = self.tcx.dcx(); + + // Accept either `state = expr` or `state = expr;`. + let loop_body_expr = match body.stmts { + [] => match body.expr { + Some(expr) => expr, + None => dcx.emit_fatal(LoopMatchMissingAssignment { span: body.span }), + }, + [single] if body.expr.is_none() => match single.kind { + hir::StmtKind::Expr(expr) | hir::StmtKind::Semi(expr) => expr, + _ => dcx.emit_fatal(LoopMatchMissingAssignment { span: body.span }), + }, + [first @ last] | [first, .., last] => dcx + .emit_fatal(LoopMatchBadStatements { span: first.span.to(last.span) }), + }; + + let hir::ExprKind::Assign(state, rhs_expr, _) = loop_body_expr.kind else { + dcx.emit_fatal(LoopMatchMissingAssignment { span: loop_body_expr.span }) + }; + + let hir::ExprKind::Block(block_body, _) = rhs_expr.kind else { + dcx.emit_fatal(LoopMatchBadRhs { span: rhs_expr.span }) + }; + + // The labeled block should contain one match expression, but defining items is + // allowed. + for stmt in block_body.stmts { + if !matches!(stmt.kind, rustc_hir::StmtKind::Item(_)) { + dcx.emit_fatal(LoopMatchBadStatements { span: stmt.span }) + } + } + + let Some(block_body_expr) = block_body.expr else { + dcx.emit_fatal(LoopMatchBadRhs { span: block_body.span }) + }; + + let hir::ExprKind::Match(scrutinee, arms, _match_source) = block_body_expr.kind + else { + dcx.emit_fatal(LoopMatchBadRhs { span: block_body_expr.span }) + }; + + fn local(expr: &rustc_hir::Expr<'_>) -> Option<hir::HirId> { + if let hir::ExprKind::Path(hir::QPath::Resolved(_, path)) = expr.kind { + if let Res::Local(hir_id) = path.res { + return Some(hir_id); + } + } + + None + } + + let Some(scrutinee_hir_id) = local(scrutinee) else { + dcx.emit_fatal(LoopMatchInvalidMatch { span: scrutinee.span }) + }; + + if local(state) != Some(scrutinee_hir_id) { + dcx.emit_fatal(LoopMatchInvalidUpdate { + scrutinee: scrutinee.span, + lhs: state.span, + }) + } + + ExprKind::LoopMatch { + state: self.mirror_expr(state), + region_scope: region::Scope { + local_id: block_body.hir_id.local_id, + data: region::ScopeData::Node, + }, + + arms: arms.iter().map(|a| self.convert_arm(a)).collect(), + match_span: block_body_expr.span, + } + } else { + let block_ty = self.typeck_results.node_type(body.hir_id); + let (temp_lifetime, backwards_incompatible) = self + .rvalue_scopes + .temporary_scope(self.region_scope_tree, body.hir_id.local_id); + let block = self.mirror_block(body); + let body = self.thir.exprs.push(Expr { + ty: block_ty, + temp_lifetime: TempLifetime { temp_lifetime, backwards_incompatible }, + span: self.thir[block].span, + kind: ExprKind::Block { block }, + }); + ExprKind::Loop { body } + } } hir::ExprKind::Field(source, ..) => ExprKind::Field { lhs: self.mirror_expr(source), diff --git a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs index 245bd866030..9fd410e6bf1 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs @@ -331,7 +331,11 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> { | WrapUnsafeBinder { source } => self.is_known_valid_scrutinee(&self.thir()[*source]), // These diverge. - Become { .. } | Break { .. } | Continue { .. } | Return { .. } => true, + Become { .. } + | Break { .. } + | Continue { .. } + | ConstContinue { .. } + | Return { .. } => true, // These are statements that evaluate to `()`. Assign { .. } | AssignOp { .. } | InlineAsm { .. } | Let { .. } => true, @@ -353,6 +357,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> { | Literal { .. } | LogicalOp { .. } | Loop { .. } + | LoopMatch { .. } | Match { .. } | NamedConst { .. } | NonHirLiteral { .. } diff --git a/compiler/rustc_mir_build/src/thir/print.rs b/compiler/rustc_mir_build/src/thir/print.rs index db9547a481f..1507b6b8c06 100644 --- a/compiler/rustc_mir_build/src/thir/print.rs +++ b/compiler/rustc_mir_build/src/thir/print.rs @@ -318,6 +318,20 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> { self.print_expr(*body, depth_lvl + 2); print_indented!(self, ")", depth_lvl); } + LoopMatch { state, region_scope, match_span, arms } => { + print_indented!(self, "LoopMatch {", depth_lvl); + print_indented!(self, "state:", depth_lvl + 1); + self.print_expr(*state, depth_lvl + 2); + print_indented!(self, format!("region_scope: {:?}", region_scope), depth_lvl + 1); + print_indented!(self, format!("match_span: {:?}", match_span), depth_lvl + 1); + + print_indented!(self, "arms: [", depth_lvl + 1); + for arm_id in arms.iter() { + self.print_arm(*arm_id, depth_lvl + 2); + } + print_indented!(self, "]", depth_lvl + 1); + print_indented!(self, "}", depth_lvl); + } Let { expr, pat } => { print_indented!(self, "Let {", depth_lvl); print_indented!(self, "expr:", depth_lvl + 1); @@ -415,6 +429,13 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> { print_indented!(self, format!("label: {:?}", label), depth_lvl + 1); print_indented!(self, "}", depth_lvl); } + ConstContinue { label, value } => { + print_indented!(self, "ConstContinue (", depth_lvl); + print_indented!(self, format!("label: {:?}", label), depth_lvl + 1); + print_indented!(self, "value:", depth_lvl + 1); + self.print_expr(*value, depth_lvl + 2); + print_indented!(self, ")", depth_lvl); + } Return { value } => { print_indented!(self, "Return {", depth_lvl); print_indented!(self, "value:", depth_lvl + 1); |
