about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-09-29 10:20:11 +0000
committerbors <bors@rust-lang.org>2023-09-29 10:20:11 +0000
commit87e2c310f9f1770aa09c1223ae4a5b7b8ce33360 (patch)
tree01f374eb428456ebe9b5bb5e86c96f5fb8e8e4c7
parentf19479a2ad238ad861cfe8d57e63beeccb56169e (diff)
parent1b3e5b2105b20b3237efaac17c4a9761890f6597 (diff)
downloadrust-87e2c310f9f1770aa09c1223ae4a5b7b8ce33360.tar.gz
rust-87e2c310f9f1770aa09c1223ae4a5b7b8ce33360.zip
Auto merge of #15667 - rmehri01:bool_to_enum_top_level, r=Veykril
fix: make bool_to_enum assist create enum at top-level

This pr makes the `bool_to_enum` assist create the `enum` at the next closest module block or at top-level, which fixes a few tricky cases such as with an associated `const` in a trait or module:

```rust
trait Foo {
    const $0BOOL: bool;
}

impl Foo for usize {
    const BOOL: bool = true;
}

fn main() {
    if <usize as Foo>::BOOL {
        println!("foo");
    }
}
```

Which now properly produces:

```rust
#[derive(PartialEq, Eq)]
enum Bool { True, False }

trait Foo {
    const BOOL: Bool;
}

impl Foo for usize {
    const BOOL: Bool = Bool::True;
}

fn main() {
    if <usize as Foo>::BOOL == Bool::True {
        println!("foo");
    }
}
```

I also think it's a bit nicer, especially for local variables, but didn't really know to do it in the first PR :)
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs224
-rw-r--r--crates/ide-assists/src/tests/generated.rs6
2 files changed, 194 insertions, 36 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index 85b0b87d0c9..082839118c5 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -16,7 +16,7 @@ use syntax::{
         edit_in_place::{AttrsOwnerEdit, Indent},
         make, HasName,
     },
-    ted, AstNode, NodeOrToken, SyntaxNode, T,
+    ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
 };
 use text_edit::TextRange;
 
@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
 // ```
 // ->
 // ```
-// fn main() {
-//     #[derive(PartialEq, Eq)]
-//     enum Bool { True, False }
+// #[derive(PartialEq, Eq)]
+// enum Bool { True, False }
 //
+// fn main() {
 //     let bool = Bool::True;
 //
 //     if bool == Bool::True {
@@ -270,6 +270,15 @@ fn replace_usages(
                         }
                         _ => (),
                     }
+                } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
+                {
+                    edit.replace(ty_annotation.syntax().text_range(), "Bool");
+                    replace_bool_expr(edit, initializer);
+                } else if let Some(receiver) = find_method_call_expr_usage(&new_name) {
+                    edit.replace(
+                        receiver.syntax().text_range(),
+                        format!("({} == Bool::True)", receiver),
+                    );
                 } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
                     // for any other usage in an expression, replace it with a check that it is the true variant
                     if let Some((record_field, expr)) = new_name
@@ -413,6 +422,26 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
     }
 }
 
+fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
+    let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
+    if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() {
+        return None;
+    }
+
+    Some((const_.ty()?, const_.body()?))
+}
+
+fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
+    let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
+    let receiver = method_call.receiver()?;
+
+    if !receiver.syntax().descendants().contains(name.syntax()) {
+        return None;
+    }
+
+    Some(receiver)
+}
+
 /// Adds the definition of the new enum before the target node.
 fn add_enum_def(
     edit: &mut SourceChangeBuilder,
@@ -430,11 +459,12 @@ fn add_enum_def(
         .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
     let enum_def = make_bool_enum(make_enum_pub);
 
-    let indent = IndentLevel::from_node(&target_node);
+    let insert_before = node_to_insert_before(target_node);
+    let indent = IndentLevel::from_node(&insert_before);
     enum_def.reindent_to(indent);
 
     ted::insert_all(
-        ted::Position::before(&edit.make_syntax_mut(target_node)),
+        ted::Position::before(&edit.make_syntax_mut(insert_before)),
         vec![
             enum_def.syntax().clone().into(),
             make::tokens::whitespace(&format!("\n\n{indent}")).into(),
@@ -442,6 +472,18 @@ fn add_enum_def(
     );
 }
 
+/// Finds where to put the new enum definition.
+/// Tries to find the ast node at the nearest module or at top-level, otherwise just
+/// returns the input node.
+fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
+    target_node
+        .ancestors()
+        .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
+        .filter(|it| ast::Item::can_cast(it.kind()))
+        .last()
+        .unwrap_or(target_node)
+}
+
 fn make_bool_enum(make_pub: bool) -> ast::Enum {
     let enum_def = make::enum_(
         if make_pub { Some(make::visibility_pub()) } else { None },
@@ -491,10 +533,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo = Bool::True;
 
     if foo == Bool::True {
@@ -520,10 +562,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo = Bool::True;
 
     if foo == Bool::False {
@@ -545,10 +587,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo: Bool = Bool::False;
 }
 "#,
@@ -565,10 +607,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo = if 1 == 2 { Bool::True } else { Bool::False };
 }
 "#,
@@ -590,10 +632,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo = Bool::False;
     let bar = true;
 
@@ -619,10 +661,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo = Bool::True;
 
     if *&foo == Bool::True {
@@ -645,10 +687,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo: Bool;
     foo = Bool::True;
 }
@@ -671,10 +713,10 @@ fn main() {
 }
 "#,
             r#"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let foo = Bool::True;
     let bar = foo == Bool::False;
 
@@ -702,11 +744,11 @@ fn main() {
 }
 "#,
             r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
 fn main() {
     if !"foo".chars().any(|c| {
-        #[derive(PartialEq, Eq)]
-        enum Bool { True, False }
-
         let foo = Bool::True;
         foo == Bool::True
     }) {
@@ -1245,6 +1287,38 @@ fn main() {
     }
 
     #[test]
+    fn field_method_chain_usage() {
+        check_assist(
+            bool_to_enum,
+            r#"
+struct Foo {
+    $0bool: bool,
+}
+
+fn main() {
+    let foo = Foo { bool: true };
+
+    foo.bool.then(|| 2);
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+    bool: Bool,
+}
+
+fn main() {
+    let foo = Foo { bool: Bool::True };
+
+    (foo.bool == Bool::True).then(|| 2);
+}
+"#,
+        )
+    }
+
+    #[test]
     fn field_non_bool() {
         cov_mark::check!(not_applicable_non_bool_field);
         check_assist_not_applicable(
@@ -1446,6 +1520,90 @@ pub mod bar {
     }
 
     #[test]
+    fn const_in_impl_cross_file() {
+        check_assist(
+            bool_to_enum,
+            r#"
+//- /main.rs
+mod foo;
+
+struct Foo;
+
+impl Foo {
+    pub const $0BOOL: bool = true;
+}
+
+//- /foo.rs
+use crate::Foo;
+
+fn foo() -> bool {
+    Foo::BOOL
+}
+"#,
+            r#"
+//- /main.rs
+mod foo;
+
+struct Foo;
+
+#[derive(PartialEq, Eq)]
+pub enum Bool { True, False }
+
+impl Foo {
+    pub const BOOL: Bool = Bool::True;
+}
+
+//- /foo.rs
+use crate::{Foo, Bool};
+
+fn foo() -> bool {
+    Foo::BOOL == Bool::True
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn const_in_trait() {
+        check_assist(
+            bool_to_enum,
+            r#"
+trait Foo {
+    const $0BOOL: bool;
+}
+
+impl Foo for usize {
+    const BOOL: bool = true;
+}
+
+fn main() {
+    if <usize as Foo>::BOOL {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+trait Foo {
+    const BOOL: Bool;
+}
+
+impl Foo for usize {
+    const BOOL: Bool = Bool::True;
+}
+
+fn main() {
+    if <usize as Foo>::BOOL == Bool::True {
+        println!("foo");
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn const_non_bool() {
         cov_mark::check!(not_applicable_non_bool_const);
         check_assist_not_applicable(
diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs
index 63a08a0e569..5a815d5c6a1 100644
--- a/crates/ide-assists/src/tests/generated.rs
+++ b/crates/ide-assists/src/tests/generated.rs
@@ -294,10 +294,10 @@ fn main() {
 }
 "#####,
         r#####"
-fn main() {
-    #[derive(PartialEq, Eq)]
-    enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
 
+fn main() {
     let bool = Bool::True;
 
     if bool == Bool::True {