about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEll <ahunpochoevjamshed@gmail.com>2025-05-30 10:57:08 +0300
committerEll <ahunpochoevjamshed@gmail.com>2025-06-02 15:29:34 +0300
commita6a1c1b247aa0fa404983efa3c226c25cafdd704 (patch)
tree02fbcb8315df8ef7af779ee1b508d5843254ba3b
parentebe9b0060240953d721508ceb4d02a745efda88f (diff)
downloadrust-a6a1c1b247aa0fa404983efa3c226c25cafdd704.tar.gz
rust-a6a1c1b247aa0fa404983efa3c226c25cafdd704.zip
Separately check equality of the scalar types and compound types in the order of declaration.
-rw-r--r--compiler/rustc_ast/src/ast.rs33
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs215
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/generic/mod.rs5
-rw-r--r--tests/ui/deriving/deriving-all-codegen.rs30
-rw-r--r--tests/ui/deriving/deriving-all-codegen.stdout165
5 files changed, 390 insertions, 58 deletions
diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs
index a16219361c0..1d4df97da58 100644
--- a/compiler/rustc_ast/src/ast.rs
+++ b/compiler/rustc_ast/src/ast.rs
@@ -2452,6 +2452,39 @@ impl TyKind {
             None
         }
     }
+
+    /// Returns `true` if this type is considered a scalar primitive (e.g.,
+    /// `i32`, `u8`, `bool`, etc).
+    ///
+    /// This check is based on **symbol equality** and does **not** remove any
+    /// path prefixes or references. If a type alias or shadowing is present
+    /// (e.g., `type i32 = CustomType;`), this method will still return `true`
+    /// for `i32`, even though it may not refer to the primitive type.
+    pub fn maybe_scalar(&self) -> bool {
+        let Some(ty_sym) = self.is_simple_path() else {
+            // unit type
+            return self.is_unit();
+        };
+        matches!(
+            ty_sym,
+            sym::i8
+                | sym::i16
+                | sym::i32
+                | sym::i64
+                | sym::i128
+                | sym::u8
+                | sym::u16
+                | sym::u32
+                | sym::u64
+                | sym::u128
+                | sym::f16
+                | sym::f32
+                | sym::f64
+                | sym::f128
+                | sym::char
+                | sym::bool
+        )
+    }
 }
 
 /// A pattern type pattern.
diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
index 4b93b3414c7..b1d950b8d89 100644
--- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
@@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*;
 use crate::deriving::generic::*;
 use crate::deriving::{path_local, path_std};
 
+/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
+/// target item.
 pub(crate) fn expand_deriving_partial_eq(
     cx: &ExtCtxt<'_>,
     span: Span,
@@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq(
     push: &mut dyn FnMut(Annotatable),
     is_const: bool,
 ) {
-    fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
-        let base = true;
-        let expr = cs_fold(
-            true, // use foldl
-            cx,
-            span,
-            substr,
-            |cx, fold| match fold {
-                CsFold::Single(field) => {
-                    let [other_expr] = &field.other_selflike_exprs[..] else {
-                        cx.dcx()
-                            .span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
-                    };
-
-                    // We received arguments of type `&T`. Convert them to type `T` by stripping
-                    // any leading `&`. This isn't necessary for type checking, but
-                    // it results in better error messages if something goes wrong.
-                    //
-                    // Note: for arguments that look like `&{ x }`, which occur with packed
-                    // structs, this would cause expressions like `{ self.x } == { other.x }`,
-                    // which isn't valid Rust syntax. This wouldn't break compilation because these
-                    // AST nodes are constructed within the compiler. But it would mean that code
-                    // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
-                    // syntax, which would be suboptimal. So we wrap these in parens, giving
-                    // `({ self.x }) == ({ other.x })`, which is valid syntax.
-                    let convert = |expr: &P<Expr>| {
-                        if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) =
-                            &expr.kind
-                        {
-                            if let ExprKind::Block(..) = &inner.kind {
-                                // `&{ x }` form: remove the `&`, add parens.
-                                cx.expr_paren(field.span, inner.clone())
-                            } else {
-                                // `&x` form: remove the `&`.
-                                inner.clone()
-                            }
-                        } else {
-                            expr.clone()
-                        }
-                    };
-                    cx.expr_binary(
-                        field.span,
-                        BinOpKind::Eq,
-                        convert(&field.self_expr),
-                        convert(other_expr),
-                    )
-                }
-                CsFold::Combine(span, expr1, expr2) => {
-                    cx.expr_binary(span, BinOpKind::And, expr1, expr2)
-                }
-                CsFold::Fieldless => cx.expr_bool(span, base),
-            },
-        );
-        BlockOrExpr::new_expr(expr)
-    }
-
     let structural_trait_def = TraitDef {
         span,
         path: path_std!(marker::StructuralPartialEq),
@@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq(
         ret_ty: Path(path_local!(bool)),
         attributes: thin_vec![cx.attr_word(sym::inline, span)],
         fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
-        combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))),
+        combine_substructure: combine_substructure(Box::new(|a, b, c| {
+            BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
+        })),
     }];
 
     let trait_def = TraitDef {
@@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq(
     };
     trait_def.expand(cx, mitem, item, push)
 }
