about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNicholas Nethercote <n.nethercote@gmail.com>2022-07-05 11:23:55 +1000
committerNicholas Nethercote <n.nethercote@gmail.com>2022-07-09 09:02:50 +1000
commit16a286b003477fe07c06c5030f0ae8298c3e78ec (patch)
tree36e3af5e983c3bdefac099a51bb84e730acdef8d
parent559398fa7860fff2b4058c302efb6f14312b0fe4 (diff)
downloadrust-16a286b003477fe07c06c5030f0ae8298c3e78ec.tar.gz
rust-16a286b003477fe07c06c5030f0ae8298c3e78ec.zip
Simplify `cs_fold`.
`cs_fold` has four distinct cases, covered by three different function
arguments:

- first field
- combine current field with previous results
- no fields
- non-matching enum variants

This commit clarifies things by replacing the three function arguments
with one that takes a new `CsFold` type with four slightly different)
cases

- single field
- combine result for current field with results for previous fields
- no fields
- non-matching enum variants

This makes the code shorter and clearer.
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/cmp/ord.rs75
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs37
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs81
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/generic/mod.rs83
4 files changed, 114 insertions, 162 deletions
diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/ord.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/ord.rs
index 1856be87a20..859e995356e 100644
--- a/compiler/rustc_builtin_macros/src/deriving/cmp/ord.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/cmp/ord.rs
@@ -55,57 +55,38 @@ pub fn cs_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> Bl
         // foldr nests the if-elses correctly, leaving the first field
         // as the outermost one, and the last as the innermost.
         false,
-        |cx, span, old, self_expr, other_selflike_exprs| {
-            // match new {
-            //     ::core::cmp::Ordering::Equal => old,
-            //     cmp => cmp
-            // }
-            let new = {
-                let [other_expr] = other_selflike_exprs else {
-                    cx.span_bug(span, "not exactly 2 arguments in `derive(Ord)`");
-                };
+        cx,
+        span,
+        substr,
+        |cx, fold| match fold {
+            CsFold::Single(field) => {
+                let [other_expr] = &field.other_selflike_exprs[..] else {
+                        cx.span_bug(field.span, "not exactly 2 arguments in `derive(Ord)`");
+                    };
                 let args = vec![
-                    cx.expr_addr_of(span, self_expr),
-                    cx.expr_addr_of(span, other_expr.clone()),
+                    cx.expr_addr_of(field.span, field.self_expr.clone()),
+                    cx.expr_addr_of(field.span, other_expr.clone()),
                 ];
-                cx.expr_call_global(span, cmp_path.clone(), args)
-            };
-
-            let eq_arm = cx.arm(span, cx.pat_path(span, equal_path.clone()), old);
-            let neq_arm = cx.arm(span, cx.pat_ident(span, test_id), cx.expr_ident(span, test_id));
-
-            cx.expr_match(span, new, vec![eq_arm, neq_arm])
-        },
-        |cx, args| match args {
-            Some((span, self_expr, other_selflike_exprs)) => {
-                let new = {
-                    let [other_expr] = other_selflike_exprs else {
-                            cx.span_bug(span, "not exactly 2 arguments in `derive(Ord)`");
-                        };
-                    let args = vec![
-                        cx.expr_addr_of(span, self_expr),
-                        cx.expr_addr_of(span, other_expr.clone()),
-                    ];
-                    cx.expr_call_global(span, cmp_path.clone(), args)
-                };
-
-                new
+                cx.expr_call_global(field.span, cmp_path.clone(), args)
             }
-            None => cx.expr_path(equal_path.clone()),
-        },
-        Box::new(|cx, span, tag_tuple| {
-            if tag_tuple.len() != 2 {
-                cx.span_bug(span, "not exactly 2 arguments in `derive(Ord)`")
-            } else {
-                let lft = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[0]));
-                let rgt = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[1]));
-                let fn_cmp_path = cx.std_path(&[sym::cmp, sym::Ord, sym::cmp]);
-                cx.expr_call_global(span, fn_cmp_path, vec![lft, rgt])
+            CsFold::Combine(span, expr1, expr2) => {
+                let eq_arm = cx.arm(span, cx.pat_path(span, equal_path.clone()), expr1);
+                let neq_arm =
+                    cx.arm(span, cx.pat_ident(span, test_id), cx.expr_ident(span, test_id));
+                cx.expr_match(span, expr2, vec![eq_arm, neq_arm])
             }
-        }),
-        cx,
-        span,
-        substr,
+            CsFold::Fieldless => cx.expr_path(equal_path.clone()),
+            CsFold::EnumNonMatching(span, tag_tuple) => {
+                if tag_tuple.len() != 2 {
+                    cx.span_bug(span, "not exactly 2 arguments in `derive(Ord)`")
+                } else {
+                    let lft = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[0]));
+                    let rgt = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[1]));
+                    let fn_cmp_path = cx.std_path(&[sym::cmp, sym::Ord, sym::cmp]);
+                    cx.expr_call_global(span, fn_cmp_path, vec![lft, rgt])
+                }
+            }
+        },
     );
     BlockOrExpr::new_expr(expr)
 }
diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
index e4af0166577..724c639984c 100644
--- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
@@ -2,8 +2,7 @@ use crate::deriving::generic::ty::*;
 use crate::deriving::generic::*;
 use crate::deriving::{path_local, path_std};
 
-use rustc_ast::ptr::P;
-use rustc_ast::{BinOpKind, Expr, MetaItem};
+use rustc_ast::{BinOpKind, MetaItem};
 use rustc_expand::base::{Annotatable, ExtCtxt};
 use rustc_span::symbol::sym;
 use rustc_span::Span;
@@ -23,34 +22,22 @@ pub fn expand_deriving_partial_eq(
         combiner: BinOpKind,
         base: bool,
     ) -> BlockOrExpr {
-        let op = |cx: &mut ExtCtxt<'_>,
-                  span: Span,
-                  self_expr: P<Expr>,
-                  other_selflike_exprs: &[P<Expr>]| {
-            let [other_expr] = other_selflike_exprs else {
-                cx.span_bug(span, "not exactly 2 arguments in `derive(PartialEq)`");
-            };
-
-            cx.expr_binary(span, op, self_expr, other_expr.clone())
-        };
-
         let expr = cs_fold(
             true, // use foldl
-            |cx, span, old, self_expr, other_selflike_exprs| {
-                let eq = op(cx, span, self_expr, other_selflike_exprs);
-                cx.expr_binary(span, combiner, old, eq)
-            },
-            |cx, args| match args {
-                Some((span, self_expr, other_selflike_exprs)) => {
-                    // Special-case the base case to generate cleaner code.
-                    op(cx, span, self_expr, other_selflike_exprs)
-                }
-                None => cx.expr_bool(span, base),
-            },
-            Box::new(|cx, span, _| cx.expr_bool(span, !base)),
             cx,
             span,
             substr,
+            |cx, fold| match fold {
+                CsFold::Single(field) => {
+                    let [other_expr] = &field.other_selflike_exprs[..] else {
+                        cx.span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
+                    };
+                    cx.expr_binary(field.span, op, field.self_expr.clone(), other_expr.clone())
+                }
+                CsFold::Combine(span, expr1, expr2) => cx.expr_binary(span, combiner, expr1, expr2),
+                CsFold::Fieldless => cx.expr_bool(span, base),
+                CsFold::EnumNonMatching(span, _tag_tuple) => cx.expr_bool(span, !base),
+            },
         );
         BlockOrExpr::new_expr(expr)
     }
diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs
index bf52c63fad4..3f9843922da 100644
--- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs
@@ -63,61 +63,40 @@ pub fn cs_partial_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_
         // foldr nests the if-elses correctly, leaving the first field
         // as the outermost one, and the last as the innermost.
         false,
