about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs129
1 files changed, 112 insertions, 17 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs
index 2aeca0bae0b..0aa23ccc840 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -1,3 +1,4 @@
+use either::Either;
 use hir::{ImportPathConfig, ModuleDef};
 use ide_db::{
     assists::{AssistId, AssistKind},
@@ -97,27 +98,30 @@ struct BoolNodeData {
 fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
     let name: ast::Name = ctx.find_node_at_offset()?;
 
-    if let Some(let_stmt) = name.syntax().ancestors().find_map(ast::LetStmt::cast) {
-        let bind_pat = match let_stmt.pat()? {
-            ast::Pat::IdentPat(pat) => pat,
-            _ => {
-                cov_mark::hit!(not_applicable_in_non_ident_pat);
-                return None;
-            }
-        };
-        let def = ctx.sema.to_def(&bind_pat)?;
+    if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
+        let def = ctx.sema.to_def(&ident_pat)?;
         if !def.ty(ctx.db()).is_bool() {
             cov_mark::hit!(not_applicable_non_bool_local);
             return None;
         }
 
-        Some(BoolNodeData {
-            target_node: let_stmt.syntax().clone(),
-            name,
-            ty_annotation: let_stmt.ty(),
-            initializer: let_stmt.initializer(),
-            definition: Definition::Local(def),
-        })
+        let local_definition = Definition::Local(def);
+        match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
+            Either::Left(param) => Some(BoolNodeData {
+                target_node: param.syntax().clone(),
+                name,
+                ty_annotation: param.ty(),
+                initializer: None,
+                definition: local_definition,
+            }),
+            Either::Right(let_stmt) => Some(BoolNodeData {
+                target_node: let_stmt.syntax().clone(),
+                name,
+                ty_annotation: let_stmt.ty(),
+                initializer: let_stmt.initializer(),
+                definition: local_definition,
+            }),
+        }
     } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
         let def = ctx.sema.to_def(&const_)?;
         if !def.ty(ctx.db()).is_bool() {
@@ -525,6 +529,98 @@ mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
 
     #[test]
+    fn parameter_with_first_param_usage() {
+        check_assist(
+            bool_to_enum,
+            r#"
+fn function($0foo: bool, bar: bool) {
+    if foo {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: Bool, bar: bool) {
+    if foo == Bool::True {
+        println!("foo");
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn parameter_with_last_param_usage() {
+        check_assist(
+            bool_to_enum,
+            r#"
+fn function(foo: bool, $0bar: bool) {
+    if bar {
+        println!("bar");
+    }
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool) {
+    if bar == Bool::True {
+        println!("bar");
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn parameter_with_middle_param_usage() {
+        check_assist(
+            bool_to_enum,
+            r#"
+fn function(foo: bool, $0bar: bool, baz: bool) {
+    if bar {
+        println!("bar");
+    }
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool, baz: bool) {
+    if bar == Bool::True {
+        println!("bar");
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn parameter_with_closure_usage() {
+        check_assist(
+            bool_to_enum,
+            r#"
+fn main() {
+    let foo = |$0bar: bool| bar;
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+    let foo = |bar: Bool| bar == Bool::True;
+}
+"#,
+        )
+    }
+
+    #[test]
     fn local_variable_with_usage() {
         check_assist(
             bool_to_enum,
@@ -791,7 +887,6 @@ fn main() {
 
     #[test]
     fn local_variable_non_ident_pat() {
-        cov_mark::check!(not_applicable_in_non_ident_pat);
         check_assist_not_applicable(
             bool_to_enum,
             r#"