about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_builtin_macros')
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs215
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/generic/mod.rs5
2 files changed, 162 insertions, 58 deletions
diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
index 4b93b3414c7..b1d950b8d89 100644
--- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
@@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*;
 use crate::deriving::generic::*;
 use crate::deriving::{path_local, path_std};
 
+/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
+/// target item.
 pub(crate) fn expand_deriving_partial_eq(
     cx: &ExtCtxt<'_>,
     span: Span,
@@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq(
     push: &mut dyn FnMut(Annotatable),
     is_const: bool,
 ) {
-    fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
-        let base = true;
-        let expr = cs_fold(
-            true, // use foldl
-            cx,
-            span,
-            substr,
-            |cx, fold| match fold {
-                CsFold::Single(field) => {
-                    let [other_expr] = &field.other_selflike_exprs[..] else {
-                        cx.dcx()
-                            .span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
-                    };
-
-                    // We received arguments of type `&T`. Convert them to type `T` by stripping
-                    // any leading `&`. This isn't necessary for type checking, but
-                    // it results in better error messages if something goes wrong.
-                    //
-                    // Note: for arguments that look like `&{ x }`, which occur with packed
-                    // structs, this would cause expressions like `{ self.x } == { other.x }`,
-                    // which isn't valid Rust syntax. This wouldn't break compilation because these
-                    // AST nodes are constructed within the compiler. But it would mean that code
-                    // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
-                    // syntax, which would be suboptimal. So we wrap these in parens, giving
-                    // `({ self.x }) == ({ other.x })`, which is valid syntax.
-                    let convert = |expr: &P<Expr>| {
-                        if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) =
-                            &expr.kind
-                        {
-                            if let ExprKind::Block(..) = &inner.kind {
-                                // `&{ x }` form: remove the `&`, add parens.
-                                cx.expr_paren(field.span, inner.clone())
-                            } else {
-                                // `&x` form: remove the `&`.
-                                inner.clone()
-                            }
-                        } else {
-                            expr.clone()
-                        }
-                    };
-                    cx.expr_binary(
-                        field.span,
-                        BinOpKind::Eq,
-                        convert(&field.self_expr),
-                        convert(other_expr),
-                    )
-                }
-                CsFold::Combine(span, expr1, expr2) => {
-                    cx.expr_binary(span, BinOpKind::And, expr1, expr2)
-                }
-                CsFold::Fieldless => cx.expr_bool(span, base),
-            },
-        );
-        BlockOrExpr::new_expr(expr)
-    }
-
     let structural_trait_def = TraitDef {
         span,
         path: path_std!(marker::StructuralPartialEq),
@@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq(
         ret_ty: Path(path_local!(bool)),
         attributes: thin_vec![cx.attr_word(sym::inline, span)],
         fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
-        combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))),
+        combine_substructure: combine_substructure(Box::new(|a, b, c| {
+            BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
+        })),
     }];
 
     let trait_def = TraitDef {
@@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq(
     };
     trait_def.expand(cx, mitem, item, push)
 }
