about summary refs log tree commit diff
path: root/src/libsyntax_ext
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2018-04-15 03:54:15 +0000
committerbors <bors@rust-lang.org>2018-04-15 03:54:15 +0000
commitbc001fa07f1e44f88b59c74290a2dd916824d33c (patch)
treea55b1eb61a3bf7d791bbe40f2833015679abf36e /src/libsyntax_ext
parentd4d43e248340b6acaf02f4439713c160fd77a846 (diff)
parent105c5180941f4034fd0d576a1d4c1bb71dd8e077 (diff)
downloadrust-bc001fa07f1e44f88b59c74290a2dd916824d33c.tar.gz
rust-bc001fa07f1e44f88b59c74290a2dd916824d33c.zip
Auto merge of #49881 - varkor:partialord-opt, r=Manishearth
Fix derive(PartialOrd) and optimise final field operation

```rust
// Before (`lt` on 2-field struct)
self.f1 < other.f1 || (!(other.f1 < self.f1) &&
(self.f2 < other.f2 || (!(other.f2 < self.f2) &&
(false)
))
)

// After
self.f1 < other.f1 || (!(other.f1 < self.f1) &&
self.f2 < other.f2
)

// Before (`le` on 2-field struct)
self.f1 < other.f1 || (!(other.f1 < self.f1) &&
(self.f2 < other.f2 || (!(other.f2 < self.f2) &&
(true)
))
)

// After
self.f1 < other.f1 || (self.f1 == other.f1 &&
self.f2 <= other.f2
)
```

(The big diff is mainly because of a past faulty rustfmt application that I corrected 😒)