-        |cx, span, old, self_expr, other_selflike_exprs| {
-            // match new {
-            //     Some(::core::cmp::Ordering::Equal) => old,
-            //     cmp => cmp
-            // }
-            let new = {
-                let [other_expr] = other_selflike_exprs else {
-                    cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`");
-                };
-
+        cx,
+        span,
+        substr,
+        |cx, fold| match fold {
+            CsFold::Single(field) => {
+                let [other_expr] = &field.other_selflike_exprs[..] else {
+                        cx.span_bug(field.span, "not exactly 2 arguments in `derive(Ord)`");
+                    };
                 let args = vec![
-                    cx.expr_addr_of(span, self_expr),
-                    cx.expr_addr_of(span, other_expr.clone()),
+                    cx.expr_addr_of(field.span, field.self_expr.clone()),
+                    cx.expr_addr_of(field.span, other_expr.clone()),
                 ];
-
-                cx.expr_call_global(span, partial_cmp_path.clone(), args)
-            };
-
-            let eq_arm =
-                cx.arm(span, cx.pat_some(span, cx.pat_path(span, equal_path.clone())), old);
-            let neq_arm = cx.arm(span, cx.pat_ident(span, test_id), cx.expr_ident(span, test_id));
-
-            cx.expr_match(span, new, vec![eq_arm, neq_arm])
-        },
-        |cx, args| match args {
-            Some((span, self_expr, other_selflike_exprs)) => {
-                let new = {
-                    let [other_expr] = other_selflike_exprs else {
-                            cx.span_bug(span, "not exactly 2 arguments in `derive(Ord)`");
-                        };
-                    let args = vec![
-                        cx.expr_addr_of(span, self_expr),
-                        cx.expr_addr_of(span, other_expr.clone()),
-                    ];
-                    cx.expr_call_global(span, partial_cmp_path.clone(), args)
-                };
-
-                new
+                cx.expr_call_global(field.span, partial_cmp_path.clone(), args)
             }
-            None => cx.expr_some(span, cx.expr_path(equal_path.clone())),
-        },
-        Box::new(|cx, span, tag_tuple| {
-            if tag_tuple.len() != 2 {
-                cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`")
-            } else {
-                let lft = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[0]));
-                let rgt = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[1]));
-                let fn_partial_cmp_path =
-                    cx.std_path(&[sym::cmp, sym::PartialOrd, sym::partial_cmp]);
-                cx.expr_call_global(span, fn_partial_cmp_path, vec![lft, rgt])
+            CsFold::Combine(span, expr1, expr2) => {
+                let eq_arm =
+                    cx.arm(span, cx.pat_some(span, cx.pat_path(span, equal_path.clone())), expr1);
+                let neq_arm =
+                    cx.arm(span, cx.pat_ident(span, test_id), cx.expr_ident(span, test_id));
+                cx.expr_match(span, expr2, vec![eq_arm, neq_arm])
             }
