diff options
| author | Ell <ahunpochoevjamshed@gmail.com> | 2025-05-30 10:57:08 +0300 |
|---|---|---|
| committer | Ell <ahunpochoevjamshed@gmail.com> | 2025-06-02 15:29:34 +0300 |
| commit | a6a1c1b247aa0fa404983efa3c226c25cafdd704 (patch) | |
| tree | 02fbcb8315df8ef7af779ee1b508d5843254ba3b /compiler | |
| parent | ebe9b0060240953d721508ceb4d02a745efda88f (diff) | |
| download | rust-a6a1c1b247aa0fa404983efa3c226c25cafdd704.tar.gz rust-a6a1c1b247aa0fa404983efa3c226c25cafdd704.zip | |
Separately check equality of the scalar types and compound types in the order of declaration.
Diffstat (limited to 'compiler')
| -rw-r--r-- | compiler/rustc_ast/src/ast.rs | 33 | ||||
| -rw-r--r-- | compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs | 215 | ||||
| -rw-r--r-- | compiler/rustc_builtin_macros/src/deriving/generic/mod.rs | 5 |
3 files changed, 195 insertions, 58 deletions
diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index a16219361c0..1d4df97da58 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -2452,6 +2452,39 @@ impl TyKind { None } } + + /// Returns `true` if this type is considered a scalar primitive (e.g., + /// `i32`, `u8`, `bool`, etc). + /// + /// This check is based on **symbol equality** and does **not** remove any + /// path prefixes or references. If a type alias or shadowing is present + /// (e.g., `type i32 = CustomType;`), this method will still return `true` + /// for `i32`, even though it may not refer to the primitive type. + pub fn maybe_scalar(&self) -> bool { + let Some(ty_sym) = self.is_simple_path() else { + // unit type + return self.is_unit(); + }; + matches!( + ty_sym, + sym::i8 + | sym::i16 + | sym::i32 + | sym::i64 + | sym::i128 + | sym::u8 + | sym::u16 + | sym::u32 + | sym::u64 + | sym::u128 + | sym::f16 + | sym::f32 + | sym::f64 + | sym::f128 + | sym::char + | sym::bool + ) + } } /// A pattern type pattern. diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs index 4b93b3414c7..b1d950b8d89 100644 --- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs +++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs @@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*; use crate::deriving::generic::*; use crate::deriving::{path_local, path_std}; +/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the +/// target item. pub(crate) fn expand_deriving_partial_eq( cx: &ExtCtxt<'_>, span: Span, @@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq( push: &mut dyn FnMut(Annotatable), is_const: bool, ) { - fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr { - let base = true; - let expr = cs_fold( - true, // use foldl - cx, - span, - substr, - |cx, fold| match fold { - CsFold::Single(field) => { - let [other_expr] = &field.other_selflike_exprs[..] else { - cx.dcx() - .span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`"); - }; - - // We received arguments of type `&T`. Convert them to type `T` by stripping - // any leading `&`. This isn't necessary for type checking, but - // it results in better error messages if something goes wrong. - // - // Note: for arguments that look like `&{ x }`, which occur with packed - // structs, this would cause expressions like `{ self.x } == { other.x }`, - // which isn't valid Rust syntax. This wouldn't break compilation because these - // AST nodes are constructed within the compiler. But it would mean that code - // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid - // syntax, which would be suboptimal. So we wrap these in parens, giving - // `({ self.x }) == ({ other.x })`, which is valid syntax. - let convert = |expr: &P<Expr>| { - if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = - &expr.kind - { - if let ExprKind::Block(..) = &inner.kind { - // `&{ x }` form: remove the `&`, add parens. - cx.expr_paren(field.span, inner.clone()) - } else { - // `&x` form: remove the `&`. - inner.clone() - } - } else { - expr.clone() - } - }; - cx.expr_binary( - field.span, - BinOpKind::Eq, - convert(&field.self_expr), - convert(other_expr), - ) - } - CsFold::Combine(span, expr1, expr2) => { - cx.expr_binary(span, BinOpKind::And, expr1, expr2) - } - CsFold::Fieldless => cx.expr_bool(span, base), - }, - ); - BlockOrExpr::new_expr(expr) - } - let structural_trait_def = TraitDef { span, path: path_std!(marker::StructuralPartialEq), @@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq( ret_ty: Path(path_local!(bool)), attributes: thin_vec![cx.attr_word(sym::inline, span)], fieldless_variants_strategy: FieldlessVariantsStrategy::Unify, - combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))), + combine_substructure: combine_substructure(Box::new(|a, b, c| { + BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c)) + })), }]; let trait_def = TraitDef { @@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq( }; trait_def.expand(cx, mitem, item, push) } + +/// Generates the equality expression for a struct or enum variant when deriving +/// `PartialEq`. +/// +/// This function generates an expression that checks if all fields of a struct +/// or enum variant are equal. +/// - Scalar fields are compared first for efficiency, followed by compound +/// fields. +/// - If there are no fields, returns `true` (fieldless types are always equal). +/// +/// Whether a field is considered "scalar" is determined by comparing the symbol +/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc). +/// This check is based on the type's symbol. +/// +/// ### Example 1 +/// ``` +/// #[derive(PartialEq)] +/// struct i32; +/// +/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not +/// // the primitive), it will not be treated as scalar. The function will still +/// // check equality of `field_2` first because the symbol matches `i32`. +/// #[derive(PartialEq)] +/// struct Struct { +/// field_1: &'static str, +/// field_2: i32, +/// } +/// ``` +/// +/// ### Example 2 +/// ``` +/// mod ty { +/// pub type i32 = i32; +/// } +/// +/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`. +/// // However, the function will not reorder the fields because the symbol for +/// // `ty::i32` does not match the symbol for the primitive `i32` +/// // ("ty::i32" != "i32"). +/// #[derive(PartialEq)] +/// struct Struct { +/// field_1: &'static str, +/// field_2: ty::i32, +/// } +/// ``` +/// +/// For enums, the discriminant is compared first, then the rest of the fields. +/// +/// # Panics +/// +/// If called on static or all-fieldless enums/structs, which should not occur +/// during derive expansion. +fn get_substructure_equality_expr( + cx: &ExtCtxt<'_>, + span: Span, + substructure: &Substructure<'_>, +) -> P<Expr> { + use SubstructureFields::*; + + match substructure.fields { + EnumMatching(.., fields) | Struct(.., fields) => { + let combine = move |acc, field| { + let rhs = get_field_equality_expr(cx, field); + if let Some(lhs) = acc { + // Combine the previous comparison with the current field + // using logical AND. + return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs)); + } + // Start the chain with the first field's comparison. + Some(rhs) + }; + + // First compare scalar fields, then compound fields, combining all + // with logical AND. + return fields + .iter() + .filter(|field| !field.maybe_scalar) + .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine) + // If there are no fields, treat as always equal. + .unwrap_or_else(|| cx.expr_bool(span, true)); + } + EnumDiscr(disc, match_expr) => { + let lhs = get_field_equality_expr(cx, disc); + let Some(match_expr) = match_expr else { + return lhs; + }; + // Compare the discriminant first (cheaper), then the rest of the + // fields. + return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone()); + } + StaticEnum(..) => cx.dcx().span_bug( + span, + "unexpected static enum encountered during `derive(PartialEq)` expansion", + ), + StaticStruct(..) => cx.dcx().span_bug( + span, + "unexpected static struct encountered during `derive(PartialEq)` expansion", + ), + AllFieldlessEnum(..) => cx.dcx().span_bug( + span, + "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion", + ), + } +} + +/// Generates an equality comparison expression for a single struct or enum +/// field. +/// +/// This function produces an AST expression that compares the `self` and +/// `other` values for a field using `==`. It removes any leading references +/// from both sides for readability. If the field is a block expression, it is +/// wrapped in parentheses to ensure valid syntax. +/// +/// # Panics +/// +/// Panics if there are not exactly two arguments to compare (should be `self` +/// and `other`). +fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> { + let [rhs] = &field.other_selflike_exprs[..] else { + cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`"); + }; + + cx.expr_binary( + field.span, + BinOpKind::Eq, + wrap_block_expr(cx, peel_refs(&field.self_expr)), + wrap_block_expr(cx, peel_refs(rhs)), + ) +} + +/// Removes all leading immutable references from an expression. +/// +/// This is used to strip away any number of leading `&` from an expression +/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable +/// references are preserved. +fn peel_refs(mut expr: &P<Expr>) -> P<Expr> { + while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind { + expr = &inner; + } + expr.clone() +} + +/// Wraps a block expression in parentheses to ensure valid AST in macro +/// expansion output. +/// +/// If the given expression is a block, it is wrapped in parentheses; otherwise, +/// it is returned unchanged. +fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> { + if matches!(&expr.kind, ExprKind::Block(..)) { + return cx.expr_paren(expr.span, expr); + } + expr +} diff --git a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs index 9aa53f9e4f7..e0e44841acb 100644 --- a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs +++ b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs @@ -284,6 +284,7 @@ pub(crate) struct FieldInfo { /// The expressions corresponding to references to this field in /// the other selflike arguments. pub other_selflike_exprs: Vec<P<Expr>>, + pub maybe_scalar: bool, } #[derive(Copy, Clone)] @@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> { let self_expr = discr_exprs.remove(0); let other_selflike_exprs = discr_exprs; - let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs }; + let discr_field = + FieldInfo { span, name: None, self_expr, other_selflike_exprs, maybe_scalar: true }; let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args) .map(|(&ident, selflike_arg)| { @@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> { name: struct_field.ident, self_expr, other_selflike_exprs, + maybe_scalar: struct_field.ty.peel_refs().kind.maybe_scalar(), } }) .collect() |