+
+/// Generates the equality expression for a struct or enum variant when deriving
+/// `PartialEq`.
+///
+/// This function generates an expression that checks if all fields of a struct
+/// or enum variant are equal.
+/// - Scalar fields are compared first for efficiency, followed by compound
+///   fields.
+/// - If there are no fields, returns `true` (fieldless types are always equal).
+///
+/// Whether a field is considered "scalar" is determined by comparing the symbol
+/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
+/// This check is based on the type's symbol.
+///
+/// ### Example 1
+/// ```
+/// #[derive(PartialEq)]
+/// struct i32;
+///
+/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
+/// // the primitive), it will not be treated as scalar. The function will still
+/// // check equality of `field_2` first because the symbol matches `i32`.
+/// #[derive(PartialEq)]
+/// struct Struct {
+///     field_1: &'static str,
+///     field_2: i32,
+/// }
+/// ```
+///
+/// ### Example 2
+/// ```
+/// mod ty {
+///     pub type i32 = i32;
+/// }
+///
+/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
+/// // However, the function will not reorder the fields because the symbol for
+/// // `ty::i32` does not match the symbol for the primitive `i32`
+/// // ("ty::i32" != "i32").
+/// #[derive(PartialEq)]
+/// struct Struct {
+///     field_1: &'static str,
+///     field_2: ty::i32,
+/// }
+/// ```
+///
+/// For enums, the discriminant is compared first, then the rest of the fields.
+///
+/// # Panics
+///
+/// If called on static or all-fieldless enums/structs, which should not occur
+/// during derive expansion.
+fn get_substructure_equality_expr(
+    cx: &ExtCtxt<'_>,
+    span: Span,
+    substructure: &Substructure<'_>,
+) -> P<Expr> {
+    use SubstructureFields::*;
+
+    match substructure.fields {
+        EnumMatching(.., fields) | Struct(.., fields) => {
+            let combine = move |acc, field| {
+                let rhs = get_field_equality_expr(cx, field);
+                if let Some(lhs) = acc {
+                    // Combine the previous comparison with the current field
+                    // using logical AND.
+                    return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
+                }
+                // Start the chain with the first field's comparison.
+                Some(rhs)
+            };
+
+            // First compare scalar fields, then compound fields, combining all
+            // with logical AND.
+            return fields
+                .iter()
+                .filter(|field| !field.maybe_scalar)
+                .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
+                // If there are no fields, treat as always equal.
+                .unwrap_or_else(|| cx.expr_bool(span, true));
+        }
+        EnumDiscr(disc, match_expr) => {
+            let lhs = get_field_equality_expr(cx, disc);
+            let Some(match_expr) = match_expr else {
+                return lhs;
+            };
+            // Compare the discriminant first (cheaper), then the rest of the
+            // fields.
+            return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
+        }
+        StaticEnum(..) => cx.dcx().span_bug(
+            span,
+            "unexpected static enum encountered during `derive(PartialEq)` expansion",
+        ),
+        StaticStruct(..) => cx.dcx().span_bug(
+            span,
+            "unexpected static struct encountered during `derive(PartialEq)` expansion",
+        ),
+        AllFieldlessEnum(..) => cx.dcx().span_bug(
+            span,
+            "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
+        ),
+    }
+}
+
+/// Generates an equality comparison expression for a single struct or enum
+/// field.
+///
+/// This function produces an AST expression that compares the `self` and
+/// `other` values for a field using `==`. It removes any leading references
+/// from both sides for readability. If the field is a block expression, it is
+/// wrapped in parentheses to ensure valid syntax.
+///
+/// # Panics
+///
+/// Panics if there are not exactly two arguments to compare (should be `self`
+/// and `other`).
+fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
+    let [rhs] = &field.other_selflike_exprs[..] else {
+        cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
+    };
+
+    cx.expr_binary(
+        field.span,
+        BinOpKind::Eq,
+        wrap_block_expr(cx, peel_refs(&field.self_expr)),
+        wrap_block_expr(cx, peel_refs(rhs)),
+    )
+}
+
+/// Removes all leading immutable references from an expression.
+///
+/// This is used to strip away any number of leading `&` from an expression
+/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
+/// references are preserved.
+fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
+    while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
+        expr = &inner;
+    }
+    expr.clone()
+}
+
+/// Wraps a block expression in parentheses to ensure valid AST in macro
+/// expansion output.
+///
+/// If the given expression is a block, it is wrapped in parentheses; otherwise,
+/// it is returned unchanged.
+fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
+    if matches!(&expr.kind, ExprKind::Block(..)) {
+        return cx.expr_paren(expr.span, expr);
+    }
+    expr
+}
diff --git a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
index 9aa53f9e4f7..e0e44841acb 100644
--- a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
@@ -284,6 +284,7 @@ pub(crate) struct FieldInfo {
     /// The expressions corresponding to references to this field in
     /// the other selflike arguments.
     pub other_selflike_exprs: Vec<P<Expr>>,
+    pub maybe_scalar: bool,
 }
 
 #[derive(Copy, Clone)]
