about summary refs log tree commit diff
diff options
context:
space:
mode:
authorTomasz Miąsko <tomasz.miasko@gmail.com>2021-01-25 00:00:00 +0000
committerTomasz Miąsko <tomasz.miasko@gmail.com>2021-02-09 08:15:37 +0100
commite4efccd4a6eef45e9648452360accab48b28f674 (patch)
tree6ea614450c1e1fff32e46b9ad9e37f2b09117f5a
parent921ec4b3fca17cc777766c240038d7d50ba98e0d (diff)
downloadrust-e4efccd4a6eef45e9648452360accab48b28f674.tar.gz
rust-e4efccd4a6eef45e9648452360accab48b28f674.zip
Fix derived PartialOrd operators
The derived implementation of `partial_cmp` compares matching fields one
by one, stopping the computation when the result of a comparison is not
equal to `Some(Equal)`.

On the other hand the derived implementation for `lt`, `le`, `gt` and
`ge` continues the computation when the result of a field comparison is
`None`, consequently those operators are not transitive and inconsistent
with `partial_cmp`.

Fix the inconsistency by using the default implementation that fall-backs
to the `partial_cmp`. This also avoids creating very deeply nested
closures that were quite costly to compile.
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs197
-rw-r--r--src/test/ui/derives/derive-partial-ord.rs60
2 files changed, 67 insertions, 190 deletions
diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs
index 21174ca4c8b..db808bf2ff5 100644
--- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs
@@ -1,13 +1,11 @@
-pub use OrderingOp::*;
-
 use crate::deriving::generic::ty::*;
 use crate::deriving::generic::*;
-use crate::deriving::{path_local, path_std, pathvec_std};
+use crate::deriving::{path_std, pathvec_std};
 
 use rustc_ast::ptr::P;
-use rustc_ast::{self as ast, BinOpKind, Expr, MetaItem};
+use rustc_ast::{Expr, MetaItem};
 use rustc_expand::base::{Annotatable, ExtCtxt};