+
+/// Generates the equality expression for a struct or enum variant when deriving
+/// `PartialEq`.
+///
+/// This function generates an expression that checks if all fields of a struct
+/// or enum variant are equal.
+/// - Scalar fields are compared first for efficiency, followed by compound
+///   fields.
+/// - If there are no fields, returns `true` (fieldless types are always equal).
+///
+/// Whether a field is considered "scalar" is determined by comparing the symbol
+/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
+/// This check is based on the type's symbol.
+///
+/// ### Example 1
+/// ```
+/// #[derive(PartialEq)]
+/// struct i32;
+///
+/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
+/// // the primitive), it will not be treated as scalar. The function will still
+/// // check equality of `field_2` first because the symbol matches `i32`.
+/// #[derive(PartialEq)]
+/// struct Struct {
+///     field_1: &'static str,
+///     field_2: i32,
+/// }
+/// ```
+///
+/// ### Example 2
+/// ```
+/// mod ty {
+///     pub type i32 = i32;
+/// }
+///
+/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
+/// // However, the function will not reorder the fields because the symbol for
+/// // `ty::i32` does not match the symbol for the primitive `i32`
+/// // ("ty::i32" != "i32").
+/// #[derive(PartialEq)]
+/// struct Struct {
+///     field_1: &'static str,
+///     field_2: ty::i32,
+/// }
+/// ```
+///
+/// For enums, the discriminant is compared first, then the rest of the fields.
+///
+/// # Panics
+///
+/// If called on static or all-fieldless enums/structs, which should not occur
+/// during derive expansion.
+fn get_substructure_equality_expr(
+    cx: &ExtCtxt<'_>,
+    span: Span,
+    substructure: &Substructure<'_>,
+) -> P<Expr> {
+    use SubstructureFields::*;
+
+    match substructure.fields {
+        EnumMatching(.., fields) | Struct(.., fields) => {
+            let combine = move |acc, field| {
+                let rhs = get_field_equality_expr(cx, field);
+                if let Some(lhs) = acc {
+                    // Combine the previous comparison with the current field
+                    // using logical AND.
+                    return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
+                }
+                // Start the chain with the first field's comparison.
+                Some(rhs)
+            };
+
+            // First compare scalar fields, then compound fields, combining all
+            // with logical AND.
+            return fields
+                .iter()
+                .filter(|field| !field.maybe_scalar)
+                .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
+                // If there are no fields, treat as always equal.
+                .unwrap_or_else(|| cx.expr_bool(span, true));
+        }
+        EnumDiscr(disc, match_expr) => {
+            let lhs = get_field_equality_expr(cx, disc);
+            let Some(match_expr) = match_expr else {
+                return lhs;
+            };
+            // Compare the discriminant first (cheaper), then the rest of the
+            // fields.
+            return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
+        }
+        StaticEnum(..) => cx.dcx().span_bug(
+            span,
+            "unexpected static enum encountered during `derive(PartialEq)` expansion",
+        ),
+        StaticStruct(..) => cx.dcx().span_bug(
+            span,
+            "unexpected static struct encountered during `derive(PartialEq)` expansion",
+        ),
+        AllFieldlessEnum(..) => cx.dcx().span_bug(
+            span,
+            "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
+        ),
+    }
+}
+
+/// Generates an equality comparison expression for a single struct or enum
+/// field.
+///
+/// This function produces an AST expression that compares the `self` and
+/// `other` values for a field using `==`. It removes any leading references
+/// from both sides for readability. If the field is a block expression, it is
+/// wrapped in parentheses to ensure valid syntax.
+///
+/// # Panics
+///
+/// Panics if there are not exactly two arguments to compare (should be `self`
+/// and `other`).
+fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
+    let [rhs] = &field.other_selflike_exprs[..] else {
+        cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
+    };
+
+    cx.expr_binary(
+        field.span,
+        BinOpKind::Eq,
+        wrap_block_expr(cx, peel_refs(&field.self_expr)),
+        wrap_block_expr(cx, peel_refs(rhs)),
+    )
+}
+
+/// Removes all leading immutable references from an expression.
+///
+/// This is used to strip away any number of leading `&` from an expression
+/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
+/// references are preserved.
+fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
+    while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
+        expr = &inner;
+    }
+    expr.clone()
+}
+
+/// Wraps a block expression in parentheses to ensure valid AST in macro
+/// expansion output.
+///
+/// If the given expression is a block, it is wrapped in parentheses; otherwise,
+/// it is returned unchanged.
+fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
+    if matches!(&expr.kind, ExprKind::Block(..)) {
+        return cx.expr_paren(expr.span, expr);
+    }
+    expr
+}
diff --git a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
index 9aa53f9e4f7..e0e44841acb 100644
--- a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
@@ -284,6 +284,7 @@ pub(crate) struct FieldInfo {
     /// The expressions corresponding to references to this field in
     /// the other selflike arguments.
     pub other_selflike_exprs: Vec<P<Expr>>,
+    pub maybe_scalar: bool,
 }
 
 #[derive(Copy, Clone)]
@@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> {
 
             let self_expr = discr_exprs.remove(0);
             let other_selflike_exprs = discr_exprs;
-            let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
+            let discr_field =
+                FieldInfo { span, name: None, self_expr, other_selflike_exprs, maybe_scalar: true };
 
             let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args)
                 .map(|(&ident, selflike_arg)| {
@@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> {
                     name: struct_field.ident,
                     self_expr,
                     other_selflike_exprs,
+                    maybe_scalar: struct_field.ty.peel_refs().kind.maybe_scalar(),
                 }
             })
             .collect()