@@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> {
 
             let self_expr = discr_exprs.remove(0);
             let other_selflike_exprs = discr_exprs;
-            let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
+            let discr_field =
+                FieldInfo { span, name: None, self_expr, other_selflike_exprs, maybe_scalar: true };
 
             let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args)
                 .map(|(&ident, selflike_arg)| {
@@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> {
                     name: struct_field.ident,
                     self_expr,
                     other_selflike_exprs,
+                    maybe_scalar: struct_field.ty.peel_refs().kind.maybe_scalar(),
                 }
             })
             .collect()
diff --git a/tests/ui/deriving/deriving-all-codegen.rs b/tests/ui/deriving/deriving-all-codegen.rs
index eab2b4f1f53..e2b6804fbd1 100644
--- a/tests/ui/deriving/deriving-all-codegen.rs
+++ b/tests/ui/deriving/deriving-all-codegen.rs
@@ -45,6 +45,22 @@ struct Big {
     b1: u32, b2: u32, b3: u32, b4: u32, b5: u32, b6: u32, b7: u32, b8: u32,
 }
 
+// It is more efficient to compare scalar types before non-scalar types.
+#[derive(PartialEq, PartialOrd)]
+struct Reorder {
+    b1: Option<f32>,
+    b2: u16,
+    b3: &'static str,
+    b4: i8,
+    b5: u128,
+    _b: *mut &'static dyn FnMut() -> (),
+    b6: f64,
+    b7: &'static mut (),
+    b8: char,
+    b9: &'static [i64],
+    b10: &'static *const bool,
+}
+
 // A struct that doesn't impl `Copy`, which means it gets the non-simple
 // `clone` implemention that clones the fields individually.
 #[derive(Clone)]
