about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyan Mehri <rmehri01@DESKTOP-QEEKTG9.hitronhub.home>2023-09-08 10:06:17 -0700
committerRyan Mehri <ryan.mehri1@gmail.com>2023-09-09 11:59:59 -0700
commit455dacfd3b5387bcf2854f2a88edb9b69361e69f (patch)
treee284040cb7cb6718fe0f5091b8221f51cf9cd488
parent91ac1d619475e1b61bf4ae8d318c4740a0adce66 (diff)
downloadrust-455dacfd3b5387bcf2854f2a88edb9b69361e69f.tar.gz
rust-455dacfd3b5387bcf2854f2a88edb9b69361e69f.zip
fix: only trigger assist on Name
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs41
1 files changed, 29 insertions, 12 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index 4158b75dc00..56749edf463 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -85,7 +85,9 @@ struct BoolNodeData {
 
 /// Attempts to find an appropriate node to apply the action to.
 fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
-    if let Some(let_stmt) = ctx.find_node_at_offset::<ast::LetStmt>() {
+    let name: ast::Name = ctx.find_node_at_offset()?;
+
+    if let Some(let_stmt) = name.syntax().ancestors().find_map(ast::LetStmt::cast) {
         let bind_pat = match let_stmt.pat()? {
             ast::Pat::IdentPat(pat) => pat,
             _ => {
@@ -101,12 +103,12 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
 
         Some(BoolNodeData {
             target_node: let_stmt.syntax().clone(),
-            name: bind_pat.name()?,
+            name,
             ty_annotation: let_stmt.ty(),
             initializer: let_stmt.initializer(),
             definition: Definition::Local(def),
         })
-    } else if let Some(const_) = ctx.find_node_at_offset::<ast::Const>() {
+    } else if let Some(const_) = name.syntax().ancestors().find_map(ast::Const::cast) {
         let def = ctx.sema.to_def(&const_)?;
         if !def.ty(ctx.db()).is_bool() {
             cov_mark::hit!(not_applicable_non_bool_const);
@@ -115,12 +117,12 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
 
         Some(BoolNodeData {
             target_node: const_.syntax().clone(),
-            name: const_.name()?,
+            name,
             ty_annotation: const_.ty(),
             initializer: const_.body(),
             definition: Definition::Const(def),
         })
-    } else if let Some(static_) = ctx.find_node_at_offset::<ast::Static>() {
+    } else if let Some(static_) = name.syntax().ancestors().find_map(ast::Static::cast) {
         let def = ctx.sema.to_def(&static_)?;
         if !def.ty(ctx.db()).is_bool() {
             cov_mark::hit!(not_applicable_non_bool_static);
@@ -129,14 +131,14 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
 
         Some(BoolNodeData {
             target_node: static_.syntax().clone(),
-            name: static_.name()?,
+            name,
             ty_annotation: static_.ty(),
             initializer: static_.body(),
             definition: Definition::Static(def),
         })
-    } else if let Some(field_name) = ctx.find_node_at_offset::<ast::Name>() {
-        let field = field_name.syntax().ancestors().find_map(ast::RecordField::cast)?;
-        if field.name()? != field_name {
+    } else {
+        let field = name.syntax().ancestors().find_map(ast::RecordField::cast)?;
+        if field.name()? != name {
             return None;
         }
 
@@ -148,13 +150,11 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
         }
         Some(BoolNodeData {
             target_node: strukt.syntax().clone(),
-            name: field_name,
+            name,
             ty_annotation: field.ty(),
             initializer: None,
             definition: Definition::Field(def),
         })
-    } else {
-        None
     }
 }
 
@@ -529,6 +529,18 @@ fn main() {
     }
 
     #[test]
+    fn local_variable_cursor_not_on_ident() {
+        check_assist_not_applicable(
+            bool_to_enum,
+            r#"
+fn main() {
+    let foo = $0true;
+}
+"#,
+        )
+    }
+
+    #[test]
     fn local_variable_non_ident_pat() {
         cov_mark::check!(not_applicable_in_non_ident_pat);
         check_assist_not_applicable(
@@ -762,4 +774,9 @@ fn main() {
 "#,
         )
     }
+
+    #[test]
+    fn not_applicable_to_other_names() {
+        check_assist_not_applicable(bool_to_enum, "fn $0main() {}")
+    }
 }