about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2021-10-12 17:32:58 +0000
committerGitHub <noreply@github.com>2021-10-12 17:32:58 +0000
commita871da36937f427624883860eafb580d6349f8da (patch)
tree3af854868c6ca4494d3ee2b298bea3913e0f71f0
parentd56c8796d6521cfacb10d31c09fcc246cf511dbd (diff)
parent601ed3a10dacc2ba2ee0ca436c23529ae7fde292 (diff)
downloadrust-a871da36937f427624883860eafb580d6349f8da.tar.gz
rust-a871da36937f427624883860eafb580d6349f8da.zip
Merge #10529
10529: Generate `PartialOrd` implementations r=Veykril a=yoshuawuyts

_co-authored with `@rylev_`

This closes #5946 (which should've been closed already, lol). This PR makes it so we generate `PartialOrd` code implementations where possible. This is the last of Rust's built-in traits that was missing codegen.

After this has been merged we should look at moving the tests to a better spot, and maybe cleaning up the implementation somewhat (it's rather copy-pasty at the moment).

Either way, this finishes up the functionality. Thanks heaps!

Co-authored-by: Yoshua Wuyts <yoshuawuyts@gmail.com>
-rw-r--r--crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs188
-rw-r--r--crates/ide_assists/src/utils/gen_trait_fn_body.rs195
2 files changed, 383 insertions, 0 deletions
diff --git a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
index c0b7db332e2..b04bd6ba098 100644
--- a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
+++ b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -676,6 +676,194 @@ impl Clone for Foo {
     }
 
     #[test]
+    fn add_custom_impl_partial_ord_record_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+struct Foo {
+    bin: usize,
+}
+"#,
+            r#"
+struct Foo {
+    bin: usize,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        self.bin.partial_cmp(other.bin)
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_record_struct_multi_field() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+struct Foo {
+    bin: usize,
+    bar: usize,
+    baz: usize,
+}
+"#,
+            r#"
+struct Foo {
+    bin: usize,
+    bar: usize,
+    baz: usize,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        (self.bin, self.bar, self.baz).partial_cmp((other.bin, other.bar, other.baz))
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_tuple_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+struct Foo(usize, usize, usize);
+"#,
+            r#"
+struct Foo(usize, usize, usize);
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        (self.0, self.1, self.2).partial_cmp((other.0, other.1, other.2))
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+enum Foo {
+    Bin,
+    Bar,
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bin,
+    Bar,
+    Baz,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other))
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_record_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+enum Foo {
+    Bar {
+        bin: String,
+    },
+    Baz {
+        qux: String,
+        fez: String,
+    },
+    Qux {},
+    Bin,
+}
+"#,
+            r#"
+enum Foo {
+    Bar {
+        bin: String,
+    },
+    Baz {
+        qux: String,
+        fez: String,
+    },
+    Qux {},
+    Bin,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        match (self, other) {
+            (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin.partial_cmp(r_bin),
+            (Self::Baz { qux: l_qux, fez: l_fez }, Self::Baz { qux: r_qux, fez: r_fez }) => {
+                (l_qux, l_fez).partial_cmp((r_qux, r_fez))
+            }
+            _ => core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)),
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_tuple_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+enum Foo {
+    Bar(String),
+    Baz(String, String),
+    Qux(),
+    Bin,
+}
+"#,
+            r#"
+enum Foo {
+    Bar(String),
+    Baz(String, String),
+    Qux(),
+    Bin,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        match (self, other) {
+            (Self::Bar(l0), Self::Bar(r0)) => l0.partial_cmp(r0),
+            (Self::Baz(l0, l1), Self::Baz(r0, r1)) => {
+                (l0, l1).partial_cmp((r0, r1))
+            }
+            _ => core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)),
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn add_custom_impl_partial_eq_record_struct() {
         check_assist(
             replace_derive_with_manual_impl,
diff --git a/crates/ide_assists/src/utils/gen_trait_fn_body.rs b/crates/ide_assists/src/utils/gen_trait_fn_body.rs
index 6915460209b..c883e6fb11b 100644
--- a/crates/ide_assists/src/utils/gen_trait_fn_body.rs
+++ b/crates/ide_assists/src/utils/gen_trait_fn_body.rs
@@ -21,6 +21,7 @@ pub(crate) fn gen_trait_fn_body(
         "Default" => gen_default_impl(adt, func),
         "Hash" => gen_hash_impl(adt, func),
         "PartialEq" => gen_partial_eq(adt, func),
+        "PartialOrd" => gen_partial_ord(adt, func),
         _ => None,
     }
 }
@@ -572,6 +573,200 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
     Some(())
 }
 
+fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
+    fn gen_partial_cmp_call(lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr {
+        let method = make::name_ref("partial_cmp");
+        make::expr_method_call(lhs, method, make::arg_list(Some(rhs)))
+    }
+    fn gen_partial_cmp_call2(mut lhs: Vec<ast::Expr>, mut rhs: Vec<ast::Expr>) -> ast::Expr {
+        let (lhs, rhs) = match (lhs.len(), rhs.len()) {
+            (1, 1) => (lhs.pop().unwrap(), rhs.pop().unwrap()),
+            _ => (make::expr_tuple(lhs.into_iter()), make::expr_tuple(rhs.into_iter())),
+        };
+        let method = make::name_ref("partial_cmp");
+        make::expr_method_call(lhs, method, make::arg_list(Some(rhs)))
+    }
+
+    fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
+        let pat = make::ext::simple_ident_pat(make::name(&pat_name));
+        let name_ref = make::name_ref(field_name);
+        make::record_pat_field(name_ref, pat.into())
+    }
+
+    fn gen_record_pat(record_name: ast::Path, fields: Vec<ast::RecordPatField>) -> ast::RecordPat {
+        let list = make::record_pat_field_list(fields);
+        make::record_pat_with_fields(record_name, list)
+    }
+
+    fn gen_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
+        make::ext::path_from_idents(["Self", &variant.name()?.to_string()])
+    }
+
+    fn gen_tuple_field(field_name: &String) -> ast::Pat {
+        ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
+    }
+
+    // FIXME: return `None` if the trait carries a generic type; we can only
+    // generate this code `Self` for the time being.
+
+    let body = match adt {
+        // `Hash` cannot be derived for unions, so no default impl can be provided.
+        ast::Adt::Union(_) => return None,
+
+        ast::Adt::Enum(enum_) => {
+            // => std::mem::discriminant(self) == std::mem::discriminant(other)
+            let lhs_name = make::expr_path(make::ext::ident_path("self"));
+            let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone())));
+            let rhs_name = make::expr_path(make::ext::ident_path("other"));
+            let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone())));
+            let ord_check = gen_partial_cmp_call(lhs, rhs);
+
+            let mut case_count = 0;
+            let mut arms = vec![];
+            for variant in enum_.variant_list()?.variants() {
+                case_count += 1;
+                match variant.field_list() {
+                    // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
+                    Some(ast::FieldList::RecordFieldList(list)) => {
+                        let mut l_pat_fields = vec![];
+                        let mut r_pat_fields = vec![];
+                        let mut l_fields = vec![];
+                        let mut r_fields = vec![];
+
+                        for field in list.fields() {
+                            let field_name = field.name()?.to_string();
+
+                            let l_name = &format!("l_{}", field_name);
+                            l_pat_fields.push(gen_record_pat_field(&field_name, &l_name));
+
+                            let r_name = &format!("r_{}", field_name);
+                            r_pat_fields.push(gen_record_pat_field(&field_name, &r_name));
+
+                            let lhs = make::expr_path(make::ext::ident_path(l_name));
+                            let rhs = make::expr_path(make::ext::ident_path(r_name));
+                            l_fields.push(lhs);
+                            r_fields.push(rhs);
+                        }
+
+                        let left_pat = gen_record_pat(gen_variant_path(&variant)?, l_pat_fields);
+                        let right_pat = gen_record_pat(gen_variant_path(&variant)?, r_pat_fields);
+                        let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]);
+
+                        let len = l_fields.len();
+                        if len != 0 {
+                            let mut expr = gen_partial_cmp_call2(l_fields, r_fields);
+                            if len >= 2 {
+                                expr = make::block_expr(None, Some(expr))
+                                    .indent(ast::edit::IndentLevel(1))
+                                    .into();
+                            }
+                            arms.push(make::match_arm(Some(tuple_pat.into()), None, expr));
+                        }
+                    }
+
+                    Some(ast::FieldList::TupleFieldList(list)) => {
+                        let mut l_pat_fields = vec![];
+                        let mut r_pat_fields = vec![];
+                        let mut l_fields = vec![];
+                        let mut r_fields = vec![];
+
+                        for (i, _) in list.fields().enumerate() {
+                            let field_name = format!("{}", i);
+
+                            let l_name = format!("l{}", field_name);
+                            l_pat_fields.push(gen_tuple_field(&l_name));
+
+                            let r_name = format!("r{}", field_name);
+                            r_pat_fields.push(gen_tuple_field(&r_name));
+
+                            let lhs = make::expr_path(make::ext::ident_path(&l_name));
+                            let rhs = make::expr_path(make::ext::ident_path(&r_name));
+                            l_fields.push(lhs);
+                            r_fields.push(rhs);
+                        }
+
+                        let left_pat =
+                            make::tuple_struct_pat(gen_variant_path(&variant)?, l_pat_fields);
+                        let right_pat =
+                            make::tuple_struct_pat(gen_variant_path(&variant)?, r_pat_fields);
+                        let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]);
+
+                        let len = l_fields.len();
+                        if len != 0 {
+                            let mut expr = gen_partial_cmp_call2(l_fields, r_fields);
+                            if len >= 2 {
+                                expr = make::block_expr(None, Some(expr))
+                                    .indent(ast::edit::IndentLevel(1))
+                                    .into();
+                            }
+                            arms.push(make::match_arm(Some(tuple_pat.into()), None, expr));
+                        }
+                    }
+                    None => continue,
+                }
+            }
+
+            let expr = match arms.len() {
+                0 => ord_check,
+                _ => {
+                    if case_count > arms.len() {
+                        let lhs = make::wildcard_pat().into();
+                        arms.push(make::match_arm(Some(lhs), None, ord_check));
+                    }
+
+                    let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
+                    let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
+                    make::expr_match(match_target, list)
+                }
+            };
+
+            make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
+        }
+        ast::Adt::Struct(strukt) => match strukt.field_list() {
+            Some(ast::FieldList::RecordFieldList(field_list)) => {
+                let mut l_fields = vec![];
+                let mut r_fields = vec![];
+                for field in field_list.fields() {
+                    let lhs = make::expr_path(make::ext::ident_path("self"));
+                    let lhs = make::expr_field(lhs, &field.name()?.to_string());
+                    let rhs = make::expr_path(make::ext::ident_path("other"));
+                    let rhs = make::expr_field(rhs, &field.name()?.to_string());
+                    l_fields.push(lhs);
+                    r_fields.push(rhs);
+                }
+
+                let expr = gen_partial_cmp_call2(l_fields, r_fields);
+                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
+            }
+
+            Some(ast::FieldList::TupleFieldList(field_list)) => {
+                let mut l_fields = vec![];
+                let mut r_fields = vec![];
+                for (i, _) in field_list.fields().enumerate() {
+                    let idx = format!("{}", i);
+                    let lhs = make::expr_path(make::ext::ident_path("self"));
+                    let lhs = make::expr_field(lhs, &idx);
+                    let rhs = make::expr_path(make::ext::ident_path("other"));
+                    let rhs = make::expr_field(rhs, &idx);
+                    l_fields.push(lhs);
+                    r_fields.push(rhs);
+                }
+                let expr = gen_partial_cmp_call2(l_fields, r_fields);
+                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
+            }
+
+            // No fields in the body means there's nothing to hash.
+            None => {
+                let expr = make::expr_literal("true").into();
+                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
+            }
+        },
+    };
+
+    ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
+    Some(())
+}
+
 fn make_discriminant() -> Option<ast::Expr> {
     Some(make::expr_path(make::ext::path_from_idents(["core", "mem", "discriminant"])?))
 }