-        }),
-        cx,
-        span,
-        substr,
+            CsFold::Fieldless => cx.expr_some(span, cx.expr_path(equal_path.clone())),
+            CsFold::EnumNonMatching(span, tag_tuple) => {
+                if tag_tuple.len() != 2 {
+                    cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`")
+                } else {
+                    let lft = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[0]));
+                    let rgt = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[1]));
+                    let fn_partial_cmp_path =
+                        cx.std_path(&[sym::cmp, sym::PartialOrd, sym::partial_cmp]);
+                    cx.expr_call_global(span, fn_partial_cmp_path, vec![lft, rgt])
+                }
+            }
+        },
     );
     BlockOrExpr::new_expr(expr)
 }
diff --git a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
index 70a97c32b48..5cad71467a1 100644
--- a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
@@ -296,11 +296,6 @@ pub enum SubstructureFields<'a> {
 pub type CombineSubstructureFunc<'a> =
     Box<dyn FnMut(&mut ExtCtxt<'_>, Span, &Substructure<'_>) -> BlockOrExpr + 'a>;
 
-/// Deal with non-matching enum variants. The slice is the identifiers holding
-/// the variant index value for each of the `Self` arguments.
-pub type EnumNonMatchCollapsedFunc<'a> =
-    Box<dyn FnMut(&mut ExtCtxt<'_>, Span, &[Ident]) -> P<Expr> + 'a>;
-
 pub fn combine_substructure(
     f: CombineSubstructureFunc<'_>,
 ) -> RefCell<CombineSubstructureFunc<'_>> {
@@ -1601,55 +1596,65 @@ impl<'a> TraitDef<'a> {
     }
 }
 
-/// Function to fold over fields, with three cases, to generate more efficient and concise code.
-/// When the `substructure` has grouped fields, there are two cases:
-/// Zero fields: call the base case function with `None` (like the usual base case of `cs_fold`).
-/// One or more fields: call the base case function on the first value (which depends on
-/// `use_fold`), and use that as the base case. Then perform `cs_fold` on the remainder of the
-/// fields.
-/// When the `substructure` is an `EnumNonMatchingCollapsed`, the result of `enum_nonmatch_f`
-/// is returned. Statics may not be folded over.
-pub fn cs_fold<F, B>(
+/// The function passed to `cs_fold` is called repeatedly with a value of this
+/// type. It describes one part of the code generation. The result is always an
+/// expression.
+pub enum CsFold<'a> {
+    /// The basic case: a field expression for one or more selflike args. E.g.
+    /// for `PartialEq::eq` this is something like `self.x == other.x`.
+    Single(&'a FieldInfo),
+
+    /// The combination of two field expressions. E.g. for `PartialEq::eq` this
+    /// is something like `<field1 equality> && <field2 equality>`.
+    Combine(Span, P<Expr>, P<Expr>),
+
+    // The fallback case for a struct or enum variant with no fields.
+    Fieldless,
+
+    /// The fallback case for non-matching enum variants. The slice is the
+    /// identifiers holding the variant index value for each of the `Self`
+    /// arguments.
+    EnumNonMatching(Span, &'a [Ident]),
+}
+
+/// Folds over fields, combining the expressions for each field in a sequence.
+/// Statics may not be folded over.
+pub fn cs_fold<F>(
     use_foldl: bool,
-    mut f: F,
-    mut b: B,
-    mut enum_nonmatch_f: EnumNonMatchCollapsedFunc<'_>,
     cx: &mut ExtCtxt<'_>,
     trait_span: Span,
     substructure: &Substructure<'_>,
+    mut f: F,
 ) -> P<Expr>
 where
-    F: FnMut(&mut ExtCtxt<'_>, Span, P<Expr>, P<Expr>, &[P<Expr>]) -> P<Expr>,
-    B: FnMut(&mut ExtCtxt<'_>, Option<(Span, P<Expr>, &[P<Expr>])>) -> P<Expr>,
+    F: FnMut(&mut ExtCtxt<'_>, CsFold<'_>) -> P<Expr>,
 {
     match *substructure.fields {
         EnumMatching(.., ref all_fields) | Struct(_, ref all_fields) => {
-            let (base, rest) = match (all_fields.is_empty(), use_foldl) {
-                (false, true) => {
-                    let (first, rest) = all_fields.split_first().unwrap();
-                    let args =
-                        (first.span, first.self_expr.clone(), &first.other_selflike_exprs[..]);
-                    (b(cx, Some(args)), rest)
-                }
-                (false, false) => {
-                    let (last, rest) = all_fields.split_last().unwrap();
-                    let args = (last.span, last.self_expr.clone(), &last.other_selflike_exprs[..]);
-                    (b(cx, Some(args)), rest)
-                }
-                (true, _) => (b(cx, None), &all_fields[..]),
+            if all_fields.is_empty() {
+                return f(cx, CsFold::Fieldless);
+            }
+
+            let (base_field, rest) = if use_foldl {
+                all_fields.split_first().unwrap()
+            } else {
+                all_fields.split_last().unwrap()
+            };
+
+            let base_expr = f(cx, CsFold::Single(base_field));
+
+            let op = |old, field: &FieldInfo| {
+                let new = f(cx, CsFold::Single(field));
+                f(cx, CsFold::Combine(field.span, old, new))
             };
 
             if use_foldl {
-                rest.iter().fold(base, |old, field| {
-                    f(cx, field.span, old, field.self_expr.clone(), &field.other_selflike_exprs)
-                })
+                rest.iter().fold(base_expr, op)
             } else {
-                rest.iter().rev().fold(base, |old, field| {
-                    f(cx, field.span, old, field.self_expr.clone(), &field.other_selflike_exprs)
-                })
+                rest.iter().rfold(base_expr, op)
             }
         }
-        EnumNonMatchingCollapsed(tuple) => enum_nonmatch_f(cx, trait_span, tuple),
+        EnumNonMatchingCollapsed(tuple) => f(cx, CsFold::EnumNonMatching(trait_span, tuple)),
         StaticEnum(..) | StaticStruct(..) => cx.span_bug(trait_span, "static function in `derive`"),
     }
 }