@@ -130,6 +146,20 @@ enum Mixed {
     S { d1: Option<u32>, d2: Option<i32> },
 }
 
+// When comparing enum variant it is more efficient to compare scalar types before non-scalar types.
+#[derive(PartialEq, PartialOrd)]
+enum ReorderEnum {
+    A(i32),
+    B,
+    C(i8),
+    D,
+    E,
+    F,
+    G(&'static mut str, *const u8, *const dyn Fn() -> ()),
+    H,
+    I,
+}
+
 // An enum with no fieldless variants. Note that `Default` cannot be derived
 // for this enum.
 #[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
diff --git a/tests/ui/deriving/deriving-all-codegen.stdout b/tests/ui/deriving/deriving-all-codegen.stdout
index 6503c870990..fa8f249373d 100644
--- a/tests/ui/deriving/deriving-all-codegen.stdout
+++ b/tests/ui/deriving/deriving-all-codegen.stdout
@@ -419,6 +419,100 @@ impl ::core::cmp::Ord for Big {
     }
 }
 
+// It is more efficient to compare scalar types before non-scalar types.
+struct Reorder {
+    b1: Option<f32>,
+    b2: u16,
+    b3: &'static str,
+    b4: i8,
+    b5: u128,
+    _b: *mut &'static dyn FnMut() -> (),
+    b6: f64,
+    b7: &'static mut (),
+    b8: char,
+    b9: &'static [i64],
+    b10: &'static *const bool,
+}
+#[automatically_derived]
+impl ::core::marker::StructuralPartialEq for Reorder { }
+#[automatically_derived]
+impl ::core::cmp::PartialEq for Reorder {
+    #[inline]
+    fn eq(&self, other: &Reorder) -> bool {
+        self.b2 == other.b2 && self.b4 == other.b4 && self.b5 == other.b5 &&
+                                        self.b6 == other.b6 && self.b7 == other.b7 &&
+                                self.b8 == other.b8 && self.b10 == other.b10 &&
+                        self.b1 == other.b1 && self.b3 == other.b3 &&
+                self._b == other._b && self.b9 == other.b9
+    }
+}
+#[automatically_derived]
+impl ::core::cmp::PartialOrd for Reorder {
+    #[inline]
+    fn partial_cmp(&self, other: &Reorder)
+        -> ::core::option::Option<::core::cmp::Ordering> {
+        match ::core::cmp::PartialOrd::partial_cmp(&self.b1, &other.b1) {
+            ::core::option::Option::Some(::core::cmp::Ordering::Equal) =>
+                match ::core::cmp::PartialOrd::partial_cmp(&self.b2,
+                        &other.b2) {
+                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                        =>
+                        match ::core::cmp::PartialOrd::partial_cmp(&self.b3,
+                                &other.b3) {
+                            ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                =>
+                                match ::core::cmp::PartialOrd::partial_cmp(&self.b4,
+                                        &other.b4) {
+                                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                        =>
+                                        match ::core::cmp::PartialOrd::partial_cmp(&self.b5,
+                                                &other.b5) {
+                                            ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                                =>
+                                                match ::core::cmp::PartialOrd::partial_cmp(&self._b,
+                                                        &other._b) {
+                                                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                                        =>
+                                                        match ::core::cmp::PartialOrd::partial_cmp(&self.b6,
+                                                                &other.b6) {
+                                                            ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                                                =>
+                                                                match ::core::cmp::PartialOrd::partial_cmp(&self.b7,
+                                                                        &other.b7) {
+                                                                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                                                        =>
+                                                                        match ::core::cmp::PartialOrd::partial_cmp(&self.b8,
+                                                                                &other.b8) {
+                                                                            ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                                                                =>
+                                                                                match ::core::cmp::PartialOrd::partial_cmp(&self.b9,
+                                                                                        &other.b9) {
+                                                                                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                                                                        =>
+                                                                                        ::core::cmp::PartialOrd::partial_cmp(&self.b10, &other.b10),
+                                                                                    cmp => cmp,
+                                                                                },
+                                                                            cmp => cmp,
+                                                                        },
+                                                                    cmp => cmp,
+                                                                },
+                                                            cmp => cmp,
+                                                        },
+                                                    cmp => cmp,
+                                                },
+                                            cmp => cmp,
+                                        },
+                                    cmp => cmp,
+                                },
+                            cmp => cmp,
+                        },
+                    cmp => cmp,
+                },
+            cmp => cmp,
+        }
+    }
+}
+
 // A struct that doesn't impl `Copy`, which means it gets the non-simple
 // `clone` implemention that clones the fields individually.
 struct NonCopy(u32);