Fixes #49650 and fixes #49505.
Diffstat (limited to 'src/libsyntax_ext')
-rw-r--r--src/libsyntax_ext/deriving/cmp/partial_eq.rs61
-rw-r--r--src/libsyntax_ext/deriving/cmp/partial_ord.rs126
-rw-r--r--src/libsyntax_ext/deriving/generic/mod.rs116
3 files changed, 212 insertions, 91 deletions
diff --git a/src/libsyntax_ext/deriving/cmp/partial_eq.rs b/src/libsyntax_ext/deriving/cmp/partial_eq.rs
index 75db7cc1e4c..81ca7e73228 100644
--- a/src/libsyntax_ext/deriving/cmp/partial_eq.rs
+++ b/src/libsyntax_ext/deriving/cmp/partial_eq.rs
@@ -26,41 +26,48 @@ pub fn expand_deriving_partial_eq(cx: &mut ExtCtxt,
                                   push: &mut FnMut(Annotatable)) {
     // structures are equal if all fields are equal, and non equal, if
     // any fields are not equal or if the enum variants are different
-    fn cs_eq(cx: &mut ExtCtxt, span: Span, substr: &Substructure) -> P<Expr> {
-        cs_fold(true, // use foldl
-                |cx, span, subexpr, self_f, other_fs| {
+    fn cs_op(cx: &mut ExtCtxt,
+             span: Span,
+             substr: &Substructure,
+             op: BinOpKind,
+             combiner: BinOpKind,
+             base: bool)
+             -> P<Expr>
+    {
+        let op = |cx: &mut ExtCtxt, span: Span, self_f: P<Expr>, other_fs: &[P<Expr>]| {
             let other_f = match (other_fs.len(), other_fs.get(0)) {
                 (1, Some(o_f)) => o_f,
                 _ => cx.span_bug(span, "not exactly 2 arguments in `derive(PartialEq)`"),
             };
 
-            let eq = cx.expr_binary(span, BinOpKind::Eq, self_f, other_f.clone());
+            cx.expr_binary(span, op, self_f, other_f.clone())
+        };
 
-            cx.expr_binary(span, BinOpKind::And, subexpr, eq)
-        },
-                cx.expr_bool(span, true),
-                Box::new(|cx, span, _, _| cx.expr_bool(span, false)),
-                cx,
-                span,
-                substr)
+        cs_fold1(true, // use foldl
+            |cx, span, subexpr, self_f, other_fs| {
+                let eq = op(cx, span, self_f, other_fs);
+                cx.expr_binary(span, combiner, subexpr, eq)
+            },
+            |cx, args| {
+                match args {
+                    Some((span, self_f, other_fs)) => {
+                        // Special-case the base case to generate cleaner code.
+                        op(cx, span, self_f, other_fs)
+                    }
+                    None => cx.expr_bool(span, base),
+                }
+            },
+            Box::new(|cx, span, _, _| cx.expr_bool(span, !base)),
+            cx,
+            span,
+            substr)
     }
-    fn cs_ne(cx: &mut ExtCtxt, span: Span, substr: &Substructure) -> P<Expr> {
-        cs_fold(true, // use foldl
-                |cx, span, subexpr, self_f, other_fs| {
-            let other_f = match (other_fs.len(), other_fs.get(0)) {
-                (1, Some(o_f)) => o_f,
-                _ => cx.span_bug(span, "not exactly 2 arguments in `derive(PartialEq)`"),
-            };
-
-            let eq = cx.expr_binary(span, BinOpKind::Ne, self_f, other_f.clone());
 
-            cx.expr_binary(span, BinOpKind::Or, subexpr, eq)
-        },
-                cx.expr_bool(span, false),
-                Box::new(|cx, span, _, _| cx.expr_bool(span, true)),
-                cx,
-                span,
-                substr)
+    fn cs_eq(cx: &mut ExtCtxt, span: Span, substr: &Substructure) -> P<Expr> {
+        cs_op(cx, span, substr, BinOpKind::Eq, BinOpKind::And, true)
+    }
+    fn cs_ne(cx: &mut ExtCtxt, span: Span, substr: &Substructure) -> P<Expr> {
+        cs_op(cx, span, substr, BinOpKind::Ne, BinOpKind::Or, false)
     }
 
     macro_rules! md {
diff --git a/src/libsyntax_ext/deriving/cmp/partial_ord.rs b/src/libsyntax_ext/deriving/cmp/partial_ord.rs
index 92183c58eb2..9560fd0570a 100644
--- a/src/libsyntax_ext/deriving/cmp/partial_ord.rs
+++ b/src/libsyntax_ext/deriving/cmp/partial_ord.rs
@@ -190,54 +190,86 @@ pub fn cs_partial_cmp(cx: &mut ExtCtxt, span: Span, substr: &Substructure) -> P<
 
 /// Strict inequality.
 fn cs_op(less: bool, equal: bool, cx: &mut ExtCtxt, span: Span, substr: &Substructure) -> P<Expr> {
-    let op = if less { BinOpKind::Lt } else { BinOpKind::Gt };
-    cs_fold(false, // need foldr,
-            |cx, span, subexpr, self_f, other_fs| {
-        // build up a series of chain ||'s and &&'s from the inside
-        // out (hence foldr) to get lexical ordering, i.e. for op ==
-        // `ast::lt`
-        //
-        // ```
-        // self.f1 < other.f1 || (!(other.f1 < self.f1) &&
-        // (self.f2 < other.f2 || (!(other.f2 < self.f2) &&
-        // (false)
-        // ))
-        // )
-        // ```
-        //
-        // The optimiser should remove the redundancy. We explicitly
-        // get use the binops to avoid auto-deref dereferencing too many
-        // layers of pointers, if the type includes pointers.
-        //
-        let other_f = match (other_fs.len(), other_fs.get(0)) {
-            (1, Some(o_f)) => o_f,
-            _ => cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`"),
-        };
-
-        let cmp = cx.expr_binary(span, op, self_f.clone(), other_f.clone());
+    let strict_op = if less { BinOpKind::Lt } else { BinOpKind::Gt };
+    cs_fold1(false, // need foldr,
+        |cx, span, subexpr, self_f, other_fs| {
+            // build up a series of chain ||'s and &&'s from the inside
+            // out (hence foldr) to get lexical ordering, i.e. for op ==
+            // `ast::lt`
+            //
+            // ```
+            // self.f1 < other.f1 || (!(other.f1 < self.f1) &&
+            // self.f2 < other.f2
+            // )
+            // ```
+            //
+            // and for op ==
+            // `ast::le`
+            //
+            // ```
+            // self.f1 < other.f1 || (self.f1 == other.f1 &&
+            // self.f2 <= other.f2
+            // )
+            // ```
+            //
+            // The optimiser should remove the redundancy. We explicitly
+            // get use the binops to avoid auto-deref dereferencing too many
+            // layers of pointers, if the type includes pointers.
+            //
+            let other_f = match (other_fs.len(), other_fs.get(0)) {
+                (1, Some(o_f)) => o_f,
+                _ => cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`"),
+            };
 
-        let not_cmp = cx.expr_unary(span,
-                                    ast::UnOp::Not,
-                                    cx.expr_binary(span, op, other_f.clone(), self_f));
+            let strict_ineq = cx.expr_binary(span, strict_op, self_f.clone(), other_f.clone());
 
-        let and = cx.expr_binary(span, BinOpKind::And, not_cmp, subexpr);
-        cx.expr_binary(span, BinOpKind::Or, cmp, and)
-    },
-            cx.expr_bool(span, equal),
-            Box::new(|cx, span, (self_args, tag_tuple), _non_self_args| {
-        if self_args.len() != 2 {
-            cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`")
-        } else {
-            let op = match (less, equal) {
-                (true, true) => LeOp,
-                (true, false) => LtOp,
-                (false, true) => GeOp,
-                (false, false) => GtOp,
+            let deleg_cmp = if !equal {
+                cx.expr_unary(span,
+                            ast::UnOp::Not,
+                            cx.expr_binary(span, strict_op, other_f.clone(), self_f))
+            } else {
+                cx.expr_binary(span, BinOpKind::Eq, self_f, other_f.clone())
             };
-            some_ordering_collapsed(cx, span, op, tag_tuple)
-        }
-    }),
-            cx,
-            span,
-            substr)
+
+            let and = cx.expr_binary(span, BinOpKind::And, deleg_cmp, subexpr);
+            cx.expr_binary(span, BinOpKind::Or, strict_ineq, and)
+        },
+        |cx, args| {
+            match args {
+                Some((span, self_f, other_fs)) => {
+                    // Special-case the base case to generate cleaner code with
+                    // fewer operations (e.g. `<=` instead of `<` and `==`).
+                    let other_f = match (other_fs.len(), other_fs.get(0)) {
+                        (1, Some(o_f)) => o_f,
+                        _ => cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`"),
+                    };
+
+                    let op = match (less, equal) {
+                        (false, false) => BinOpKind::Gt,
+                        (false, true) => BinOpKind::Ge,
+                        (true, false) => BinOpKind::Lt,
+                        (true, true) => BinOpKind::Le,
+                    };
+
+                    cx.expr_binary(span, op, self_f, other_f.clone())
+                }
+                None => cx.expr_bool(span, equal)
+            }
+        },
+        Box::new(|cx, span, (self_args, tag_tuple), _non_self_args| {
+            if self_args.len() != 2 {
+                cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`")
+            } else {
+                let op = match (less, equal) {
+                    (false, false) => GtOp,
+                    (false, true) => GeOp,
+                    (true, false) => LtOp,
+                    (true, true) => LeOp,
+                };
+                some_ordering_collapsed(cx, span, op, tag_tuple)
+            }
+        }),
+        cx,
+        span,
+        substr)
 }
diff --git a/src/libsyntax_ext/deriving/generic/mod.rs b/src/libsyntax_ext/deriving/generic/mod.rs
index 66053e037e1..1f80385cfbd 100644
--- a/src/libsyntax_ext/deriving/generic/mod.rs
+++ b/src/libsyntax_ext/deriving/generic/mod.rs
@@ -1680,12 +1680,55 @@ impl<'a> TraitDef<'a> {
 
 // helpful premade recipes
 
+pub fn cs_fold_fields<'a, F>(use_foldl: bool,
+                             mut f: F,
+                             base: P<Expr>,
+                             cx: &mut ExtCtxt,
+                             all_fields: &[FieldInfo<'a>])
+                             -> P<Expr>
+    where F: FnMut(&mut ExtCtxt, Span, P<Expr>, P<Expr>, &[P<Expr>]) -> P<Expr>
+{
+    if use_foldl {
+        all_fields.iter().fold(base, |old, field| {
+            f(cx, field.span, old, field.self_.clone(), &field.other)
+        })
+    } else {
+        all_fields.iter().rev().fold(base, |old, field| {
+            f(cx, field.span, old, field.self_.clone(), &field.other)
+        })
+    }
+}
+
+pub fn cs_fold_enumnonmatch(mut enum_nonmatch_f: EnumNonMatchCollapsedFunc,
+                            cx: &mut ExtCtxt,
+                            trait_span: Span,
+                            substructure: &Substructure)
+                            -> P<Expr>
+{
+    match *substructure.fields {
+        EnumNonMatchingCollapsed(ref all_args, _, tuple) => {
+            enum_nonmatch_f(cx,
+                            trait_span,
+                            (&all_args[..], tuple),
+                            substructure.nonself_args)
+        }
+        _ => cx.span_bug(trait_span, "cs_fold_enumnonmatch expected an EnumNonMatchingCollapsed")
+    }
+}
+
+pub fn cs_fold_static(cx: &mut ExtCtxt,
+                      trait_span: Span)
+                      -> P<Expr>
+{
+    cx.span_bug(trait_span, "static function in `derive`")
+}
+
 /// Fold the fields. `use_foldl` controls whether this is done
 /// left-to-right (`true`) or right-to-left (`false`).
 pub fn cs_fold<F>(use_foldl: bool,
-                  mut f: F,
+                  f: F,
                   base: P<Expr>,
-                  mut enum_nonmatch_f: EnumNonMatchCollapsedFunc,
+                  enum_nonmatch_f: EnumNonMatchCollapsedFunc,
                   cx: &mut ExtCtxt,
                   trait_span: Span,
                   substructure: &Substructure)
@@ -1695,26 +1738,65 @@ pub fn cs_fold<F>(use_foldl: bool,
     match *substructure.fields {
         EnumMatching(.., ref all_fields) |
         Struct(_, ref all_fields) => {
-            if use_foldl {
-                all_fields.iter().fold(base, |old, field| {
-                    f(cx, field.span, old, field.self_.clone(), &field.other)
-                })
-            } else {
-                all_fields.iter().rev().fold(base, |old, field| {
-                    f(cx, field.span, old, field.self_.clone(), &field.other)
-                })
-            }
+            cs_fold_fields(use_foldl, f, base, cx, all_fields)
         }
-        EnumNonMatchingCollapsed(ref all_args, _, tuple) => {
-            enum_nonmatch_f(cx,
-                            trait_span,
-                            (&all_args[..], tuple),
-                            substructure.nonself_args)
+        EnumNonMatchingCollapsed(..) => {
+            cs_fold_enumnonmatch(enum_nonmatch_f, cx, trait_span, substructure)
+        }
+        StaticEnum(..) | StaticStruct(..) => {
+            cs_fold_static(cx, trait_span)
         }
-        StaticEnum(..) | StaticStruct(..) => cx.span_bug(trait_span, "static function in `derive`"),
     }
 }
 
+/// Function to fold over fields, with three cases, to generate more efficient and concise code.
+/// When the `substructure` has grouped fields, there are two cases:
+/// Zero fields: call the base case function with None (like the usual base case of `cs_fold`).
+/// One or more fields: call the base case function on the first value (which depends on
+/// `use_fold`), and use that as the base case. Then perform `cs_fold` on the remainder of the
+/// fields.
+/// When the `substructure` is a `EnumNonMatchingCollapsed`, the result of `enum_nonmatch_f`
+/// is returned. Statics may not be folded over.
+/// See `cs_op` in `partial_ord.rs` for a model example.
+pub fn cs_fold1<F, B>(use_foldl: bool,
+                      f: F,
+                      mut b: B,
+                      enum_nonmatch_f: EnumNonMatchCollapsedFunc,
+                      cx: &mut ExtCtxt,
+                      trait_span: Span,
+                      substructure: &Substructure)
+                      -> P<Expr>
+    where F: FnMut(&mut ExtCtxt, Span, P<Expr>, P<Expr>, &[P<Expr>]) -> P<Expr>,
+          B: FnMut(&mut ExtCtxt, Option<(Span, P<Expr>, &[P<Expr>])>) -> P<Expr>
+{
+    match *substructure.fields {
+        EnumMatching(.., ref all_fields) |
+        Struct(_, ref all_fields) => {
+            let (base, all_fields) = match (all_fields.is_empty(), use_foldl) {
+                (false, true) => {
+                    let field = &all_fields[0];
+                    let args = (field.span, field.self_.clone(), &field.other[..]);
+                    (b(cx, Some(args)), &all_fields[1..])
+                }
+                (false, false) => {
+                    let idx = all_fields.len() - 1;
+                    let field = &all_fields[idx];
+                    let args = (field.span, field.self_.clone(), &field.other[..]);
+                    (b(cx, Some(args)), &all_fields[..idx])
+                }
+                (true, _) => (b(cx, None), &all_fields[..])
+            };
+
+            cs_fold_fields(use_foldl, f, base, cx, all_fields)
+        }
+        EnumNonMatchingCollapsed(..) => {
+            cs_fold_enumnonmatch(enum_nonmatch_f, cx, trait_span, substructure)
+        }
+        StaticEnum(..) | StaticStruct(..) => {
+            cs_fold_static(cx, trait_span)
+        }
+    }
+}
 
 /// Call the method that is being derived on all the fields, and then
 /// process the collected results. i.e.