diff options
| author | bjorn3 <17426603+bjorn3@users.noreply.github.com> | 2025-02-18 14:16:57 +0100 |
|---|---|---|
| committer | Folkert de Vries <folkert@folkertdev.nl> | 2025-06-23 20:43:04 +0200 |
| commit | ba5556d239c11232dc8d95123ea70a2783019476 (patch) | |
| tree | 0fc9a5eb5964d18de0ab43107b5964818f8bb49f /compiler | |
| parent | 42245d34d22ade32b3f276dcf74deb826841594c (diff) | |
| download | rust-ba5556d239c11232dc8d95123ea70a2783019476.tar.gz rust-ba5556d239c11232dc8d95123ea70a2783019476.zip | |
Add `#[loop_match]` for improved DFA codegen
Co-authored-by: Folkert de Vries <folkert@folkertdev.nl>
Diffstat (limited to 'compiler')
31 files changed, 1002 insertions, 45 deletions
diff --git a/compiler/rustc_attr_data_structures/src/attributes.rs b/compiler/rustc_attr_data_structures/src/attributes.rs index 9227b81f12f..4995b855f32 100644 --- a/compiler/rustc_attr_data_structures/src/attributes.rs +++ b/compiler/rustc_attr_data_structures/src/attributes.rs @@ -212,6 +212,9 @@ pub enum AttributeKind { first_span: Span, }, + /// Represents `#[const_continue]`. + ConstContinue(Span), + /// Represents `#[rustc_const_stable]` and `#[rustc_const_unstable]`. ConstStability { stability: PartialConstStability, @@ -231,6 +234,9 @@ pub enum AttributeKind { /// Represents `#[inline]` and `#[rustc_force_inline]`. Inline(InlineAttr, Span), + /// Represents `#[loop_match]`. + LoopMatch(Span), + /// Represents `#[rustc_macro_transparency]`. MacroTransparency(Transparency), diff --git a/compiler/rustc_attr_parsing/src/attributes/loop_match.rs b/compiler/rustc_attr_parsing/src/attributes/loop_match.rs new file mode 100644 index 00000000000..f6c7ac5e3a3 --- /dev/null +++ b/compiler/rustc_attr_parsing/src/attributes/loop_match.rs @@ -0,0 +1,31 @@ +use rustc_attr_data_structures::AttributeKind; +use rustc_feature::{AttributeTemplate, template}; +use rustc_span::{Symbol, sym}; + +use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser}; +use crate::context::{AcceptContext, Stage}; +use crate::parser::ArgParser; + +pub(crate) struct LoopMatchParser; +impl<S: Stage> SingleAttributeParser<S> for LoopMatchParser { + const PATH: &[Symbol] = &[sym::loop_match]; + const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepFirst; + const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Warn; + const TEMPLATE: AttributeTemplate = template!(Word); + + fn convert(cx: &mut AcceptContext<'_, '_, S>, _args: &ArgParser<'_>) -> Option<AttributeKind> { + Some(AttributeKind::LoopMatch(cx.attr_span)) + } +} + +pub(crate) struct ConstContinueParser; +impl<S: Stage> SingleAttributeParser<S> for ConstContinueParser { + const PATH: &[Symbol] = &[sym::const_continue]; + const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepFirst; + const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Warn; + const TEMPLATE: AttributeTemplate = template!(Word); + + fn convert(cx: &mut AcceptContext<'_, '_, S>, _args: &ArgParser<'_>) -> Option<AttributeKind> { + Some(AttributeKind::ConstContinue(cx.attr_span)) + } +} diff --git a/compiler/rustc_attr_parsing/src/attributes/mod.rs b/compiler/rustc_attr_parsing/src/attributes/mod.rs index 738d8735b69..bc18ec8a9c0 100644 --- a/compiler/rustc_attr_parsing/src/attributes/mod.rs +++ b/compiler/rustc_attr_parsing/src/attributes/mod.rs @@ -32,6 +32,7 @@ pub(crate) mod confusables; pub(crate) mod deprecation; pub(crate) mod inline; pub(crate) mod lint_helpers; +pub(crate) mod loop_match; pub(crate) mod must_use; pub(crate) mod repr; pub(crate) mod semantics; diff --git a/compiler/rustc_attr_parsing/src/context.rs b/compiler/rustc_attr_parsing/src/context.rs index 457e073c488..c89ee8131bb 100644 --- a/compiler/rustc_attr_parsing/src/context.rs +++ b/compiler/rustc_attr_parsing/src/context.rs @@ -20,6 +20,7 @@ use crate::attributes::confusables::ConfusablesParser; use crate::attributes::deprecation::DeprecationParser; use crate::attributes::inline::{InlineParser, RustcForceInlineParser}; use crate::attributes::lint_helpers::{AsPtrParser, PubTransparentParser}; +use crate::attributes::loop_match::{ConstContinueParser, LoopMatchParser}; use crate::attributes::must_use::MustUseParser; use crate::attributes::repr::{AlignParser, ReprParser}; use crate::attributes::semantics::MayDangleParser; @@ -110,9 +111,11 @@ attribute_parsers!( // tidy-alphabetical-start Single<AsPtrParser>, Single<ColdParser>, + Single<ConstContinueParser>, Single<ConstStabilityIndirectParser>, Single<DeprecationParser>, Single<InlineParser>, + Single<LoopMatchParser>, Single<MayDangleParser>, Single<MustUseParser>, Single<NoMangleParser>, diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index 280b33f0723..85f3f51cfdf 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -657,6 +657,19 @@ pub static BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ EncodeCrossCrate::Yes, min_generic_const_args, experimental!(type_const), ), + // The `#[loop_match]` and `#[const_continue]` attributes are part of the + // lang experiment for RFC 3720 tracked in: + // + // - https://github.com/rust-lang/rust/issues/132306 + gated!( + const_continue, Normal, template!(Word), ErrorFollowing, + EncodeCrossCrate::No, loop_match, experimental!(const_continue) + ), + gated!( + loop_match, Normal, template!(Word), ErrorFollowing, + EncodeCrossCrate::No, loop_match, experimental!(loop_match) + ), + // ========================================================================== // Internal attributes: Stability, deprecation, and unsafe: // ========================================================================== diff --git a/compiler/rustc_feature/src/unstable.rs b/compiler/rustc_feature/src/unstable.rs index 91715851226..d9d5334615a 100644 --- a/compiler/rustc_feature/src/unstable.rs +++ b/compiler/rustc_feature/src/unstable.rs @@ -557,6 +557,8 @@ declare_features! ( /// Allows using `#[link(kind = "link-arg", name = "...")]` /// to pass custom arguments to the linker. (unstable, link_arg_attribute, "1.76.0", Some(99427)), + /// Allows fused `loop`/`match` for direct intraprocedural jumps. + (incomplete, loop_match, "CURRENT_RUSTC_VERSION", Some(132306)), /// Give access to additional metadata about declarative macro meta-variables. (unstable, macro_metavar_expr, "1.61.0", Some(83527)), /// Provides a way to concatenate identifiers using metavariable expressions. diff --git a/compiler/rustc_hir_typeck/messages.ftl b/compiler/rustc_hir_typeck/messages.ftl index 258535f3742..c21b16c9f9f 100644 --- a/compiler/rustc_hir_typeck/messages.ftl +++ b/compiler/rustc_hir_typeck/messages.ftl @@ -79,6 +79,9 @@ hir_typeck_cast_unknown_pointer = cannot cast {$to -> .note = the type information given here is insufficient to check whether the pointer cast is valid .label_from = the type information given here is insufficient to check whether the pointer cast is valid +hir_typeck_const_continue_bad_label = + `#[const_continue]` must break to a labeled block that participates in a `#[loop_match]` + hir_typeck_const_select_must_be_const = this argument must be a `const fn` .help = consult the documentation on `const_eval_select` for more information diff --git a/compiler/rustc_hir_typeck/src/errors.rs b/compiler/rustc_hir_typeck/src/errors.rs index 5fea0c62843..3606c778fc4 100644 --- a/compiler/rustc_hir_typeck/src/errors.rs +++ b/compiler/rustc_hir_typeck/src/errors.rs @@ -1167,3 +1167,10 @@ pub(crate) struct AbiCannotBeCalled { pub span: Span, pub abi: ExternAbi, } + +#[derive(Diagnostic)] +#[diag(hir_typeck_const_continue_bad_label)] +pub(crate) struct ConstContinueBadLabel { + #[primary_span] + pub span: Span, +} diff --git a/compiler/rustc_hir_typeck/src/loops.rs b/compiler/rustc_hir_typeck/src/loops.rs index b06e0704b6f..80eab578f13 100644 --- a/compiler/rustc_hir_typeck/src/loops.rs +++ b/compiler/rustc_hir_typeck/src/loops.rs @@ -2,6 +2,8 @@ use std::collections::BTreeMap; use std::fmt; use Context::*; +use rustc_ast::Label; +use rustc_attr_data_structures::{AttributeKind, find_attr}; use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_hir::def_id::LocalDefId; @@ -14,8 +16,9 @@ use rustc_span::hygiene::DesugaringKind; use rustc_span::{BytePos, Span}; use crate::errors::{ - BreakInsideClosure, BreakInsideCoroutine, BreakNonLoop, ContinueLabeledBlock, OutsideLoop, - OutsideLoopSuggestion, UnlabeledCfInWhileCondition, UnlabeledInLabeledBlock, + BreakInsideClosure, BreakInsideCoroutine, BreakNonLoop, ConstContinueBadLabel, + ContinueLabeledBlock, OutsideLoop, OutsideLoopSuggestion, UnlabeledCfInWhileCondition, + UnlabeledInLabeledBlock, }; /// The context in which a block is encountered. @@ -37,6 +40,11 @@ enum Context { AnonConst, /// E.g. `const { ... }`. ConstBlock, + /// E.g. `#[loop_match] loop { state = 'label: { /* ... */ } }`. + LoopMatch { + /// The label of the labeled block (not of the loop itself). + labeled_block: Label, + }, } #[derive(Clone)] @@ -141,7 +149,12 @@ impl<'hir> Visitor<'hir> for CheckLoopVisitor<'hir> { } } hir::ExprKind::Loop(ref b, _, source, _) => { - self.with_context(Loop(source), |v| v.visit_block(b)); + let cx = match self.is_loop_match(e, b) { + Some(labeled_block) => LoopMatch { labeled_block }, + None => Loop(source), + }; + + self.with_context(cx, |v| v.visit_block(b)); } hir::ExprKind::Closure(&hir::Closure { ref fn_decl, body, fn_decl_span, kind, .. @@ -197,6 +210,23 @@ impl<'hir> Visitor<'hir> for CheckLoopVisitor<'hir> { Err(hir::LoopIdError::UnresolvedLabel) => None, }; + // A `#[const_continue]` must break to a block in a `#[loop_match]`. + if find_attr!(self.tcx.hir_attrs(e.hir_id), AttributeKind::ConstContinue(_)) { + if let Some(break_label) = break_label.label { + let is_target_label = |cx: &Context| match cx { + Context::LoopMatch { labeled_block } => { + break_label.ident.name == labeled_block.ident.name + } + _ => false, + }; + + if !self.cx_stack.iter().rev().any(is_target_label) { + let span = break_label.ident.span; + self.tcx.dcx().emit_fatal(ConstContinueBadLabel { span }); + } + } + } + if let Some(Node::Block(_)) = loop_id.map(|id| self.tcx.hir_node(id)) { return; } @@ -299,7 +329,7 @@ impl<'hir> CheckLoopVisitor<'hir> { cx_pos: usize, ) { match self.cx_stack[cx_pos] { - LabeledBlock | Loop(_) => {} + LabeledBlock | Loop(_) | LoopMatch { .. } => {} Closure(closure_span) => { self.tcx.dcx().emit_err(BreakInsideClosure { span, @@ -380,4 +410,36 @@ impl<'hir> CheckLoopVisitor<'hir> { }); } } + + /// Is this a loop annotated with `#[loop_match]` that looks syntactically sound? + fn is_loop_match( + &self, + e: &'hir hir::Expr<'hir>, + body: &'hir hir::Block<'hir>, + ) -> Option<Label> { + if !find_attr!(self.tcx.hir_attrs(e.hir_id), AttributeKind::LoopMatch(_)) { + return None; + } + + // NOTE: Diagnostics are emitted during MIR construction. + + // Accept either `state = expr` or `state = expr;`. + let loop_body_expr = match body.stmts { + [] => match body.expr { + Some(expr) => expr, + None => return None, + }, + [single] if body.expr.is_none() => match single.kind { + hir::StmtKind::Expr(expr) | hir::StmtKind::Semi(expr) => expr, + _ => return None, + }, + [..] => return None, + }; + + let hir::ExprKind::Assign(_, rhs_expr, _) = loop_body_expr.kind else { return None }; + + let hir::ExprKind::Block(_, label) = rhs_expr.kind else { return None }; + + label + } } diff --git a/compiler/rustc_middle/src/thir.rs b/compiler/rustc_middle/src/thir.rs index b9a014d14c0..d0e72a86d8a 100644 --- a/compiler/rustc_middle/src/thir.rs +++ b/compiler/rustc_middle/src/thir.rs @@ -378,6 +378,14 @@ pub enum ExprKind<'tcx> { Loop { body: ExprId, }, + /// A `#[loop_match] loop { state = 'blk: { match state { ... } } }` expression. + LoopMatch { + /// The state variable that is updated, and also the scrutinee of the match. + state: ExprId, + region_scope: region::Scope, + arms: Box<[ArmId]>, + match_span: Span, + }, /// Special expression representing the `let` part of an `if let` or similar construct /// (including `if let` guards in match arms, and let-chains formed by `&&`). /// @@ -454,6 +462,11 @@ pub enum ExprKind<'tcx> { Continue { label: region::Scope, }, + /// A `#[const_continue] break` expression. + ConstContinue { + label: region::Scope, + value: ExprId, + }, /// A `return` expression. Return { value: Option<ExprId>, diff --git a/compiler/rustc_middle/src/thir/visit.rs b/compiler/rustc_middle/src/thir/visit.rs index d8743814d79..c9ef723aea4 100644 --- a/compiler/rustc_middle/src/thir/visit.rs +++ b/compiler/rustc_middle/src/thir/visit.rs @@ -83,7 +83,7 @@ pub fn walk_expr<'thir, 'tcx: 'thir, V: Visitor<'thir, 'tcx>>( visitor.visit_pat(pat); } Loop { body } => visitor.visit_expr(&visitor.thir()[body]), - Match { scrutinee, ref arms, .. } => { + LoopMatch { state: scrutinee, ref arms, .. } | Match { scrutinee, ref arms, .. } => { visitor.visit_expr(&visitor.thir()[scrutinee]); for &arm in &**arms { visitor.visit_arm(&visitor.thir()[arm]); @@ -108,6 +108,7 @@ pub fn walk_expr<'thir, 'tcx: 'thir, V: Visitor<'thir, 'tcx>>( } } Continue { label: _ } => {} + ConstContinue { value, label: _ } => visitor.visit_expr(&visitor.thir()[value]), Return { value } => { if let Some(value) = value { visitor.visit_expr(&visitor.thir()[value]) diff --git a/compiler/rustc_mir_build/Cargo.toml b/compiler/rustc_mir_build/Cargo.toml index 1534865cdd3..c4c2d8a7ac8 100644 --- a/compiler/rustc_mir_build/Cargo.toml +++ b/compiler/rustc_mir_build/Cargo.toml @@ -11,6 +11,7 @@ rustc_abi = { path = "../rustc_abi" } rustc_apfloat = "0.2.0" rustc_arena = { path = "../rustc_arena" } rustc_ast = { path = "../rustc_ast" } +rustc_attr_data_structures = { path = "../rustc_attr_data_structures" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_fluent_macro = { path = "../rustc_fluent_macro" } diff --git a/compiler/rustc_mir_build/messages.ftl b/compiler/rustc_mir_build/messages.ftl index fae159103e7..e339520cd86 100644 --- a/compiler/rustc_mir_build/messages.ftl +++ b/compiler/rustc_mir_build/messages.ftl @@ -84,6 +84,15 @@ mir_build_call_to_unsafe_fn_requires_unsafe_unsafe_op_in_unsafe_fn_allowed = mir_build_confused = missing patterns are not covered because `{$variable}` is interpreted as a constant pattern, not a new variable +mir_build_const_continue_bad_const = could not determine the target branch for this `#[const_continue]` + .label = this value is too generic + .note = the value must be a literal or a monomorphic const + +mir_build_const_continue_missing_value = a `#[const_continue]` must break to a label with a value + +mir_build_const_continue_unknown_jump_target = the target of this `#[const_continue]` is not statically known + .label = this value must be a literal or a monomorphic const + mir_build_const_defined_here = constant defined here mir_build_const_param_in_pattern = constant parameters cannot be referenced in patterns @@ -212,6 +221,30 @@ mir_build_literal_in_range_out_of_bounds = literal out of range for `{$ty}` .label = this value does not fit into the type `{$ty}` whose range is `{$min}..={$max}` +mir_build_loop_match_arm_with_guard = + match arms that are part of a `#[loop_match]` cannot have guards + +mir_build_loop_match_bad_rhs = + this expression must be a single `match` wrapped in a labeled block + +mir_build_loop_match_bad_statements = + statements are not allowed in this position within a `#[loop_match]` + +mir_build_loop_match_invalid_match = + invalid match on `#[loop_match]` state + .note = a local variable must be the scrutinee within a `#[loop_match]` + +mir_build_loop_match_invalid_update = + invalid update of the `#[loop_match]` state + .label = the assignment must update this variable + +mir_build_loop_match_missing_assignment = + expected a single assignment expression + +mir_build_loop_match_unsupported_type = + this `#[loop_match]` state value has type `{$ty}`, which is not supported + .note = only integers, floats, bool, char, and enums without fields are supported + mir_build_lower_range_bound_must_be_less_than_or_equal_to_upper = lower range bound must be less than or equal to upper .label = lower bound larger than upper bound diff --git a/compiler/rustc_mir_build/src/builder/expr/as_place.rs b/compiler/rustc_mir_build/src/builder/expr/as_place.rs index f8c64d7d13e..99148504a87 100644 --- a/compiler/rustc_mir_build/src/builder/expr/as_place.rs +++ b/compiler/rustc_mir_build/src/builder/expr/as_place.rs @@ -565,12 +565,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | ExprKind::Match { .. } | ExprKind::If { .. } | ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::Block { .. } | ExprKind::Let { .. } | ExprKind::Assign { .. } | ExprKind::AssignOp { .. } | ExprKind::Break { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } | ExprKind::Literal { .. } diff --git a/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs b/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs index b23bc089cd4..9e07dd5da7e 100644 --- a/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs +++ b/compiler/rustc_mir_build/src/builder/expr/as_rvalue.rs @@ -538,6 +538,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | ExprKind::RawBorrow { .. } | ExprKind::Adt { .. } | ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::LogicalOp { .. } | ExprKind::Call { .. } | ExprKind::Field { .. } @@ -548,6 +549,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | ExprKind::UpvarRef { .. } | ExprKind::Break { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } | ExprKind::InlineAsm { .. } diff --git a/compiler/rustc_mir_build/src/builder/expr/category.rs b/compiler/rustc_mir_build/src/builder/expr/category.rs index 34524aed406..5e4219dbf5b 100644 --- a/compiler/rustc_mir_build/src/builder/expr/category.rs +++ b/compiler/rustc_mir_build/src/builder/expr/category.rs @@ -83,9 +83,11 @@ impl Category { | ExprKind::NamedConst { .. } => Some(Category::Constant), ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::Block { .. } | ExprKind::Break { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } => // FIXME(#27840) these probably want their own diff --git a/compiler/rustc_mir_build/src/builder/expr/into.rs b/compiler/rustc_mir_build/src/builder/expr/into.rs index 2074fbce0ae..fe3d072fa88 100644 --- a/compiler/rustc_mir_build/src/builder/expr/into.rs +++ b/compiler/rustc_mir_build/src/builder/expr/into.rs @@ -8,15 +8,16 @@ use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; use rustc_middle::span_bug; use rustc_middle::thir::*; -use rustc_middle::ty::{CanonicalUserTypeAnnotation, Ty}; +use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty}; use rustc_span::DUMMY_SP; use rustc_span::source_map::Spanned; use rustc_trait_selection::infer::InferCtxtExt; use tracing::{debug, instrument}; use crate::builder::expr::category::{Category, RvalueFunc}; -use crate::builder::matches::DeclareLetBindings; +use crate::builder::matches::{DeclareLetBindings, HasMatchGuard}; use crate::builder::{BlockAnd, BlockAndExtension, BlockFrame, Builder, NeedsTemporary}; +use crate::errors::{LoopMatchArmWithGuard, LoopMatchUnsupportedType}; impl<'a, 'tcx> Builder<'a, 'tcx> { /// Compile `expr`, storing the result into `destination`, which @@ -244,6 +245,122 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { None }) } + ExprKind::LoopMatch { state, region_scope, match_span, ref arms } => { + // Intuitively, this is a combination of a loop containing a labeled block + // containing a match. + // + // The only new bit here is that the lowering of the match is wrapped in a + // `in_const_continuable_scope`, which makes the match arms and their target basic + // block available to the lowering of `#[const_continue]`. + + fn is_supported_loop_match_type(ty: Ty<'_>) -> bool { + match ty.kind() { + ty::Uint(_) | ty::Int(_) | ty::Float(_) | ty::Bool | ty::Char => true, + ty::Adt(adt_def, _) => match adt_def.adt_kind() { + ty::AdtKind::Struct | ty::AdtKind::Union => false, + ty::AdtKind::Enum => { + adt_def.variants().iter().all(|v| v.fields.is_empty()) + } + }, + _ => false, + } + } + + let state_ty = this.thir.exprs[state].ty; + if !is_supported_loop_match_type(state_ty) { + let span = this.thir.exprs[state].span; + this.tcx.dcx().emit_fatal(LoopMatchUnsupportedType { span, ty: state_ty }) + } + + let loop_block = this.cfg.start_new_block(); + + // Start the loop. + this.cfg.goto(block, source_info, loop_block); + + this.in_breakable_scope(Some(loop_block), destination, expr_span, |this| { + // Logic for `loop`. + let mut body_block = this.cfg.start_new_block(); + this.cfg.terminate( + loop_block, + source_info, + TerminatorKind::FalseUnwind { + real_target: body_block, + unwind: UnwindAction::Continue, + }, + ); + this.diverge_from(loop_block); + + // Logic for `match`. + let scrutinee_place_builder = + unpack!(body_block = this.as_place_builder(body_block, state)); + let scrutinee_span = this.thir.exprs[state].span; + let match_start_span = match_span.shrink_to_lo().to(scrutinee_span); + + let mut patterns = Vec::with_capacity(arms.len()); + for &arm_id in arms.iter() { + let arm = &this.thir[arm_id]; + + if let Some(guard) = arm.guard { + let span = this.thir.exprs[guard].span; + this.tcx.dcx().emit_fatal(LoopMatchArmWithGuard { span }) + } + + patterns.push((&*arm.pattern, HasMatchGuard::No)); + } + + // The `built_tree` maps match arms to their basic block (where control flow + // jumps to when a value matches the arm). This structure is stored so that a + // `#[const_continue]` can figure out what basic block to jump to. + let built_tree = this.lower_match_tree( + body_block, + scrutinee_span, + &scrutinee_place_builder, + match_start_span, + patterns, + false, + ); + + let state_place = scrutinee_place_builder.to_place(this); + + // This is logic for the labeled block: a block is a drop scope, hence + // `in_scope`, and a labeled block can be broken out of with a `break 'label`, + // hence the `in_breakable_scope`. + // + // Then `in_const_continuable_scope` stores information for the lowering of + // `#[const_continue]`, and finally the match is lowered in the standard way. + unpack!( + body_block = this.in_scope( + (region_scope, source_info), + LintLevel::Inherited, + move |this| { + this.in_breakable_scope(None, state_place, expr_span, |this| { + Some(this.in_const_continuable_scope( + arms.clone(), + built_tree.clone(), + state_place, + expr_span, + |this| { + this.lower_match_arms( + destination, + scrutinee_place_builder, + scrutinee_span, + arms, + built_tree, + this.source_info(match_span), + ) + }, + )) + }) + } + ) + ); + + this.cfg.goto(body_block, source_info, loop_block); + + // Loops are only exited by `break` expressions. + None + }) + } ExprKind::Call { ty: _, fun, ref args, from_hir_call, fn_span } => { let fun = unpack!(block = this.as_local_operand(block, fun)); let args: Box<[_]> = args @@ -601,6 +718,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Break { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } => { diff --git a/compiler/rustc_mir_build/src/builder/expr/stmt.rs b/compiler/rustc_mir_build/src/builder/expr/stmt.rs index 2dff26f02f3..675beceea14 100644 --- a/compiler/rustc_mir_build/src/builder/expr/stmt.rs +++ b/compiler/rustc_mir_build/src/builder/expr/stmt.rs @@ -98,6 +98,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ExprKind::Break { label, value } => { this.break_scope(block, value, BreakableTarget::Break(label), source_info) } + ExprKind::ConstContinue { label, value } => { + this.break_const_continuable_scope(block, value, label, source_info) + } ExprKind::Return { value } => { this.break_scope(block, value, BreakableTarget::Return, source_info) } diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index 977d4f3e931..270a7d4b154 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -18,7 +18,9 @@ use rustc_middle::bug; use rustc_middle::middle::region; use rustc_middle::mir::{self, *}; use rustc_middle::thir::{self, *}; -use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty}; +use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, ValTree, ValTreeKind}; +use rustc_pattern_analysis::constructor::RangeEnd; +use rustc_pattern_analysis::rustc::{DeconstructedPat, RustcPatCtxt}; use rustc_span::{BytePos, Pos, Span, Symbol, sym}; use tracing::{debug, instrument}; @@ -426,7 +428,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// (by [Builder::lower_match_tree]). /// /// `outer_source_info` is the SourceInfo for the whole match. - fn lower_match_arms( + pub(crate) fn lower_match_arms( &mut self, destination: Place<'tcx>, scrutinee_place_builder: PlaceBuilder<'tcx>, @@ -1395,7 +1397,7 @@ pub(crate) struct ArmHasGuard(pub(crate) bool); /// A sub-branch in the output of match lowering. Match lowering has generated MIR code that will /// branch to `success_block` when the matched value matches the corresponding pattern. If there is /// a guard, its failure must continue to `otherwise_block`, which will resume testing patterns. -#[derive(Debug)] +#[derive(Debug, Clone)] struct MatchTreeSubBranch<'tcx> { span: Span, /// The block that is branched to if the corresponding subpattern matches. @@ -1411,7 +1413,7 @@ struct MatchTreeSubBranch<'tcx> { } /// A branch in the output of match lowering. -#[derive(Debug)] +#[derive(Debug, Clone)] struct MatchTreeBranch<'tcx> { sub_branches: Vec<MatchTreeSubBranch<'tcx>>, } @@ -1430,8 +1432,8 @@ struct MatchTreeBranch<'tcx> { /// Here the first arm gives the first `MatchTreeBranch`, which has two sub-branches, one for each /// alternative of the or-pattern. They are kept separate because each needs to bind `x` to a /// different place. -#[derive(Debug)] -struct BuiltMatchTree<'tcx> { +#[derive(Debug, Clone)] +pub(crate) struct BuiltMatchTree<'tcx> { branches: Vec<MatchTreeBranch<'tcx>>, otherwise_block: BasicBlock, /// If any of the branches had a guard, we collect here the places and locals to fakely borrow @@ -1489,7 +1491,7 @@ impl<'tcx> MatchTreeBranch<'tcx> { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum HasMatchGuard { +pub(crate) enum HasMatchGuard { Yes, No, } @@ -1504,7 +1506,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// `refutable` indicates whether the candidate list is refutable (for `if let` and `let else`) /// or not (for `let` and `match`). In the refutable case we return the block to which we branch /// on failure. - fn lower_match_tree( + pub(crate) fn lower_match_tree( &mut self, block: BasicBlock, scrutinee_span: Span, @@ -1890,7 +1892,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { debug!("expanding or-pattern: candidate={:#?}\npats={:#?}", candidate, pats); candidate.or_span = Some(match_pair.pattern_span); candidate.subcandidates = pats - .into_vec() .into_iter() .map(|flat_pat| Candidate::from_flat_pat(flat_pat, candidate.has_guard)) .collect(); @@ -2864,4 +2865,129 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { true } + + /// Attempt to statically pick the `BasicBlock` that a value would resolve to at runtime. + pub(crate) fn static_pattern_match( + &self, + cx: &RustcPatCtxt<'_, 'tcx>, + valtree: ValTree<'tcx>, + arms: &[ArmId], + built_match_tree: &BuiltMatchTree<'tcx>, + ) -> Option<BasicBlock> { + let it = arms.iter().zip(built_match_tree.branches.iter()); + for (&arm_id, branch) in it { + let pat = cx.lower_pat(&*self.thir.arms[arm_id].pattern); + + // Peel off or-patterns if they exist. + if let rustc_pattern_analysis::rustc::Constructor::Or = pat.ctor() { + for pat in pat.iter_fields() { + // For top-level or-patterns (the only ones we accept right now), when the + // bindings are the same (e.g. there are none), the sub_branch is stored just + // once. + let sub_branch = branch + .sub_branches + .get(pat.idx) + .or_else(|| branch.sub_branches.last()) + .unwrap(); + + match self.static_pattern_match_inner(valtree, &pat.pat) { + true => return Some(sub_branch.success_block), + false => continue, + } + } + } else if self.static_pattern_match_inner(valtree, &pat) { + return Some(branch.sub_branches[0].success_block); + } + } + + None + } + + /// Helper for [`Self::static_pattern_match`], checking whether the value represented by the + /// `ValTree` matches the given pattern. This function does not recurse, meaning that it does + /// not handle or-patterns, or patterns for types with fields. + fn static_pattern_match_inner( + &self, + valtree: ty::ValTree<'tcx>, + pat: &DeconstructedPat<'_, 'tcx>, + ) -> bool { + use rustc_pattern_analysis::constructor::{IntRange, MaybeInfiniteInt}; + use rustc_pattern_analysis::rustc::Constructor; + + match pat.ctor() { + Constructor::Variant(variant_index) => { + let ValTreeKind::Branch(box [actual_variant_idx]) = *valtree else { + bug!("malformed valtree for an enum") + }; + + let ValTreeKind::Leaf(actual_variant_idx) = ***actual_variant_idx else { + bug!("malformed valtree for an enum") + }; + + *variant_index == VariantIdx::from_u32(actual_variant_idx.to_u32()) + } + Constructor::IntRange(int_range) => { + let size = pat.ty().primitive_size(self.tcx); + let actual_int = valtree.unwrap_leaf().to_bits(size); + let actual_int = if pat.ty().is_signed() { + MaybeInfiniteInt::new_finite_int(actual_int, size.bits()) + } else { + MaybeInfiniteInt::new_finite_uint(actual_int) + }; + IntRange::from_singleton(actual_int).is_subrange(int_range) + } + Constructor::Bool(pattern_value) => match valtree.unwrap_leaf().try_to_bool() { + Ok(actual_value) => *pattern_value == actual_value, + Err(()) => bug!("bool value with invalid bits"), + }, + Constructor::F16Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f16(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::F32Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f32(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::F64Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f64(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::F128Range(l, h, end) => { + let actual = valtree.unwrap_leaf().to_f128(); + match end { + RangeEnd::Included => (*l..=*h).contains(&actual), + RangeEnd::Excluded => (*l..*h).contains(&actual), + } + } + Constructor::Wildcard => true, + + // These we may eventually support: + Constructor::Struct + | Constructor::Ref + | Constructor::DerefPattern(_) + | Constructor::Slice(_) + | Constructor::UnionField + | Constructor::Or + | Constructor::Str(_) => bug!("unsupported pattern constructor {:?}", pat.ctor()), + + // These should never occur here: + Constructor::Opaque(_) + | Constructor::Never + | Constructor::NonExhaustive + | Constructor::Hidden + | Constructor::Missing + | Constructor::PrivateUninhabited => { + bug!("unsupported pattern constructor {:?}", pat.ctor()) + } + } + } } diff --git a/compiler/rustc_mir_build/src/builder/scope.rs b/compiler/rustc_mir_build/src/builder/scope.rs index 67988f1fcbc..1d15e7e126f 100644 --- a/compiler/rustc_mir_build/src/builder/scope.rs +++ b/compiler/rustc_mir_build/src/builder/scope.rs @@ -83,20 +83,24 @@ that contains only loops and breakable blocks. It tracks where a `break`, use std::mem; +use interpret::ErrorHandled; use rustc_data_structures::fx::FxHashMap; use rustc_hir::HirId; use rustc_index::{IndexSlice, IndexVec}; use rustc_middle::middle::region; -use rustc_middle::mir::*; -use rustc_middle::thir::{ExprId, LintLevel}; -use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::mir::{self, *}; +use rustc_middle::thir::{AdtExpr, AdtExprBase, ArmId, ExprId, ExprKind, LintLevel}; +use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt, ValTree}; use rustc_middle::{bug, span_bug}; +use rustc_pattern_analysis::rustc::RustcPatCtxt; use rustc_session::lint::Level; use rustc_span::source_map::Spanned; use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument}; +use super::matches::BuiltMatchTree; use crate::builder::{BlockAnd, BlockAndExtension, BlockFrame, Builder, CFG}; +use crate::errors::{ConstContinueBadConst, ConstContinueUnknownJumpTarget}; #[derive(Debug)] pub(crate) struct Scopes<'tcx> { @@ -105,6 +109,8 @@ pub(crate) struct Scopes<'tcx> { /// The current set of breakable scopes. See module comment for more details. breakable_scopes: Vec<BreakableScope<'tcx>>, + const_continuable_scopes: Vec<ConstContinuableScope<'tcx>>, + /// The scope of the innermost if-then currently being lowered. if_then_scope: Option<IfThenScope>, @@ -175,6 +181,20 @@ struct BreakableScope<'tcx> { } #[derive(Debug)] +struct ConstContinuableScope<'tcx> { + /// The scope for the `#[loop_match]` which its `#[const_continue]`s will jump to. + region_scope: region::Scope, + /// The place of the state of a `#[loop_match]`, which a `#[const_continue]` must update. + state_place: Place<'tcx>, + + arms: Box<[ArmId]>, + built_match_tree: BuiltMatchTree<'tcx>, + + /// Drops that happen on a `#[const_continue]` + const_continue_drops: DropTree, +} + +#[derive(Debug)] struct IfThenScope { /// The if-then scope or arm scope region_scope: region::Scope, @@ -461,6 +481,7 @@ impl<'tcx> Scopes<'tcx> { Self { scopes: Vec::new(), breakable_scopes: Vec::new(), + const_continuable_scopes: Vec::new(), if_then_scope: None, unwind_drops: DropTree::new(), coroutine_drops: DropTree::new(), @@ -552,6 +573,59 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } + /// Start a const-continuable scope, which tracks where `#[const_continue] break` should + /// branch to. + pub(crate) fn in_const_continuable_scope<F>( + &mut self, + arms: Box<[ArmId]>, + built_match_tree: BuiltMatchTree<'tcx>, + state_place: Place<'tcx>, + span: Span, + f: F, + ) -> BlockAnd<()> + where + F: FnOnce(&mut Builder<'a, 'tcx>) -> BlockAnd<()>, + { + let region_scope = self.scopes.topmost(); + let scope = ConstContinuableScope { + region_scope, + state_place, + const_continue_drops: DropTree::new(), + arms, + built_match_tree, + }; + self.scopes.const_continuable_scopes.push(scope); + let normal_exit_block = f(self); + let const_continue_scope = self.scopes.const_continuable_scopes.pop().unwrap(); + assert!(const_continue_scope.region_scope == region_scope); + + let break_block = self.build_exit_tree( + const_continue_scope.const_continue_drops, + region_scope, + span, + None, + ); + + match (normal_exit_block, break_block) { + (block, None) => block, + (normal_block, Some(exit_block)) => { + let target = self.cfg.start_new_block(); + let source_info = self.source_info(span); + self.cfg.terminate( + normal_block.into_block(), + source_info, + TerminatorKind::Goto { target }, + ); + self.cfg.terminate( + exit_block.into_block(), + source_info, + TerminatorKind::Goto { target }, + ); + target.unit() + } + } + } + /// Start an if-then scope which tracks drop for `if` expressions and `if` /// guards. /// @@ -742,6 +816,190 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.cfg.start_new_block().unit() } + /// Based on `FunctionCx::eval_unevaluated_mir_constant_to_valtree`. + fn eval_unevaluated_mir_constant_to_valtree( + &self, + constant: ConstOperand<'tcx>, + ) -> Result<(ty::ValTree<'tcx>, Ty<'tcx>), interpret::ErrorHandled> { + assert!(!constant.const_.ty().has_param()); + let (uv, ty) = match constant.const_ { + mir::Const::Unevaluated(uv, ty) => (uv.shrink(), ty), + mir::Const::Ty(_, c) => match c.kind() { + // A constant that came from a const generic but was then used as an argument to + // old-style simd_shuffle (passing as argument instead of as a generic param). + ty::ConstKind::Value(cv) => return Ok((cv.valtree, cv.ty)), + other => span_bug!(constant.span, "{other:#?}"), + }, + mir::Const::Val(mir::ConstValue::Scalar(mir::interpret::Scalar::Int(val)), ty) => { + return Ok((ValTree::from_scalar_int(self.tcx, val), ty)); + } + // We should never encounter `Const::Val` unless MIR opts (like const prop) evaluate + // a constant and write that value back into `Operand`s. This could happen, but is + // unlikely. Also: all users of `simd_shuffle` are on unstable and already need to take + // a lot of care around intrinsics. For an issue to happen here, it would require a + // macro expanding to a `simd_shuffle` call without wrapping the constant argument in a + // `const {}` block, but the user pass through arbitrary expressions. + + // FIXME(oli-obk): Replace the magic const generic argument of `simd_shuffle` with a + // real const generic, and get rid of this entire function. + other => span_bug!(constant.span, "{other:#?}"), + }; + + match self.tcx.const_eval_resolve_for_typeck(self.typing_env(), uv, constant.span) { + Ok(Ok(valtree)) => Ok((valtree, ty)), + Ok(Err(ty)) => span_bug!(constant.span, "could not convert {ty:?} to a valtree"), + Err(e) => Err(e), + } + } + + /// Sets up the drops for jumping from `block` to `scope`. + pub(crate) fn break_const_continuable_scope( + &mut self, + mut block: BasicBlock, + value: ExprId, + scope: region::Scope, + source_info: SourceInfo, + ) -> BlockAnd<()> { + let span = source_info.span; + + // A break can only break out of a scope, so the value should be a scope. + let rustc_middle::thir::ExprKind::Scope { value, .. } = self.thir[value].kind else { + span_bug!(span, "break value must be a scope") + }; + + let constant = match &self.thir[value].kind { + ExprKind::Adt(box AdtExpr { variant_index, fields, base, .. }) => { + assert!(matches!(base, AdtExprBase::None)); + assert!(fields.is_empty()); + ConstOperand { + span: self.thir[value].span, + user_ty: None, + const_: Const::Ty( + self.thir[value].ty, + ty::Const::new_value( + self.tcx, + ValTree::from_branches( + self.tcx, + [ValTree::from_scalar_int(self.tcx, variant_index.as_u32().into())], + ), + self.thir[value].ty, + ), + ), + } + } + _ => self.as_constant(&self.thir[value]), + }; + + let break_index = self + .scopes + .const_continuable_scopes + .iter() + .rposition(|const_continuable_scope| const_continuable_scope.region_scope == scope) + .unwrap_or_else(|| span_bug!(span, "no enclosing const-continuable scope found")); + + let scope = &self.scopes.const_continuable_scopes[break_index]; + + let state_decl = &self.local_decls[scope.state_place.as_local().unwrap()]; + let state_ty = state_decl.ty; + let (discriminant_ty, rvalue) = match state_ty.kind() { + ty::Adt(adt_def, _) if adt_def.is_enum() => { + (state_ty.discriminant_ty(self.tcx), Rvalue::Discriminant(scope.state_place)) + } + ty::Uint(_) | ty::Int(_) | ty::Float(_) | ty::Bool | ty::Char => { + (state_ty, Rvalue::Use(Operand::Copy(scope.state_place))) + } + _ => span_bug!(state_decl.source_info.span, "unsupported #[loop_match] state"), + }; + + // The `PatCtxt` is normally used in pattern exhaustiveness checking, but reused + // here because it performs normalization and const evaluation. + let dropless_arena = rustc_arena::DroplessArena::default(); + let typeck_results = self.tcx.typeck(self.def_id); + let cx = RustcPatCtxt { + tcx: self.tcx, + typeck_results, + module: self.tcx.parent_module(self.hir_id).to_def_id(), + // FIXME(#132279): We're in a body, should handle opaques. + typing_env: rustc_middle::ty::TypingEnv::non_body_analysis(self.tcx, self.def_id), + dropless_arena: &dropless_arena, + match_lint_level: self.hir_id, + whole_match_span: Some(rustc_span::Span::default()), + scrut_span: rustc_span::Span::default(), + refutable: true, + known_valid_scrutinee: true, + }; + + let valtree = match self.eval_unevaluated_mir_constant_to_valtree(constant) { + Ok((valtree, ty)) => { + // Defensively check that the type is monomorphic. + assert!(!ty.has_param()); + + valtree + } + Err(ErrorHandled::Reported(..)) => return self.cfg.start_new_block().unit(), + Err(ErrorHandled::TooGeneric(_)) => { + self.tcx.dcx().emit_fatal(ConstContinueBadConst { span: constant.span }); + } + }; + + let Some(real_target) = + self.static_pattern_match(&cx, valtree, &*scope.arms, &scope.built_match_tree) + else { + self.tcx.dcx().emit_fatal(ConstContinueUnknownJumpTarget { span }) + }; + + self.block_context.push(BlockFrame::SubExpr); + let state_place = scope.state_place; + block = self.expr_into_dest(state_place, block, value).into_block(); + self.block_context.pop(); + + let discr = self.temp(discriminant_ty, source_info.span); + let scope_index = self + .scopes + .scope_index(self.scopes.const_continuable_scopes[break_index].region_scope, span); + let scope = &mut self.scopes.const_continuable_scopes[break_index]; + self.cfg.push_assign(block, source_info, discr, rvalue); + let drop_and_continue_block = self.cfg.start_new_block(); + let imaginary_target = self.cfg.start_new_block(); + self.cfg.terminate( + block, + source_info, + TerminatorKind::FalseEdge { real_target: drop_and_continue_block, imaginary_target }, + ); + + let drops = &mut scope.const_continue_drops; + + let drop_idx = self.scopes.scopes[scope_index + 1..] + .iter() + .flat_map(|scope| &scope.drops) + .fold(ROOT_NODE, |drop_idx, &drop| drops.add_drop(drop, drop_idx)); + + drops.add_entry_point(imaginary_target, drop_idx); + + self.cfg.terminate(imaginary_target, source_info, TerminatorKind::UnwindResume); + + let region_scope = scope.region_scope; + let scope_index = self.scopes.scope_index(region_scope, span); + let mut drops = DropTree::new(); + + let drop_idx = self.scopes.scopes[scope_index + 1..] + .iter() + .flat_map(|scope| &scope.drops) + .fold(ROOT_NODE, |drop_idx, &drop| drops.add_drop(drop, drop_idx)); + + drops.add_entry_point(drop_and_continue_block, drop_idx); + + // `build_drop_trees` doesn't have access to our source_info, so we + // create a dummy terminator now. `TerminatorKind::UnwindResume` is used + // because MIR type checking will panic if it hasn't been overwritten. + // (See `<ExitScopes as DropTreeBuilder>::link_entry_point`.) + self.cfg.terminate(drop_and_continue_block, source_info, TerminatorKind::UnwindResume); + + self.build_exit_tree(drops, region_scope, span, Some(real_target)); + + return self.cfg.start_new_block().unit(); + } + /// Sets up the drops for breaking from `block` due to an `if` condition /// that turned out to be false. /// diff --git a/compiler/rustc_mir_build/src/check_unsafety.rs b/compiler/rustc_mir_build/src/check_unsafety.rs index d5061b71699..0b6b36640e9 100644 --- a/compiler/rustc_mir_build/src/check_unsafety.rs +++ b/compiler/rustc_mir_build/src/check_unsafety.rs @@ -465,10 +465,12 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> { | ExprKind::Break { .. } | ExprKind::Closure { .. } | ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } | ExprKind::Yield { .. } | ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } | ExprKind::Let { .. } | ExprKind::Match { .. } | ExprKind::Box { .. } diff --git a/compiler/rustc_mir_build/src/errors.rs b/compiler/rustc_mir_build/src/errors.rs index ae09db50235..32df191cbca 100644 --- a/compiler/rustc_mir_build/src/errors.rs +++ b/compiler/rustc_mir_build/src/errors.rs @@ -1149,3 +1149,80 @@ impl Subdiagnostic for Rust2024IncompatiblePatSugg { } } } + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_invalid_update)] +pub(crate) struct LoopMatchInvalidUpdate { + #[primary_span] + pub lhs: Span, + #[label] + pub scrutinee: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_invalid_match)] +#[note] +pub(crate) struct LoopMatchInvalidMatch { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_unsupported_type)] +#[note] +pub(crate) struct LoopMatchUnsupportedType<'tcx> { + #[primary_span] + pub span: Span, + pub ty: Ty<'tcx>, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_bad_statements)] +pub(crate) struct LoopMatchBadStatements { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_bad_rhs)] +pub(crate) struct LoopMatchBadRhs { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_missing_assignment)] +pub(crate) struct LoopMatchMissingAssignment { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_loop_match_arm_with_guard)] +pub(crate) struct LoopMatchArmWithGuard { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_const_continue_bad_const)] +pub(crate) struct ConstContinueBadConst { + #[primary_span] + #[label] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_const_continue_missing_value)] +pub(crate) struct ConstContinueMissingValue { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_build_const_continue_unknown_jump_target)] +#[note] +pub(crate) struct ConstContinueUnknownJumpTarget { + #[primary_span] + pub span: Span, +} diff --git a/compiler/rustc_mir_build/src/thir/cx/expr.rs b/compiler/rustc_mir_build/src/thir/cx/expr.rs index 3baeccf6409..4c970e8e3d7 100644 --- a/compiler/rustc_mir_build/src/thir/cx/expr.rs +++ b/compiler/rustc_mir_build/src/thir/cx/expr.rs @@ -1,6 +1,7 @@ use itertools::Itertools; use rustc_abi::{FIRST_VARIANT, FieldIdx}; use rustc_ast::UnsafeBinderCastKind; +use rustc_attr_data_structures::{AttributeKind, find_attr}; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_hir as hir; use rustc_hir::def::{CtorKind, CtorOf, DefKind, Res}; @@ -21,6 +22,7 @@ use rustc_middle::{bug, span_bug}; use rustc_span::{Span, sym}; use tracing::{debug, info, instrument, trace}; +use crate::errors::*; use crate::thir::cx::ThirBuildCx; impl<'tcx> ThirBuildCx<'tcx> { @@ -796,16 +798,38 @@ impl<'tcx> ThirBuildCx<'tcx> { } hir::ExprKind::Ret(v) => ExprKind::Return { value: v.map(|v| self.mirror_expr(v)) }, hir::ExprKind::Become(call) => ExprKind::Become { value: self.mirror_expr(call) }, - hir::ExprKind::Break(dest, ref value) => match dest.target_id { - Ok(target_id) => ExprKind::Break { - label: region::Scope { - local_id: target_id.local_id, - data: region::ScopeData::Node, - }, - value: value.map(|value| self.mirror_expr(value)), - }, - Err(err) => bug!("invalid loop id for break: {}", err), - }, + hir::ExprKind::Break(dest, ref value) => { + if find_attr!(self.tcx.hir_attrs(expr.hir_id), AttributeKind::ConstContinue(_)) { + match dest.target_id { + Ok(target_id) => { + let Some(value) = value else { + let span = expr.span; + self.tcx.dcx().emit_fatal(ConstContinueMissingValue { span }) + }; + + ExprKind::ConstContinue { + label: region::Scope { + local_id: target_id.local_id, + data: region::ScopeData::Node, + }, + value: self.mirror_expr(value), + } + } + Err(err) => bug!("invalid loop id for break: {}", err), + } + } else { + match dest.target_id { + Ok(target_id) => ExprKind::Break { + label: region::Scope { + local_id: target_id.local_id, + data: region::ScopeData::Node, + }, + value: value.map(|value| self.mirror_expr(value)), + }, + Err(err) => bug!("invalid loop id for break: {}", err), + } + } + } hir::ExprKind::Continue(dest) => match dest.target_id { Ok(loop_id) => ExprKind::Continue { label: region::Scope { @@ -840,18 +864,93 @@ impl<'tcx> ThirBuildCx<'tcx> { match_source, }, hir::ExprKind::Loop(body, ..) => { - let block_ty = self.typeck_results.node_type(body.hir_id); - let (temp_lifetime, backwards_incompatible) = self - .rvalue_scopes - .temporary_scope(self.region_scope_tree, body.hir_id.local_id); - let block = self.mirror_block(body); - let body = self.thir.exprs.push(Expr { - ty: block_ty, - temp_lifetime: TempLifetime { temp_lifetime, backwards_incompatible }, - span: self.thir[block].span, - kind: ExprKind::Block { block }, - }); - ExprKind::Loop { body } + if find_attr!(self.tcx.hir_attrs(expr.hir_id), AttributeKind::LoopMatch(_)) { + let dcx = self.tcx.dcx(); + + // Accept either `state = expr` or `state = expr;`. + let loop_body_expr = match body.stmts { + [] => match body.expr { + Some(expr) => expr, + None => dcx.emit_fatal(LoopMatchMissingAssignment { span: body.span }), + }, + [single] if body.expr.is_none() => match single.kind { + hir::StmtKind::Expr(expr) | hir::StmtKind::Semi(expr) => expr, + _ => dcx.emit_fatal(LoopMatchMissingAssignment { span: body.span }), + }, + [first @ last] | [first, .., last] => dcx + .emit_fatal(LoopMatchBadStatements { span: first.span.to(last.span) }), + }; + + let hir::ExprKind::Assign(state, rhs_expr, _) = loop_body_expr.kind else { + dcx.emit_fatal(LoopMatchMissingAssignment { span: loop_body_expr.span }) + }; + + let hir::ExprKind::Block(block_body, _) = rhs_expr.kind else { + dcx.emit_fatal(LoopMatchBadRhs { span: rhs_expr.span }) + }; + + // The labeled block should contain one match expression, but defining items is + // allowed. + for stmt in block_body.stmts { + if !matches!(stmt.kind, rustc_hir::StmtKind::Item(_)) { + dcx.emit_fatal(LoopMatchBadStatements { span: stmt.span }) + } + } + + let Some(block_body_expr) = block_body.expr else { + dcx.emit_fatal(LoopMatchBadRhs { span: block_body.span }) + }; + + let hir::ExprKind::Match(scrutinee, arms, _match_source) = block_body_expr.kind + else { + dcx.emit_fatal(LoopMatchBadRhs { span: block_body_expr.span }) + }; + + fn local(expr: &rustc_hir::Expr<'_>) -> Option<hir::HirId> { + if let hir::ExprKind::Path(hir::QPath::Resolved(_, path)) = expr.kind { + if let Res::Local(hir_id) = path.res { + return Some(hir_id); + } + } + + None + } + + let Some(scrutinee_hir_id) = local(scrutinee) else { + dcx.emit_fatal(LoopMatchInvalidMatch { span: scrutinee.span }) + }; + + if local(state) != Some(scrutinee_hir_id) { + dcx.emit_fatal(LoopMatchInvalidUpdate { + scrutinee: scrutinee.span, + lhs: state.span, + }) + } + + ExprKind::LoopMatch { + state: self.mirror_expr(state), + region_scope: region::Scope { + local_id: block_body.hir_id.local_id, + data: region::ScopeData::Node, + }, + + arms: arms.iter().map(|a| self.convert_arm(a)).collect(), + match_span: block_body_expr.span, + } + } else { + let block_ty = self.typeck_results.node_type(body.hir_id); + let (temp_lifetime, backwards_incompatible) = self + .rvalue_scopes + .temporary_scope(self.region_scope_tree, body.hir_id.local_id); + let block = self.mirror_block(body); + let body = self.thir.exprs.push(Expr { + ty: block_ty, + temp_lifetime: TempLifetime { temp_lifetime, backwards_incompatible }, + span: self.thir[block].span, + kind: ExprKind::Block { block }, + }); + ExprKind::Loop { body } + } } hir::ExprKind::Field(source, ..) => ExprKind::Field { lhs: self.mirror_expr(source), diff --git a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs index 245bd866030..9fd410e6bf1 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs @@ -331,7 +331,11 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> { | WrapUnsafeBinder { source } => self.is_known_valid_scrutinee(&self.thir()[*source]), // These diverge. - Become { .. } | Break { .. } | Continue { .. } | Return { .. } => true, + Become { .. } + | Break { .. } + | Continue { .. } + | ConstContinue { .. } + | Return { .. } => true, // These are statements that evaluate to `()`. Assign { .. } | AssignOp { .. } | InlineAsm { .. } | Let { .. } => true, @@ -353,6 +357,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> { | Literal { .. } | LogicalOp { .. } | Loop { .. } + | LoopMatch { .. } | Match { .. } | NamedConst { .. } | NonHirLiteral { .. } diff --git a/compiler/rustc_mir_build/src/thir/print.rs b/compiler/rustc_mir_build/src/thir/print.rs index db9547a481f..1507b6b8c06 100644 --- a/compiler/rustc_mir_build/src/thir/print.rs +++ b/compiler/rustc_mir_build/src/thir/print.rs @@ -318,6 +318,20 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> { self.print_expr(*body, depth_lvl + 2); print_indented!(self, ")", depth_lvl); } + LoopMatch { state, region_scope, match_span, arms } => { + print_indented!(self, "LoopMatch {", depth_lvl); + print_indented!(self, "state:", depth_lvl + 1); + self.print_expr(*state, depth_lvl + 2); + print_indented!(self, format!("region_scope: {:?}", region_scope), depth_lvl + 1); + print_indented!(self, format!("match_span: {:?}", match_span), depth_lvl + 1); + + print_indented!(self, "arms: [", depth_lvl + 1); + for arm_id in arms.iter() { + self.print_arm(*arm_id, depth_lvl + 2); + } + print_indented!(self, "]", depth_lvl + 1); + print_indented!(self, "}", depth_lvl); + } Let { expr, pat } => { print_indented!(self, "Let {", depth_lvl); print_indented!(self, "expr:", depth_lvl + 1); @@ -415,6 +429,13 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> { print_indented!(self, format!("label: {:?}", label), depth_lvl + 1); print_indented!(self, "}", depth_lvl); } + ConstContinue { label, value } => { + print_indented!(self, "ConstContinue (", depth_lvl); + print_indented!(self, format!("label: {:?}", label), depth_lvl + 1); + print_indented!(self, "value:", depth_lvl + 1); + self.print_expr(*value, depth_lvl + 2); + print_indented!(self, ")", depth_lvl); + } Return { value } => { print_indented!(self, "Return {", depth_lvl); print_indented!(self, "value:", depth_lvl + 1); diff --git a/compiler/rustc_passes/messages.ftl b/compiler/rustc_passes/messages.ftl index 29526817257..06acf7bf992 100644 --- a/compiler/rustc_passes/messages.ftl +++ b/compiler/rustc_passes/messages.ftl @@ -82,11 +82,14 @@ passes_collapse_debuginfo = passes_confusables = attribute should be applied to an inherent method .label = not an inherent method +passes_const_continue_attr = + `#[const_continue]` should be applied to a break expression + .label = not a break expression + passes_const_stable_not_stable = attribute `#[rustc_const_stable]` can only be applied to functions that are declared `#[stable]` .label = attribute specified here - passes_coroutine_on_non_closure = attribute should be applied to closures .label = not a closure @@ -446,6 +449,10 @@ passes_linkage = attribute should be applied to a function or static .label = not a function definition or static +passes_loop_match_attr = + `#[loop_match]` should be applied to a loop + .label = not a loop + passes_macro_export = `#[macro_export]` only has an effect on macro definitions diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index c2a58b4cd7d..3c40b597f41 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -132,6 +132,12 @@ impl<'tcx> CheckAttrVisitor<'tcx> { Attribute::Parsed(AttributeKind::Optimize(_, attr_span)) => { self.check_optimize(hir_id, *attr_span, span, target) } + Attribute::Parsed(AttributeKind::LoopMatch(attr_span)) => { + self.check_loop_match(hir_id, *attr_span, target) + } + Attribute::Parsed(AttributeKind::ConstContinue(attr_span)) => { + self.check_const_continue(hir_id, *attr_span, target) + } Attribute::Parsed(AttributeKind::AllowInternalUnstable(syms)) => self .check_allow_internal_unstable( hir_id, @@ -2630,6 +2636,32 @@ impl<'tcx> CheckAttrVisitor<'tcx> { } } } + + fn check_loop_match(&self, hir_id: HirId, attr_span: Span, target: Target) { + let node_span = self.tcx.hir_span(hir_id); + + if !matches!(target, Target::Expression) { + self.dcx().emit_err(errors::LoopMatchAttr { attr_span, node_span }); + return; + } + + if !matches!(self.tcx.hir_expect_expr(hir_id).kind, hir::ExprKind::Loop(..)) { + self.dcx().emit_err(errors::LoopMatchAttr { attr_span, node_span }); + }; + } + + fn check_const_continue(&self, hir_id: HirId, attr_span: Span, target: Target) { + let node_span = self.tcx.hir_span(hir_id); + + if !matches!(target, Target::Expression) { + self.dcx().emit_err(errors::ConstContinueAttr { attr_span, node_span }); + return; + } + + if !matches!(self.tcx.hir_expect_expr(hir_id).kind, hir::ExprKind::Break(..)) { + self.dcx().emit_err(errors::ConstContinueAttr { attr_span, node_span }); + }; + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_passes/src/errors.rs b/compiler/rustc_passes/src/errors.rs index 94c8ae77ed7..c933647165c 100644 --- a/compiler/rustc_passes/src/errors.rs +++ b/compiler/rustc_passes/src/errors.rs @@ -31,6 +31,24 @@ pub(crate) struct AutoDiffAttr { pub attr_span: Span, } +#[derive(Diagnostic)] +#[diag(passes_loop_match_attr)] +pub(crate) struct LoopMatchAttr { + #[primary_span] + pub attr_span: Span, + #[label] + pub node_span: Span, +} + +#[derive(Diagnostic)] +#[diag(passes_const_continue_attr)] +pub(crate) struct ConstContinueAttr { + #[primary_span] + pub attr_span: Span, + #[label] + pub node_span: Span, +} + #[derive(LintDiagnostic)] #[diag(passes_outer_crate_level_attr)] pub(crate) struct OuterCrateLevelAttr; diff --git a/compiler/rustc_pattern_analysis/src/constructor.rs b/compiler/rustc_pattern_analysis/src/constructor.rs index f7a4931c111..09685640e50 100644 --- a/compiler/rustc_pattern_analysis/src/constructor.rs +++ b/compiler/rustc_pattern_analysis/src/constructor.rs @@ -314,7 +314,8 @@ impl IntRange { IntRange { lo, hi } } - fn is_subrange(&self, other: &Self) -> bool { + #[inline] + pub fn is_subrange(&self, other: &Self) -> bool { other.lo <= self.lo && self.hi <= other.hi } diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 684b1781b44..fb0f63d12c2 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -692,6 +692,7 @@ symbols! { const_closures, const_compare_raw_pointers, const_constructor, + const_continue, const_deallocate, const_destruct, const_eval_limit, @@ -1304,6 +1305,7 @@ symbols! { logf64, loongarch_target_feature, loop_break_value, + loop_match, lt, m68k_target_feature, macro_at_most_once_rep, diff --git a/compiler/rustc_ty_utils/src/consts.rs b/compiler/rustc_ty_utils/src/consts.rs index b275cd382ab..60f8bd9d83a 100644 --- a/compiler/rustc_ty_utils/src/consts.rs +++ b/compiler/rustc_ty_utils/src/consts.rs @@ -226,7 +226,11 @@ fn recurse_build<'tcx>( ExprKind::Yield { .. } => { error(GenericConstantTooComplexSub::YieldNotSupported(node.span))? } - ExprKind::Continue { .. } | ExprKind::Break { .. } | ExprKind::Loop { .. } => { + ExprKind::Continue { .. } + | ExprKind::ConstContinue { .. } + | ExprKind::Break { .. } + | ExprKind::Loop { .. } + | ExprKind::LoopMatch { .. } => { error(GenericConstantTooComplexSub::LoopNotSupported(node.span))? } ExprKind::Box { .. } => error(GenericConstantTooComplexSub::BoxNotSupported(node.span))?, @@ -329,6 +333,7 @@ impl<'a, 'tcx> IsThirPolymorphic<'a, 'tcx> { | thir::ExprKind::NeverToAny { .. } | thir::ExprKind::PointerCoercion { .. } | thir::ExprKind::Loop { .. } + | thir::ExprKind::LoopMatch { .. } | thir::ExprKind::Let { .. } | thir::ExprKind::Match { .. } | thir::ExprKind::Block { .. } @@ -342,6 +347,7 @@ impl<'a, 'tcx> IsThirPolymorphic<'a, 'tcx> { | thir::ExprKind::RawBorrow { .. } | thir::ExprKind::Break { .. } | thir::ExprKind::Continue { .. } + | thir::ExprKind::ConstContinue { .. } | thir::ExprKind::Return { .. } | thir::ExprKind::Become { .. } | thir::ExprKind::Array { .. } |