@@ -1167,6 +1261,77 @@ impl ::core::cmp::Ord for Mixed {
     }
 }
 
+// When comparing enum variant it is more efficient to compare scalar types before non-scalar types.
+enum ReorderEnum {
+    A(i32),
+    B,
+    C(i8),
+    D,
+    E,
+    F,
+    G(&'static mut str, *const u8, *const dyn Fn() -> ()),
+    H,
+    I,
+}
+#[automatically_derived]
+impl ::core::marker::StructuralPartialEq for ReorderEnum { }
+#[automatically_derived]
+impl ::core::cmp::PartialEq for ReorderEnum {
+    #[inline]
+    fn eq(&self, other: &ReorderEnum) -> bool {
+        let __self_discr = ::core::intrinsics::discriminant_value(self);
+        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
+        __self_discr == __arg1_discr &&
+            match (self, other) {
+                (ReorderEnum::A(__self_0), ReorderEnum::A(__arg1_0)) =>
+                    __self_0 == __arg1_0,
+                (ReorderEnum::C(__self_0), ReorderEnum::C(__arg1_0)) =>
+                    __self_0 == __arg1_0,
+                (ReorderEnum::G(__self_0, __self_1, __self_2),
+                    ReorderEnum::G(__arg1_0, __arg1_1, __arg1_2)) =>
+                    __self_1 == __arg1_1 && __self_0 == __arg1_0 &&
+                        __self_2 == __arg1_2,
+                _ => true,
+            }
+    }
+}
+#[automatically_derived]
+impl ::core::cmp::PartialOrd for ReorderEnum {
+    #[inline]
+    fn partial_cmp(&self, other: &ReorderEnum)
+        -> ::core::option::Option<::core::cmp::Ordering> {
+        let __self_discr = ::core::intrinsics::discriminant_value(self);
+        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
+        match ::core::cmp::PartialOrd::partial_cmp(&__self_discr,
+                &__arg1_discr) {
+            ::core::option::Option::Some(::core::cmp::Ordering::Equal) =>
+                match (self, other) {
+                    (ReorderEnum::A(__self_0), ReorderEnum::A(__arg1_0)) =>
+                        ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
+                    (ReorderEnum::C(__self_0), ReorderEnum::C(__arg1_0)) =>
+                        ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
+                    (ReorderEnum::G(__self_0, __self_1, __self_2),
+                        ReorderEnum::G(__arg1_0, __arg1_1, __arg1_2)) =>
+                        match ::core::cmp::PartialOrd::partial_cmp(__self_0,
+                                __arg1_0) {
+                            ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                =>
+                                match ::core::cmp::PartialOrd::partial_cmp(__self_1,
+                                        __arg1_1) {
+                                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
+                                        => ::core::cmp::PartialOrd::partial_cmp(__self_2, __arg1_2),
+                                    cmp => cmp,
+                                },
+                            cmp => cmp,
+                        },
+                    _ =>
+                        ::core::option::Option::Some(::core::cmp::Ordering::Equal),
+                },
+            cmp => cmp,
+        }
+    }
+}
+
 // An enum with no fieldless variants. Note that `Default` cannot be derived
 // for this enum.
 enum Fielded { X(u32), Y(bool), Z(Option<i32>), }