diff options
34 files changed, 504 insertions, 14 deletions
diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index 8986430141b..9d216ef3dd8 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -2469,6 +2469,8 @@ pub enum TyPatKind { /// A range pattern (e.g., `1...2`, `1..2`, `1..`, `..2`, `1..=2`, `..=2`). Range(Option<P<AnonConst>>, Option<P<AnonConst>>, Spanned<RangeEnd>), + Or(ThinVec<P<TyPat>>), + /// Placeholder for a pattern that wasn't syntactically well formed in some way. Err(ErrorGuaranteed), } diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs index 6aae2e481a5..cd2293423db 100644 --- a/compiler/rustc_ast/src/mut_visit.rs +++ b/compiler/rustc_ast/src/mut_visit.rs @@ -612,6 +612,7 @@ pub fn walk_ty_pat<T: MutVisitor>(vis: &mut T, ty: &mut P<TyPat>) { visit_opt(start, |c| vis.visit_anon_const(c)); visit_opt(end, |c| vis.visit_anon_const(c)); } + TyPatKind::Or(variants) => visit_thin_vec(variants, |p| vis.visit_ty_pat(p)), TyPatKind::Err(_) => {} } visit_lazy_tts(vis, tokens); diff --git a/compiler/rustc_ast/src/visit.rs b/compiler/rustc_ast/src/visit.rs index 79193fcec63..69a186c8cf1 100644 --- a/compiler/rustc_ast/src/visit.rs +++ b/compiler/rustc_ast/src/visit.rs @@ -608,6 +608,7 @@ pub fn walk_ty_pat<'a, V: Visitor<'a>>(visitor: &mut V, tp: &'a TyPat) -> V::Res visit_opt!(visitor, visit_anon_const, start); visit_opt!(visitor, visit_anon_const, end); } + TyPatKind::Or(variants) => walk_list!(visitor, visit_ty_pat, variants), TyPatKind::Err(_) => {} } V::Result::output() diff --git a/compiler/rustc_ast_lowering/src/pat.rs b/compiler/rustc_ast_lowering/src/pat.rs index f94d788a9b0..4a6929ef011 100644 --- a/compiler/rustc_ast_lowering/src/pat.rs +++ b/compiler/rustc_ast_lowering/src/pat.rs @@ -464,6 +464,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { ) }), ), + TyPatKind::Or(variants) => { + hir::TyPatKind::Or(self.arena.alloc_from_iter( + variants.iter().map(|pat| self.lower_ty_pat_mut(pat, base_type)), + )) + } TyPatKind::Err(guar) => hir::TyPatKind::Err(*guar), }; diff --git a/compiler/rustc_ast_pretty/src/pprust/state.rs b/compiler/rustc_ast_pretty/src/pprust/state.rs index 62a50c73855..9411237421d 100644 --- a/compiler/rustc_ast_pretty/src/pprust/state.rs +++ b/compiler/rustc_ast_pretty/src/pprust/state.rs @@ -1162,6 +1162,17 @@ impl<'a> State<'a> { self.print_expr_anon_const(end, &[]); } } + rustc_ast::TyPatKind::Or(variants) => { + let mut first = true; + for pat in variants { + if first { + first = false + } else { + self.word(" | "); + } + self.print_ty_pat(pat); + } + } rustc_ast::TyPatKind::Err(_) => { self.popen(); self.word("/*ERROR*/"); diff --git a/compiler/rustc_builtin_macros/src/pattern_type.rs b/compiler/rustc_builtin_macros/src/pattern_type.rs index aaf5b233651..3529e5525fc 100644 --- a/compiler/rustc_builtin_macros/src/pattern_type.rs +++ b/compiler/rustc_builtin_macros/src/pattern_type.rs @@ -4,6 +4,7 @@ use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token}; use rustc_errors::PResult; use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult}; use rustc_parse::exp; +use rustc_parse::parser::{CommaRecoveryMode, RecoverColon, RecoverComma}; use rustc_span::Span; pub(crate) fn expand<'cx>( @@ -27,7 +28,17 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P let ty = parser.parse_ty()?; parser.expect_keyword(exp!(Is))?; - let pat = pat_to_ty_pat(cx, parser.parse_pat_no_top_alt(None, None)?.into_inner()); + let pat = pat_to_ty_pat( + cx, + parser + .parse_pat_no_top_guard( + None, + RecoverComma::No, + RecoverColon::No, + CommaRecoveryMode::EitherTupleOrPipe, + )? + .into_inner(), + ); if parser.token != token::Eof { parser.unexpected()?; @@ -47,6 +58,9 @@ fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> P<TyPat> { end.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })), include_end, ), + ast::PatKind::Or(variants) => TyPatKind::Or( + variants.into_iter().map(|pat| pat_to_ty_pat(cx, pat.into_inner())).collect(), + ), ast::PatKind::Err(guar) => TyPatKind::Err(guar), _ => TyPatKind::Err(cx.dcx().span_err(pat.span, "pattern not supported in pattern types")), }; diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 40c63f2b250..97d066ffe3f 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -61,16 +61,21 @@ pub(crate) fn eval_nullary_intrinsic<'tcx>( ensure_monomorphic_enough(tcx, tp_ty)?; ConstValue::from_u128(tcx.type_id_hash(tp_ty).as_u128()) } - sym::variant_count => match tp_ty.kind() { + sym::variant_count => match match tp_ty.kind() { + // Pattern types have the same number of variants as their base type. + // Even if we restrict e.g. which variants are valid, the variants are essentially just uninhabited. + // And `Result<(), !>` still has two variants according to `variant_count`. + ty::Pat(base, _) => *base, + _ => tp_ty, + } + .kind() + { // Correctly handles non-monomorphic calls, so there is no need for ensure_monomorphic_enough. ty::Adt(adt, _) => ConstValue::from_target_usize(adt.variants().len() as u64, &tcx), ty::Alias(..) | ty::Param(_) | ty::Placeholder(_) | ty::Infer(_) => { throw_inval!(TooGeneric) } - ty::Pat(_, pat) => match **pat { - ty::PatternKind::Range { .. } => ConstValue::from_target_usize(0u64, &tcx), - // Future pattern kinds may have more variants - }, + ty::Pat(..) => unreachable!(), ty::Bound(_, _) => bug!("bound ty during ctfe"), ty::Bool | ty::Char diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs index fb7ba6d7ef5..c86af5a9a4b 100644 --- a/compiler/rustc_const_eval/src/interpret/validity.rs +++ b/compiler/rustc_const_eval/src/interpret/validity.rs @@ -1248,6 +1248,14 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt, // Range patterns are precisely reflected into `valid_range` and thus // handled fully by `visit_scalar` (called below). ty::PatternKind::Range { .. } => {}, + + // FIXME(pattern_types): check that the value is covered by one of the variants. + // For now, we rely on layout computation setting the scalar's `valid_range` to + // match the pattern. However, this cannot always work; the layout may + // pessimistically cover actually illegal ranges and Miri would miss that UB. + // The consolation here is that codegen also will miss that UB, so at least + // we won't see optimizations actually breaking such programs. + ty::PatternKind::Or(_patterns) => {} } } _ => { diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index 2f8a8534247..af587ee5bdc 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -1813,6 +1813,9 @@ pub enum TyPatKind<'hir> { /// A range pattern (e.g., `1..=2` or `1..2`). Range(&'hir ConstArg<'hir>, &'hir ConstArg<'hir>), + /// A list of patterns where only one needs to be satisfied + Or(&'hir [TyPat<'hir>]), + /// A placeholder for a pattern that wasn't well formed in some way. Err(ErrorGuaranteed), } diff --git a/compiler/rustc_hir/src/intravisit.rs b/compiler/rustc_hir/src/intravisit.rs index 3c2897ef1d9..a60de4b1fc3 100644 --- a/compiler/rustc_hir/src/intravisit.rs +++ b/compiler/rustc_hir/src/intravisit.rs @@ -710,6 +710,7 @@ pub fn walk_ty_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v TyPat<'v>) try_visit!(visitor.visit_const_arg_unambig(lower_bound)); try_visit!(visitor.visit_const_arg_unambig(upper_bound)); } + TyPatKind::Or(patterns) => walk_list!(visitor, visit_pattern_type_pattern, patterns), TyPatKind::Err(_) => (), } V::Result::output() diff --git a/compiler/rustc_hir_analysis/src/collect/type_of.rs b/compiler/rustc_hir_analysis/src/collect/type_of.rs index 694c1228859..c20b14df770 100644 --- a/compiler/rustc_hir_analysis/src/collect/type_of.rs +++ b/compiler/rustc_hir_analysis/src/collect/type_of.rs @@ -94,10 +94,12 @@ fn const_arg_anon_type_of<'tcx>(icx: &ItemCtxt<'tcx>, arg_hir_id: HirId, span: S } Node::TyPat(pat) => { - let hir::TyKind::Pat(ty, p) = tcx.parent_hir_node(pat.hir_id).expect_ty().kind else { - bug!() + let node = match tcx.parent_hir_node(pat.hir_id) { + // Or patterns can be nested one level deep + Node::TyPat(p) => tcx.parent_hir_node(p.hir_id), + other => other, }; - assert_eq!(p.hir_id, pat.hir_id); + let hir::TyKind::Pat(ty, _) = node.expect_ty().kind else { bug!() }; icx.lower_ty(ty) } diff --git a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs index 6e1e650a817..53792c7b093 100644 --- a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs +++ b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs @@ -2735,6 +2735,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ { ty_span: Span, pat: &hir::TyPat<'tcx>, ) -> Result<ty::PatternKind<'tcx>, ErrorGuaranteed> { + let tcx = self.tcx(); match pat.kind { hir::TyPatKind::Range(start, end) => { match ty.kind() { @@ -2750,6 +2751,13 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ { .span_delayed_bug(ty_span, "invalid base type for range pattern")), } } + hir::TyPatKind::Or(patterns) => { + self.tcx() + .mk_patterns_from_iter(patterns.iter().map(|pat| { + self.lower_pat_ty_pat(ty, ty_span, pat).map(|pat| tcx.mk_pat(pat)) + })) + .map(ty::PatternKind::Or) + } hir::TyPatKind::Err(e) => Err(e), } } diff --git a/compiler/rustc_hir_analysis/src/variance/constraints.rs b/compiler/rustc_hir_analysis/src/variance/constraints.rs index 8123326a47f..ff73d163f84 100644 --- a/compiler/rustc_hir_analysis/src/variance/constraints.rs +++ b/compiler/rustc_hir_analysis/src/variance/constraints.rs @@ -340,6 +340,11 @@ impl<'a, 'tcx> ConstraintContext<'a, 'tcx> { self.add_constraints_from_const(current, start, variance); self.add_constraints_from_const(current, end, variance); } + ty::PatternKind::Or(patterns) => { + for pat in patterns { + self.add_constraints_from_pat(current, variance, pat) + } + } } } diff --git a/compiler/rustc_hir_pretty/src/lib.rs b/compiler/rustc_hir_pretty/src/lib.rs index c95d6a277c7..37c7e613b2c 100644 --- a/compiler/rustc_hir_pretty/src/lib.rs +++ b/compiler/rustc_hir_pretty/src/lib.rs @@ -1866,6 +1866,19 @@ impl<'a> State<'a> { self.word("..="); self.print_const_arg(end); } + TyPatKind::Or(patterns) => { + self.popen(); + let mut first = true; + for pat in patterns { + if first { + first = false; + } else { + self.word(" | "); + } + self.print_ty_pat(pat); + } + self.pclose(); + } TyPatKind::Err(_) => { self.popen(); self.word("/*ERROR*/"); diff --git a/compiler/rustc_lint/src/types.rs b/compiler/rustc_lint/src/types.rs index 2511d3f6b16..42194bac3dc 100644 --- a/compiler/rustc_lint/src/types.rs +++ b/compiler/rustc_lint/src/types.rs @@ -900,6 +900,9 @@ fn pat_ty_is_known_nonnull<'tcx>( // to ensure we aren't wrapping over zero. start > 0 && end >= start } + ty::PatternKind::Or(patterns) => { + patterns.iter().all(|pat| pat_ty_is_known_nonnull(tcx, typing_env, pat)) + } } }, ) @@ -1046,13 +1049,29 @@ pub(crate) fn repr_nullable_ptr<'tcx>( } None } - ty::Pat(base, pat) => match **pat { - ty::PatternKind::Range { .. } => get_nullable_type(tcx, typing_env, *base), - }, + ty::Pat(base, pat) => get_nullable_type_from_pat(tcx, typing_env, *base, *pat), _ => None, } } +fn get_nullable_type_from_pat<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + base: Ty<'tcx>, + pat: ty::Pattern<'tcx>, +) -> Option<Ty<'tcx>> { + match *pat { + ty::PatternKind::Range { .. } => get_nullable_type(tcx, typing_env, base), + ty::PatternKind::Or(patterns) => { + let first = get_nullable_type_from_pat(tcx, typing_env, base, patterns[0])?; + for &pat in &patterns[1..] { + assert_eq!(first, get_nullable_type_from_pat(tcx, typing_env, base, pat)?); + } + Some(first) + } + } +} + impl<'a, 'tcx> ImproperCTypesVisitor<'a, 'tcx> { /// Check if the type is array and emit an unsafe type lint. fn check_for_array_ty(&mut self, sp: Span, ty: Ty<'tcx>) -> bool { diff --git a/compiler/rustc_middle/src/ty/codec.rs b/compiler/rustc_middle/src/ty/codec.rs index 23927c112bc..5ff87959a80 100644 --- a/compiler/rustc_middle/src/ty/codec.rs +++ b/compiler/rustc_middle/src/ty/codec.rs @@ -442,6 +442,15 @@ impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::BoundVaria } } +impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::Pattern<'tcx>> { + fn decode(decoder: &mut D) -> &'tcx Self { + let len = decoder.read_usize(); + decoder.interner().mk_patterns_from_iter( + (0..len).map::<ty::Pattern<'tcx>, _>(|_| Decodable::decode(decoder)), + ) + } +} + impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::Const<'tcx>> { fn decode(decoder: &mut D) -> &'tcx Self { let len = decoder.read_usize(); @@ -503,6 +512,7 @@ impl_decodable_via_ref! { &'tcx mir::Body<'tcx>, &'tcx mir::ConcreteOpaqueTypes<'tcx>, &'tcx ty::List<ty::BoundVariableKind>, + &'tcx ty::List<ty::Pattern<'tcx>>, &'tcx ty::ListWithCachedTypeInfo<ty::Clause<'tcx>>, &'tcx ty::List<FieldIdx>, &'tcx ty::List<(VariantIdx, FieldIdx)>, diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index e8dad1e056c..f97e13f13b0 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -136,6 +136,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> { type AllocId = crate::mir::interpret::AllocId; type Pat = Pattern<'tcx>; + type PatList = &'tcx List<Pattern<'tcx>>; type Safety = hir::Safety; type Abi = ExternAbi; type Const = ty::Const<'tcx>; @@ -843,6 +844,7 @@ pub struct CtxtInterners<'tcx> { captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>, offset_of: InternedSet<'tcx, List<(VariantIdx, FieldIdx)>>, valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>, + patterns: InternedSet<'tcx, List<ty::Pattern<'tcx>>>, } impl<'tcx> CtxtInterners<'tcx> { @@ -879,6 +881,7 @@ impl<'tcx> CtxtInterners<'tcx> { captures: InternedSet::with_capacity(N), offset_of: InternedSet::with_capacity(N), valtree: InternedSet::with_capacity(N), + patterns: InternedSet::with_capacity(N), } } @@ -2659,6 +2662,7 @@ slice_interners!( local_def_ids: intern_local_def_ids(LocalDefId), captures: intern_captures(&'tcx ty::CapturedPlace<'tcx>), offset_of: pub mk_offset_of((VariantIdx, FieldIdx)), + patterns: pub mk_patterns(Pattern<'tcx>), ); impl<'tcx> TyCtxt<'tcx> { @@ -2932,6 +2936,14 @@ impl<'tcx> TyCtxt<'tcx> { self.intern_local_def_ids(def_ids) } + pub fn mk_patterns_from_iter<I, T>(self, iter: I) -> T::Output + where + I: Iterator<Item = T>, + T: CollectAndApply<ty::Pattern<'tcx>, &'tcx List<ty::Pattern<'tcx>>>, + { + T::collect_and_apply(iter, |xs| self.mk_patterns(xs)) + } + pub fn mk_local_def_ids_from_iter<I, T>(self, iter: I) -> T::Output where I: Iterator<Item = T>, diff --git a/compiler/rustc_middle/src/ty/pattern.rs b/compiler/rustc_middle/src/ty/pattern.rs index 758adc42e3e..5af9b17dd77 100644 --- a/compiler/rustc_middle/src/ty/pattern.rs +++ b/compiler/rustc_middle/src/ty/pattern.rs @@ -23,6 +23,13 @@ impl<'tcx> Flags for Pattern<'tcx> { FlagComputation::for_const_kind(&start.kind()).flags | FlagComputation::for_const_kind(&end.kind()).flags } + ty::PatternKind::Or(pats) => { + let mut flags = pats[0].flags(); + for pat in pats[1..].iter() { + flags |= pat.flags(); + } + flags + } } } @@ -31,6 +38,13 @@ impl<'tcx> Flags for Pattern<'tcx> { ty::PatternKind::Range { start, end } => { start.outer_exclusive_binder().max(end.outer_exclusive_binder()) } + ty::PatternKind::Or(pats) => { + let mut idx = pats[0].outer_exclusive_binder(); + for pat in pats[1..].iter() { + idx = idx.max(pat.outer_exclusive_binder()); + } + idx + } } } } @@ -77,6 +91,19 @@ impl<'tcx> IrPrint<PatternKind<'tcx>> for TyCtxt<'tcx> { write!(f, "..={end}") } + PatternKind::Or(patterns) => { + write!(f, "(")?; + let mut first = true; + for pat in patterns { + if first { + first = false + } else { + write!(f, " | ")?; + } + write!(f, "{pat:?}")?; + } + write!(f, ")") + } } } diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs index c3ee72bcaed..6ad4e5276b2 100644 --- a/compiler/rustc_middle/src/ty/relate.rs +++ b/compiler/rustc_middle/src/ty/relate.rs @@ -59,6 +59,15 @@ impl<'tcx> Relate<TyCtxt<'tcx>> for ty::Pattern<'tcx> { let end = relation.relate(end_a, end_b)?; Ok(tcx.mk_pat(ty::PatternKind::Range { start, end })) } + (&ty::PatternKind::Or(a), &ty::PatternKind::Or(b)) => { + if a.len() != b.len() { + return Err(TypeError::Mismatch); + } + let v = iter::zip(a, b).map(|(a, b)| relation.relate(a, b)); + let patterns = tcx.mk_patterns_from_iter(v)?; + Ok(tcx.mk_pat(ty::PatternKind::Or(patterns))) + } + (ty::PatternKind::Range { .. } | ty::PatternKind::Or(_), _) => Err(TypeError::Mismatch), } } } diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index 26861666c1d..2fcb2a1572a 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -779,5 +779,6 @@ list_fold! { ty::Clauses<'tcx> : mk_clauses, &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates, &'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems, + &'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns, CanonicalVarInfos<'tcx> : mk_canonical_var_infos, } diff --git a/compiler/rustc_resolve/src/late.rs b/compiler/rustc_resolve/src/late.rs index bae2fdeecaf..faee0e7dd5f 100644 --- a/compiler/rustc_resolve/src/late.rs +++ b/compiler/rustc_resolve/src/late.rs @@ -958,6 +958,11 @@ impl<'ra: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'_, 'ast, 'r self.resolve_anon_const(end, AnonConstKind::ConstArg(IsRepeatExpr::No)); } } + TyPatKind::Or(patterns) => { + for pat in patterns { + self.visit_ty_pat(pat) + } + } TyPatKind::Err(_) => {} } } diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs index c0ed3b90eb4..1c33101da35 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs @@ -412,6 +412,7 @@ impl<'tcx> Stable<'tcx> for ty::Pattern<'tcx> { end: Some(end.stable(tables)), include_end: true, }, + ty::PatternKind::Or(_) => todo!(), } } } diff --git a/compiler/rustc_symbol_mangling/src/v0.rs b/compiler/rustc_symbol_mangling/src/v0.rs index 6cde28d0ee9..cc608736b67 100644 --- a/compiler/rustc_symbol_mangling/src/v0.rs +++ b/compiler/rustc_symbol_mangling/src/v0.rs @@ -256,6 +256,11 @@ impl<'tcx> SymbolMangler<'tcx> { Ty::new_array_with_const_len(self.tcx, self.tcx.types.unit, ct).print(self)?; } } + ty::PatternKind::Or(patterns) => { + for pat in patterns { + self.print_pat(pat)?; + } + } }) } } diff --git a/compiler/rustc_trait_selection/src/traits/wf.rs b/compiler/rustc_trait_selection/src/traits/wf.rs index 0a6b3b29990..6e85d3b5b13 100644 --- a/compiler/rustc_trait_selection/src/traits/wf.rs +++ b/compiler/rustc_trait_selection/src/traits/wf.rs @@ -695,6 +695,11 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> { check(start); check(end); } + ty::PatternKind::Or(patterns) => { + for pat in patterns { + self.add_wf_preds_for_pat_ty(base_ty, pat) + } + } } } } diff --git a/compiler/rustc_ty_utils/src/layout.rs b/compiler/rustc_ty_utils/src/layout.rs index 1915ba623cb..b962979bec7 100644 --- a/compiler/rustc_ty_utils/src/layout.rs +++ b/compiler/rustc_ty_utils/src/layout.rs @@ -255,13 +255,95 @@ fn layout_of_uncached<'tcx>( }; layout.largest_niche = Some(niche); - - tcx.mk_layout(layout) } else { bug!("pattern type with range but not scalar layout: {ty:?}, {layout:?}") } } + ty::PatternKind::Or(variants) => match *variants[0] { + ty::PatternKind::Range { .. } => { + if let BackendRepr::Scalar(scalar) = &mut layout.backend_repr { + let variants: Result<Vec<_>, _> = variants + .iter() + .map(|pat| match *pat { + ty::PatternKind::Range { start, end } => Ok(( + extract_const_value(cx, ty, start) + .unwrap() + .try_to_bits(tcx, cx.typing_env) + .ok_or_else(|| error(cx, LayoutError::Unknown(ty)))?, + extract_const_value(cx, ty, end) + .unwrap() + .try_to_bits(tcx, cx.typing_env) + .ok_or_else(|| error(cx, LayoutError::Unknown(ty)))?, + )), + ty::PatternKind::Or(_) => { + unreachable!("mixed or patterns are not allowed") + } + }) + .collect(); + let mut variants = variants?; + if !scalar.is_signed() { + let guar = tcx.dcx().err(format!( + "only signed integer base types are allowed for or-pattern pattern types at present" + )); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + variants.sort(); + if variants.len() != 2 { + let guar = tcx + .dcx() + .err(format!("the only or-pattern types allowed are two range patterns that are directly connected at their overflow site")); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + + // first is the one starting at the signed in range min + let mut first = variants[0]; + let mut second = variants[1]; + if second.0 + == layout.size.truncate(layout.size.signed_int_min() as u128) + { + (second, first) = (first, second); + } + + if layout.size.sign_extend(first.1) >= layout.size.sign_extend(second.0) + { + let guar = tcx.dcx().err(format!( + "only non-overlapping pattern type ranges are allowed at present" + )); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + if layout.size.signed_int_max() as u128 != second.1 { + let guar = tcx.dcx().err(format!( + "one pattern needs to end at `{ty}::MAX`, but was {} instead", + second.1 + )); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + + // Now generate a wrapping range (which aren't allowed in surface syntax). + scalar.valid_range_mut().start = second.0; + scalar.valid_range_mut().end = first.1; + + let niche = Niche { + offset: Size::ZERO, + value: scalar.primitive(), + valid_range: scalar.valid_range(cx), + }; + + layout.largest_niche = Some(niche); + } else { + bug!( + "pattern type with range but not scalar layout: {ty:?}, {layout:?}" + ) + } + } + ty::PatternKind::Or(..) => bug!("patterns cannot have nested or patterns"), + }, } + tcx.mk_layout(layout) } // Basic scalars. diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 9758cecaf6a..6410da1f740 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -113,6 +113,13 @@ pub trait Interner: + Relate<Self> + Flags + IntoKind<Kind = ty::PatternKind<Self>>; + type PatList: Copy + + Debug + + Hash + + Default + + Eq + + TypeVisitable<Self> + + SliceLike<Item = Self::Pat>; type Safety: Safety<Self>; type Abi: Abi<Self>; diff --git a/compiler/rustc_type_ir/src/pattern.rs b/compiler/rustc_type_ir/src/pattern.rs index d74a82da1f9..7e56565917c 100644 --- a/compiler/rustc_type_ir/src/pattern.rs +++ b/compiler/rustc_type_ir/src/pattern.rs @@ -13,4 +13,5 @@ use crate::Interner; )] pub enum PatternKind<I: Interner> { Range { start: I::Const, end: I::Const }, + Or(I::PatList), } diff --git a/compiler/rustc_type_ir/src/walk.rs b/compiler/rustc_type_ir/src/walk.rs index ebfb3f786e8..737550eb73e 100644 --- a/compiler/rustc_type_ir/src/walk.rs +++ b/compiler/rustc_type_ir/src/walk.rs @@ -173,5 +173,10 @@ fn push_ty_pat<I: Interner>(stack: &mut TypeWalkerStack<I>, pat: I::Pat) { stack.push(end.into()); stack.push(start.into()); } + ty::PatternKind::Or(pats) => { + for pat in pats.iter() { + push_ty_pat::<I>(stack, pat) + } + } } } diff --git a/src/tools/clippy/clippy_utils/src/hir_utils.rs b/src/tools/clippy/clippy_utils/src/hir_utils.rs index fe1fd70a9fa..17368a7530d 100644 --- a/src/tools/clippy/clippy_utils/src/hir_utils.rs +++ b/src/tools/clippy/clippy_utils/src/hir_utils.rs @@ -1117,6 +1117,11 @@ impl<'a, 'tcx> SpanlessHash<'a, 'tcx> { self.hash_const_arg(s); self.hash_const_arg(e); }, + TyPatKind::Or(variants) => { + for variant in variants.iter() { + self.hash_ty_pat(variant) + } + }, TyPatKind::Err(_) => {}, } } diff --git a/src/tools/rustfmt/src/types.rs b/src/tools/rustfmt/src/types.rs index 75a5a8532b8..7ec1032dcb4 100644 --- a/src/tools/rustfmt/src/types.rs +++ b/src/tools/rustfmt/src/types.rs @@ -1093,6 +1093,19 @@ impl Rewrite for ast::TyPat { ast::TyPatKind::Range(ref lhs, ref rhs, ref end_kind) => { rewrite_range_pat(context, shape, lhs, rhs, end_kind, self.span) } + ast::TyPatKind::Or(ref variants) => { + let mut first = true; + let mut s = String::new(); + for variant in variants { + if first { + first = false + } else { + s.push_str(" | "); + } + s.push_str(&variant.rewrite_result(context, shape)?); + } + Ok(s) + } ast::TyPatKind::Err(_) => Err(RewriteError::Unknown), } } diff --git a/tests/ui/type/pattern_types/or_patterns.rs b/tests/ui/type/pattern_types/or_patterns.rs new file mode 100644 index 00000000000..25cb1867047 --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns.rs @@ -0,0 +1,45 @@ +//! Demonstrate some use cases of or patterns + +//@ normalize-stderr: "pref: Align\([1-8] bytes\)" -> "pref: $$SOME_ALIGN" +//@ normalize-stderr: "randomization_seed: \d+" -> "randomization_seed: $$SEED" + +#![feature( + pattern_type_macro, + pattern_types, + rustc_attrs, + const_trait_impl, + pattern_type_range_trait +)] + +use std::pat::pattern_type; + +#[rustc_layout(debug)] +type NonNullI8 = pattern_type!(i8 is ..0 | 1..); +//~^ ERROR: layout_of + +#[rustc_layout(debug)] +type NonNegOneI8 = pattern_type!(i8 is ..-1 | 0..); +//~^ ERROR: layout_of + +fn main() { + let _: NonNullI8 = 42; + let _: NonNullI8 = 1; + let _: NonNullI8 = 0; + //~^ ERROR: mismatched types + let _: NonNullI8 = -1; + //~^ ERROR: cannot apply unary operator + let _: NonNullI8 = -128; + //~^ ERROR: cannot apply unary operator + let _: NonNullI8 = 127; + + let _: NonNegOneI8 = 42; + let _: NonNegOneI8 = 1; + let _: NonNegOneI8 = 0; + let _: NonNegOneI8 = -1; + //~^ ERROR: cannot apply unary operator + let _: NonNegOneI8 = -2; + //~^ ERROR: cannot apply unary operator + let _: NonNegOneI8 = -128; + //~^ ERROR: cannot apply unary operator + let _: NonNegOneI8 = 127; +} diff --git a/tests/ui/type/pattern_types/or_patterns.stderr b/tests/ui/type/pattern_types/or_patterns.stderr new file mode 100644 index 00000000000..58ca585f4a9 --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns.stderr @@ -0,0 +1,123 @@ +error[E0308]: mismatched types + --> $DIR/or_patterns.rs:27:24 + | +LL | let _: NonNullI8 = 0; + | --------- ^ expected `(i8) is (i8::MIN..=-1 | 1..)`, found integer + | | + | expected due to this + | + = note: expected pattern type `(i8) is (i8::MIN..=-1 | 1..)` + found type `{integer}` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-1 | 1..)` + --> $DIR/or_patterns.rs:29:24 + | +LL | let _: NonNullI8 = -1; + | ^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-1 | 1..)` + --> $DIR/or_patterns.rs:31:24 + | +LL | let _: NonNullI8 = -128; + | ^^^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-2 | 0..)` + --> $DIR/or_patterns.rs:38:26 + | +LL | let _: NonNegOneI8 = -1; + | ^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-2 | 0..)` + --> $DIR/or_patterns.rs:40:26 + | +LL | let _: NonNegOneI8 = -2; + | ^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-2 | 0..)` + --> $DIR/or_patterns.rs:42:26 + | +LL | let _: NonNegOneI8 = -128; + | ^^^^ cannot apply unary operator `-` + +error: layout_of((i8) is (i8::MIN..=-1 | 1..)) = Layout { + size: Size(1 bytes), + align: AbiAndPrefAlign { + abi: Align(1 bytes), + pref: $SOME_ALIGN, + }, + backend_repr: Scalar( + Initialized { + value: Int( + I8, + true, + ), + valid_range: 1..=255, + }, + ), + fields: Primitive, + largest_niche: Some( + Niche { + offset: Size(0 bytes), + value: Int( + I8, + true, + ), + valid_range: 1..=255, + }, + ), + uninhabited: false, + variants: Single { + index: 0, + }, + max_repr_align: None, + unadjusted_abi_align: Align(1 bytes), + randomization_seed: $SEED, + } + --> $DIR/or_patterns.rs:17:1 + | +LL | type NonNullI8 = pattern_type!(i8 is ..0 | 1..); + | ^^^^^^^^^^^^^^ + +error: layout_of((i8) is (i8::MIN..=-2 | 0..)) = Layout { + size: Size(1 bytes), + align: AbiAndPrefAlign { + abi: Align(1 bytes), + pref: $SOME_ALIGN, + }, + backend_repr: Scalar( + Initialized { + value: Int( + I8, + true, + ), + valid_range: 0..=254, + }, + ), + fields: Primitive, + largest_niche: Some( + Niche { + offset: Size(0 bytes), + value: Int( + I8, + true, + ), + valid_range: 0..=254, + }, + ), + uninhabited: false, + variants: Single { + index: 0, + }, + max_repr_align: None, + unadjusted_abi_align: Align(1 bytes), + randomization_seed: $SEED, + } + --> $DIR/or_patterns.rs:21:1 + | +LL | type NonNegOneI8 = pattern_type!(i8 is ..-1 | 0..); + | ^^^^^^^^^^^^^^^^ + +error: aborting due to 8 previous errors + +Some errors have detailed explanations: E0308, E0600. +For more information about an error, try `rustc --explain E0308`. diff --git a/tests/ui/type/pattern_types/or_patterns_invalid.rs b/tests/ui/type/pattern_types/or_patterns_invalid.rs new file mode 100644 index 00000000000..d341927601d --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns_invalid.rs @@ -0,0 +1,26 @@ +//! Demonstrate some use cases of or patterns + +#![feature( + pattern_type_macro, + pattern_types, + rustc_attrs, + const_trait_impl, + pattern_type_range_trait +)] + +use std::pat::pattern_type; + +fn main() { + //~? ERROR: only non-overlapping pattern type ranges are allowed at present + let not_adjacent: pattern_type!(i8 is -127..0 | 1..) = unsafe { std::mem::transmute(0) }; + + //~? ERROR: one pattern needs to end at `i8::MAX`, but was 29 instead + let not_wrapping: pattern_type!(i8 is 10..20 | 20..30) = unsafe { std::mem::transmute(0) }; + + //~? ERROR: only signed integer base types are allowed for or-pattern pattern types + let not_signed: pattern_type!(u8 is 10.. | 0..5) = unsafe { std::mem::transmute(0) }; + + //~? ERROR: allowed are two range patterns that are directly connected + let not_simple_enough_for_mvp: pattern_type!(i8 is ..0 | 1..10 | 10..) = + unsafe { std::mem::transmute(0) }; +} diff --git a/tests/ui/type/pattern_types/or_patterns_invalid.stderr b/tests/ui/type/pattern_types/or_patterns_invalid.stderr new file mode 100644 index 00000000000..6964788a6c2 --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns_invalid.stderr @@ -0,0 +1,10 @@ +error: only non-overlapping pattern type ranges are allowed at present + +error: one pattern needs to end at `i8::MAX`, but was 29 instead + +error: only signed integer base types are allowed for or-pattern pattern types at present + +error: the only or-pattern types allowed are two range patterns that are directly connected at their overflow site + +error: aborting due to 4 previous errors + |