-use rustc_span::symbol::{sym, Ident, Symbol};
+use rustc_span::symbol::{sym, Ident};
 use rustc_span::Span;
 
 pub fn expand_deriving_partial_ord(
@@ -17,26 +15,6 @@ pub fn expand_deriving_partial_ord(
     item: &Annotatable,
     push: &mut dyn FnMut(Annotatable),
 ) {
-    macro_rules! md {
-        ($name:expr, $op:expr, $equal:expr) => {{
-            let inline = cx.meta_word(span, sym::inline);
-            let attrs = vec![cx.attribute(inline)];
-            MethodDef {
-                name: $name,
-                generics: Bounds::empty(),
-                explicit_self: borrowed_explicit_self(),
-                args: vec![(borrowed_self(), sym::other)],
-                ret_ty: Literal(path_local!(bool)),
-                attributes: attrs,
-                is_unsafe: false,
-                unify_fieldless_variants: true,
-                combine_substructure: combine_substructure(Box::new(|cx, span, substr| {
-                    cs_op($op, $equal, cx, span, substr)
-                })),
-            }
-        }};
-    }
-
     let ordering_ty = Literal(path_std!(cmp::Ordering));
     let ret_ty = Literal(Path::new_(
         pathvec_std!(option::Option),
@@ -62,21 +40,6 @@ pub fn expand_deriving_partial_ord(
         })),
     };
 
-    // avoid defining extra methods if we can
-    // c-like enums, enums without any fields and structs without fields
-    // can safely define only `partial_cmp`.
-    let methods = if is_type_without_fields(item) {
-        vec![partial_cmp_def]
-    } else {
-        vec![
-            partial_cmp_def,
-            md!(sym::lt, true, false),
-            md!(sym::le, true, true),
-            md!(sym::gt, false, false),
-            md!(sym::ge, false, true),
-        ]
-    };
-
     let trait_def = TraitDef {
         span,
         attributes: vec![],
@@ -85,39 +48,12 @@ pub fn expand_deriving_partial_ord(
         generics: Bounds::empty(),
         is_unsafe: false,
         supports_unions: false,
-        methods,
+        methods: vec![partial_cmp_def],
         associated_types: Vec::new(),
     };
     trait_def.expand(cx, mitem, item, push)
 }
 
-#[derive(Copy, Clone)]
-pub enum OrderingOp {
-    PartialCmpOp,
-    LtOp,
-    LeOp,
-    GtOp,
-    GeOp,
-}
-
-pub fn some_ordering_collapsed(
-    cx: &mut ExtCtxt<'_>,
-    span: Span,
-    op: OrderingOp,
-    self_arg_tags: &[Ident],
-) -> P<ast::Expr> {
-    let lft = cx.expr_ident(span, self_arg_tags[0]);
-    let rgt = cx.expr_addr_of(span, cx.expr_ident(span, self_arg_tags[1]));
-    let op_sym = match op {
-        PartialCmpOp => sym::partial_cmp,
-        LtOp => sym::lt,
-        LeOp => sym::le,
-        GtOp => sym::gt,
-        GeOp => sym::ge,
-    };
-    cx.expr_method_call(span, lft, Ident::new(op_sym, span), vec![rgt])
-}
-
 pub fn cs_partial_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> P<Expr> {
     let test_id = Ident::new(sym::cmp, span);
     let ordering = cx.path_global(span, cx.std_path(&[sym::cmp, sym::Ordering, sym::Equal]));
@@ -171,7 +107,9 @@ pub fn cs_partial_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_
             if self_args.len() != 2 {
                 cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`")
             } else {
-                some_ordering_collapsed(cx, span, PartialCmpOp, tag_tuple)
+                let lft = cx.expr_ident(span, tag_tuple[0]);
+                let rgt = cx.expr_addr_of(span, cx.expr_ident(span, tag_tuple[1]));
+                cx.expr_method_call(span, lft, Ident::new(sym::partial_cmp, span), vec![rgt])
             }
         }),
         cx,
@@ -179,124 +117,3 @@ pub fn cs_partial_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_
         substr,
     )
 }
-
-/// Strict inequality.
-fn cs_op(
-    less: bool,
-    inclusive: bool,
-    cx: &mut ExtCtxt<'_>,
-    span: Span,
-    substr: &Substructure<'_>,
-) -> P<Expr> {
-    let ordering_path = |cx: &mut ExtCtxt<'_>, name: &str| {
-        cx.expr_path(
-            cx.path_global(span, cx.std_path(&[sym::cmp, sym::Ordering, Symbol::intern(name)])),
-        )
-    };
-
-    let par_cmp = |cx: &mut ExtCtxt<'_>, span, self_f: P<Expr>, other_fs: &[P<Expr>], default| {
-        let other_f = match other_fs {
-            [o_f] => o_f,
-            _ => cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`"),
-        };
-
-        // `PartialOrd::partial_cmp(self.fi, other.fi)`
-        let cmp_path = cx.expr_path(
-            cx.path_global(span, cx.std_path(&[sym::cmp, sym::PartialOrd, sym::partial_cmp])),
-        );
-        let cmp = cx.expr_call(
-            span,
-            cmp_path,
-            vec![cx.expr_addr_of(span, self_f), cx.expr_addr_of(span, other_f.clone())],
-        );
-
-        let default = ordering_path(cx, default);
-        // `Option::unwrap_or(_, Ordering::Equal)`
-        let unwrap_path = cx.expr_path(
-            cx.path_global(span, cx.std_path(&[sym::option, sym::Option, sym::unwrap_or])),
-        );
-        cx.expr_call(span, unwrap_path, vec![cmp, default])
-    };
-
-    let fold = cs_fold1(
-        false, // need foldr
-        |cx, span, subexpr, self_f, other_fs| {
-            // build up a series of `partial_cmp`s from the inside
-            // out (hence foldr) to get lexical ordering, i.e., for op ==
-            // `ast::lt`
-            //
-            // ```
-            // Ordering::then_with(
-            //    Option::unwrap_or(
-            //        PartialOrd::partial_cmp(self.f1, other.f1), Ordering::Equal)
-            //    ),
-            //    Option::unwrap_or(
-            //        PartialOrd::partial_cmp(self.f2, other.f2), Ordering::Greater)
-            //    )
-            // )
-            // == Ordering::Less
-            // ```
-            //
-            // and for op ==
-            // `ast::le`
-            //
-            // ```
-            // Ordering::then_with(
-            //    Option::unwrap_or(
-            //        PartialOrd::partial_cmp(self.f1, other.f1), Ordering::Equal)
-            //    ),
-            //    Option::unwrap_or(
-            //        PartialOrd::partial_cmp(self.f2, other.f2), Ordering::Greater)
-            //    )
-            // )
-            // != Ordering::Greater
-            // ```
-            //
-            // The optimiser should remove the redundancy. We explicitly
-            // get use the binops to avoid auto-deref dereferencing too many
-            // layers of pointers, if the type includes pointers.
-
-            // `Option::unwrap_or(PartialOrd::partial_cmp(self.fi, other.fi), Ordering::Equal)`
-            let par_cmp = par_cmp(cx, span, self_f, other_fs, "Equal");
-
-            // `Ordering::then_with(Option::unwrap_or(..), ..)`
-            let then_with_path = cx.expr_path(
-                cx.path_global(span, cx.std_path(&[sym::cmp, sym::Ordering, sym::then_with])),
-            );
-            cx.expr_call(span, then_with_path, vec![par_cmp, cx.lambda0(span, subexpr)])
-        },
-        |cx, args| match args {
-            Some((span, self_f, other_fs)) => {
-                let opposite = if less { "Greater" } else { "Less" };
-                par_cmp(cx, span, self_f, other_fs, opposite)
-            }
-            None => cx.expr_bool(span, inclusive),
-        },
-        Box::new(|cx, span, (self_args, tag_tuple), _non_self_args| {
-            if self_args.len() != 2 {
-                cx.span_bug(span, "not exactly 2 arguments in `derive(PartialOrd)`")
-            } else {
-                let op = match (less, inclusive) {
-                    (false, false) => GtOp,
-                    (false, true) => GeOp,
-                    (true, false) => LtOp,
-                    (true, true) => LeOp,
-                };
-                some_ordering_collapsed(cx, span, op, tag_tuple)
-            }
-        }),
-        cx,
-        span,
-        substr,
-    );
-
-    match *substr.fields {
-        EnumMatching(.., ref all_fields) | Struct(.., ref all_fields) if !all_fields.is_empty() => {
-            let ordering = ordering_path(cx, if less ^ inclusive { "Less" } else { "Greater" });
-            let comp_op = if inclusive { BinOpKind::Ne } else { BinOpKind::Eq };
-
-            cx.expr_binary(span, comp_op, fold, ordering)
-        }
-        _ => fold,
-    }
-}
diff --git a/src/test/ui/derives/derive-partial-ord.rs b/src/test/ui/derives/derive-partial-ord.rs
new file mode 100644
index 00000000000..9078a7ffa4f
--- /dev/null
+++ b/src/test/ui/derives/derive-partial-ord.rs
@@ -0,0 +1,60 @@
+// Checks that in a derived implementation of PartialOrd the lt, le, ge, gt methods are consistent
+// with partial_cmp. Also verifies that implementation is consistent with that for tuples.
+//
+// run-pass
+
+#[derive(PartialEq, PartialOrd)]
+struct P(f64, f64);
+
+fn main() {
+    let values: &[f64] = &[1.0, 2.0, f64::NAN];
+    for a in values {
+        for b in values {
+            for c in values {
+                for d in values {
+                    // Check impl for a tuple.
+                    check(&(*a, *b), &(*c, *d));
+
+                    // Check derived impl.
+                    check(&P(*a, *b), &P(*c, *d));
+
+                    // Check that impls agree with each other.
+                    assert_eq!(
+                        PartialOrd::partial_cmp(&(*a, *b), &(*c, *d)),
+                        PartialOrd::partial_cmp(&P(*a, *b), &P(*c, *d)),
+                    );
+                }
+            }
+        }
+    }
+}
+
+fn check<T: PartialOrd>(a: &T, b: &T) {
+    use std::cmp::Ordering::*;
+    match PartialOrd::partial_cmp(a, b) {
+        None => {
+            assert!(!(a < b));
+            assert!(!(a <= b));
+            assert!(!(a > b));
+            assert!(!(a >= b));
+        }
+        Some(Equal) => {
+            assert!(!(a < b));
+            assert!(a <= b);
+            assert!(!(a > b));
+            assert!(a >= b);
+        }
+        Some(Less) => {
+            assert!(a < b);
+            assert!(a <= b);
+            assert!(!(a > b));
+            assert!(!(a >= b));
+        }
+        Some(Greater) => {
+            assert!(!(a < b));
+            assert!(!(a <= b));
+            assert!(a > b);
+            assert!(a >= b);
+        }
+    }
+}