about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs41
-rw-r--r--crates/ide_assists/src/utils/gen_trait_fn_body.rs27
2 files changed, 56 insertions, 12 deletions
diff --git a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
index 0f5a3843153..4fceefe331d 100644
--- a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
+++ b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -822,6 +822,47 @@ impl PartialOrd for Foo {
     }
 
     #[test]
+    fn add_custom_impl_partial_ord_tuple_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+enum Foo {
+    Bar(String),
+    Baz(String, String),
+    Qux(),
+    Bin,
+}
+"#,
+            r#"
+enum Foo {
+    Bar(String),
+    Baz(String, String),
+    Qux(),
+    Bin,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        match (self, other) {
+            (Self::Bar(l0), Self::Bar(r0)) => l0.partial_cmp(r0),
+            (Self::Baz(l0, l1), Self::Baz(r0, r1)) => {
+                match l0.partial_cmp(r0) {
+                    Some(core::cmp::Ordering::Eq) => {}
+                    ord => return ord,
+                }
+                l1.partial_cmp(r1)
+            }
+            _ => core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)),
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn add_custom_impl_partial_eq_record_struct() {
         check_assist(
             replace_derive_with_manual_impl,
diff --git a/crates/ide_assists/src/utils/gen_trait_fn_body.rs b/crates/ide_assists/src/utils/gen_trait_fn_body.rs
index 9633fd263b6..10b781636f4 100644
--- a/crates/ide_assists/src/utils/gen_trait_fn_body.rs
+++ b/crates/ide_assists/src/utils/gen_trait_fn_body.rs
@@ -574,13 +574,6 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
 }
 
 fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
-    fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
-        match expr {
-            Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
-            None => Some(cmp),
-        }
-    }
-
     fn gen_partial_eq_match(match_target: ast::Expr) -> Option<ast::Stmt> {
         let mut arms = vec![];
 
@@ -683,7 +676,7 @@ fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
                     }
 
                     Some(ast::FieldList::TupleFieldList(list)) => {
-                        let mut expr = None;
+                        let mut exprs = vec![];
                         let mut l_fields = vec![];
                         let mut r_fields = vec![];
 
@@ -698,16 +691,26 @@ fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
 
                             let lhs = make::expr_path(make::ext::ident_path(&l_name));
                             let rhs = make::expr_path(make::ext::ident_path(&r_name));
-                            let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
-                            expr = gen_eq_chain(expr, cmp);
+                            let ord = gen_partial_cmp_call(lhs, rhs);
+                            exprs.push(ord);
                         }
 
                         let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
                         let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
                         let tuple = make::tuple_pat(vec![left.into(), right.into()]);
 
-                        if let Some(expr) = expr {
-                            arms.push(make::match_arm(Some(tuple.into()), None, expr));
+                        if let Some(tail) = exprs.pop() {
+                            let stmts = exprs
+                                .into_iter()
+                                .map(gen_partial_eq_match)
+                                .collect::<Option<Vec<ast::Stmt>>>()?;
+                            let expr = match stmts.len() {
+                                0 => tail,
+                                _ => make::block_expr(stmts.into_iter(), Some(tail))
+                                    .indent(ast::edit::IndentLevel(1))
+                                    .into(),
+                            };
+                            arms.push(make::match_arm(Some(tuple.into()), None, expr.into()));
                         }
                     }
                     None => continue,