about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs166
1 files changed, 165 insertions, 1 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index f59b0528131..784a0d35599 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -263,6 +263,11 @@ fn replace_usages(
 fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
     let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?;
 
+    if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) {
+        cov_mark::hit!(dont_assign_incorrect_ref);
+        return None;
+    }
+
     if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
         bin_expr.rhs()
     } else {
@@ -273,6 +278,11 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
 fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> {
     let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
 
+    if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
+        cov_mark::hit!(dont_overwrite_expression_inside_negation);
+        return None;
+    }
+
     if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
         let inner_expr = prefix_expr.expr()?;
         Some((prefix_expr, inner_expr))
@@ -285,7 +295,12 @@ fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprFie
     let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?;
     let initializer = record_field.expr()?;
 
-    Some((record_field, initializer))
+    if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) {
+        Some((record_field, initializer))
+    } else {
+        cov_mark::hit!(dont_overwrite_wrong_record_field);
+        None
+    }
 }
 
 /// Adds the definition of the new enum before the target node.
@@ -562,6 +577,37 @@ fn main() {
     }
 
     #[test]
+    fn local_variable_nested_in_negation() {
+        cov_mark::check!(dont_overwrite_expression_inside_negation);
+        check_assist(
+            bool_to_enum,
+            r#"
+fn main() {
+    if !"foo".chars().any(|c| {
+        let $0foo = true;
+        foo
+    }) {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    if !"foo".chars().any(|c| {
+        #[derive(PartialEq, Eq)]
+        enum Bool { True, False }
+
+        let foo = Bool::True;
+        foo == Bool::True
+    }) {
+        println!("foo");
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn local_variable_non_bool() {
         cov_mark::check!(not_applicable_non_bool_local);
         check_assist_not_applicable(
@@ -639,6 +685,42 @@ fn main() {
     }
 
     #[test]
+    fn field_negated() {
+        check_assist(
+            bool_to_enum,
+            r#"
+struct Foo {
+    $0bar: bool,
+}
+
+fn main() {
+    let foo = Foo { bar: false };
+
+    if !foo.bar {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+    bar: Bool,
+}
+
+fn main() {
+    let foo = Foo { bar: Bool::False };
+
+    if foo.bar == Bool::False {
+        println!("foo");
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn field_in_mod_properly_indented() {
         check_assist(
             bool_to_enum,
@@ -715,6 +797,88 @@ fn main() {
     }
 
     #[test]
+    fn field_assigned_to_another() {
+        cov_mark::check!(dont_assign_incorrect_ref);
+        check_assist(
+            bool_to_enum,
+            r#"
+struct Foo {
+    $0foo: bool,
+}
+
+struct Bar {
+    bar: bool,
+}
+
+fn main() {
+    let foo = Foo { foo: true };
+    let mut bar = Bar { bar: true };
+
+    bar.bar = foo.foo;
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+    foo: Bool,
+}
+
+struct Bar {
+    bar: bool,
+}
+
+fn main() {
+    let foo = Foo { foo: Bool::True };
+    let mut bar = Bar { bar: true };
+
+    bar.bar = foo.foo == Bool::True;
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn field_initialized_with_other() {
+        cov_mark::check!(dont_overwrite_wrong_record_field);
+        check_assist(
+            bool_to_enum,
+            r#"
+struct Foo {
+    $0foo: bool,
+}
+
+struct Bar {
+    bar: bool,
+}
+
+fn main() {
+    let foo = Foo { foo: true };
+    let bar = Bar { bar: foo.foo };
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+    foo: Bool,
+}
+
+struct Bar {
+    bar: bool,
+}
+
+fn main() {
+    let foo = Foo { foo: Bool::True };
+    let bar = Bar { bar: foo.foo == Bool::True };
+}
+"#,
+        )
+    }
+
+    #[test]
     fn field_non_bool() {
         cov_mark::check!(not_applicable_non_bool_field);
         check_assist_not_applicable(