diff options
Diffstat (limited to 'compiler/rustc_builtin_macros')
| -rw-r--r-- | compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs | 215 | ||||
| -rw-r--r-- | compiler/rustc_builtin_macros/src/deriving/generic/mod.rs | 5 |
2 files changed, 162 insertions, 58 deletions
diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs index 4b93b3414c7..b1d950b8d89 100644 --- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs +++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs @@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*; use crate::deriving::generic::*; use crate::deriving::{path_local, path_std}; +/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the +/// target item. pub(crate) fn expand_deriving_partial_eq( cx: &ExtCtxt<'_>, span: Span, @@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq( push: &mut dyn FnMut(Annotatable), is_const: bool, ) { - fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr { - let base = true; - let expr = cs_fold( - true, // use foldl - cx, - span, - substr, - |cx, fold| match fold { - CsFold::Single(field) => { - let [other_expr] = &field.other_selflike_exprs[..] else { - cx.dcx() - .span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`"); - }; - - // We received arguments of type `&T`. Convert them to type `T` by stripping - // any leading `&`. This isn't necessary for type checking, but - // it results in better error messages if something goes wrong. - // - // Note: for arguments that look like `&{ x }`, which occur with packed - // structs, this would cause expressions like `{ self.x } == { other.x }`, - // which isn't valid Rust syntax. This wouldn't break compilation because these - // AST nodes are constructed within the compiler. But it would mean that code - // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid - // syntax, which would be suboptimal. So we wrap these in parens, giving - // `({ self.x }) == ({ other.x })`, which is valid syntax. - let convert = |expr: &P<Expr>| { - if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = - &expr.kind - { - if let ExprKind::Block(..) = &inner.kind { - // `&{ x }` form: remove the `&`, add parens. - cx.expr_paren(field.span, inner.clone()) - } else { - // `&x` form: remove the `&`. - inner.clone() - } - } else { - expr.clone() - } - }; - cx.expr_binary( - field.span, - BinOpKind::Eq, - convert(&field.self_expr), - convert(other_expr), - ) - } - CsFold::Combine(span, expr1, expr2) => { - cx.expr_binary(span, BinOpKind::And, expr1, expr2) - } - CsFold::Fieldless => cx.expr_bool(span, base), - }, - ); - BlockOrExpr::new_expr(expr) - } - let structural_trait_def = TraitDef { span, path: path_std!(marker::StructuralPartialEq), @@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq( ret_ty: Path(path_local!(bool)), attributes: thin_vec![cx.attr_word(sym::inline, span)], fieldless_variants_strategy: FieldlessVariantsStrategy::Unify, - combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))), + combine_substructure: combine_substructure(Box::new(|a, b, c| { + BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c)) + })), }]; let trait_def = TraitDef { @@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq( }; trait_def.expand(cx, mitem, item, push) } + +/// Generates the equality expression for a struct or enum variant when deriving +/// `PartialEq`. +/// +/// This function generates an expression that checks if all fields of a struct +/// or enum variant are equal. +/// - Scalar fields are compared first for efficiency, followed by compound +/// fields. +/// - If there are no fields, returns `true` (fieldless types are always equal). +/// +/// Whether a field is considered "scalar" is determined by comparing the symbol +/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc). +/// This check is based on the type's symbol. +/// +/// ### Example 1 +/// ``` +/// #[derive(PartialEq)] +/// struct i32; +/// +/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not +/// // the primitive), it will not be treated as scalar. The function will still +/// // check equality of `field_2` first because the symbol matches `i32`. +/// #[derive(PartialEq)] +/// struct Struct { +/// field_1: &'static str, +/// field_2: i32, +/// } +/// ``` +/// +/// ### Example 2 +/// ``` +/// mod ty { +/// pub type i32 = i32; +/// } +/// +/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`. +/// // However, the function will not reorder the fields because the symbol for +/// // `ty::i32` does not match the symbol for the primitive `i32` +/// // ("ty::i32" != "i32"). +/// #[derive(PartialEq)] +/// struct Struct { +/// field_1: &'static str, +/// field_2: ty::i32, +/// } +/// ``` +/// +/// For enums, the discriminant is compared first, then the rest of the fields. +/// +/// # Panics +/// +/// If called on static or all-fieldless enums/structs, which should not occur +/// during derive expansion. +fn get_substructure_equality_expr( + cx: &ExtCtxt<'_>, + span: Span, + substructure: &Substructure<'_>, +) -> P<Expr> { + use SubstructureFields::*; + + match substructure.fields { + EnumMatching(.., fields) | Struct(.., fields) => { + let combine = move |acc, field| { + let rhs = get_field_equality_expr(cx, field); + if let Some(lhs) = acc { + // Combine the previous comparison with the current field + // using logical AND. + return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs)); + } + // Start the chain with the first field's comparison. + Some(rhs) + }; + + // First compare scalar fields, then compound fields, combining all + // with logical AND. + return fields + .iter() + .filter(|field| !field.maybe_scalar) + .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine) + // If there are no fields, treat as always equal. + .unwrap_or_else(|| cx.expr_bool(span, true)); + } + EnumDiscr(disc, match_expr) => { + let lhs = get_field_equality_expr(cx, disc); + let Some(match_expr) = match_expr else { + return lhs; + }; + // Compare the discriminant first (cheaper), then the rest of the + // fields. + return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone()); + } + StaticEnum(..) => cx.dcx().span_bug( + span, + "unexpected static enum encountered during `derive(PartialEq)` expansion", + ), + StaticStruct(..) => cx.dcx().span_bug( + span, + "unexpected static struct encountered during `derive(PartialEq)` expansion", + ), + AllFieldlessEnum(..) => cx.dcx().span_bug( + span, + "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion", + ), + } +} + +/// Generates an equality comparison expression for a single struct or enum +/// field. +/// +/// This function produces an AST expression that compares the `self` and +/// `other` values for a field using `==`. It removes any leading references +/// from both sides for readability. If the field is a block expression, it is +/// wrapped in parentheses to ensure valid syntax. +/// +/// # Panics +/// +/// Panics if there are not exactly two arguments to compare (should be `self` +/// and `other`). +fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> { + let [rhs] = &field.other_selflike_exprs[..] else { + cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`"); + }; + + cx.expr_binary( + field.span, + BinOpKind::Eq, + wrap_block_expr(cx, peel_refs(&field.self_expr)), + wrap_block_expr(cx, peel_refs(rhs)), + ) +} + +/// Removes all leading immutable references from an expression. +/// +/// This is used to strip away any number of leading `&` from an expression +/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable +/// references are preserved. +fn peel_refs(mut expr: &P<Expr>) -> P<Expr> { + while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind { + expr = &inner; + } + expr.clone() +} + +/// Wraps a block expression in parentheses to ensure valid AST in macro +/// expansion output. +/// +/// If the given expression is a block, it is wrapped in parentheses; otherwise, +/// it is returned unchanged. +fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> { + if matches!(&expr.kind, ExprKind::Block(..)) { + return cx.expr_paren(expr.span, expr); + } + expr +} diff --git a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs index 9aa53f9e4f7..e0e44841acb 100644 --- a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs +++ b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs @@ -284,6 +284,7 @@ pub(crate) struct FieldInfo { /// The expressions corresponding to references to this field in /// the other selflike arguments. pub other_selflike_exprs: Vec<P<Expr>>, + pub maybe_scalar: bool, } #[derive(Copy, Clone)] @@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> { let self_expr = discr_exprs.remove(0); let other_selflike_exprs = discr_exprs; - let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs }; + let discr_field = + FieldInfo { span, name: None, self_expr, other_selflike_exprs, maybe_scalar: true }; let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args) .map(|(&ident, selflike_arg)| { @@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> { name: struct_field.ident, self_expr, other_selflike_exprs, + maybe_scalar: struct_field.ty.peel_refs().kind.maybe_scalar(), } }) .collect() |
