about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyan Mehri <ryan.mehri1@gmail.com>2023-10-01 21:27:06 -0700
committerRyan Mehri <ryan.mehri1@gmail.com>2023-10-01 21:30:10 -0700
commit34d3490198fe6e7f56eb60c9665d25ef9cfd6f4e (patch)
tree6ec542cf8fe1c51500537d4943c16a3d570810cd
parent0840038f02daec6ba3238f05d8caa037d28701a0 (diff)
downloadrust-34d3490198fe6e7f56eb60c9665d25ef9cfd6f4e.tar.gz
rust-34d3490198fe6e7f56eb60c9665d25ef9cfd6f4e.zip
feat: add assist for applying De Morgan's law to iterators
-rw-r--r--crates/ide-assists/src/handlers/apply_demorgan.rs329
-rw-r--r--crates/ide-assists/src/lib.rs1
-rw-r--r--crates/ide-assists/src/tests/generated.rs24
3 files changed, 352 insertions, 2 deletions
diff --git a/crates/ide-assists/src/handlers/apply_demorgan.rs b/crates/ide-assists/src/handlers/apply_demorgan.rs
index 66bc2f6dadc..74db300465a 100644
--- a/crates/ide-assists/src/handlers/apply_demorgan.rs
+++ b/crates/ide-assists/src/handlers/apply_demorgan.rs
@@ -1,7 +1,13 @@
 use std::collections::VecDeque;
 
+use ide_db::{
+    assists::GroupLabel,
+    famous_defs::FamousDefs,
+    source_change::SourceChangeBuilder,
+    syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
+};
 use syntax::{
-    ast::{self, AstNode, Expr::BinExpr},
+    ast::{self, make, AstNode, Expr::BinExpr, HasArgList},
     ted::{self, Position},
     SyntaxKind,
 };
@@ -89,7 +95,8 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
 
     let dm_lhs = demorganed.lhs()?;
 
-    acc.add(
+    acc.add_group(
+        &GroupLabel("Apply De Morgan's law".to_string()),
         AssistId("apply_demorgan", AssistKind::RefactorRewrite),
         "Apply De Morgan's law",
         op_range,
@@ -143,6 +150,122 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
     )
 }
 
+// Assist: apply_demorgan_iterator
+//
+// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law] to
+// `Iterator::all` and `Iterator::any`.
+//
+// This transforms expressions of the form `!iter.any(|x| predicate(x))` into
+// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for
+// `Iterator::all` into `Iterator::any`.
+//
+// ```
+// # //- minicore: iterator
+// fn main() {
+//     let arr = [1, 2, 3];
+//     if !arr.into_iter().$0any(|num| num == 4) {
+//         println!("foo");
+//     }
+// }
+// ```
+// ->
+// ```
+// fn main() {
+//     let arr = [1, 2, 3];
+//     if arr.into_iter().all(|num| num != 4) {
+//         println!("foo");
+//     }
+// }
+// ```
+pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
+    let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
+    let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
+
+    let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
+    let closure_body = closure_expr.body()?;
+
+    let op_range = method_call.syntax().text_range();
+    let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
+    acc.add_group(
+        &GroupLabel("Apply De Morgan's law".to_string()),
+        AssistId("apply_demorgan_iterator", AssistKind::RefactorRewrite),
+        label,
+        op_range,
+        |edit| {
+            // replace the method name
+            let new_name = match name.text().as_str() {
+                "all" => make::name_ref("any"),
+                "any" => make::name_ref("all"),
+                _ => unreachable!(),
+            }
+            .clone_for_update();
+            edit.replace_ast(name, new_name);
+
+            // negate all tail expressions in the closure body
+            let tail_cb = &mut |e: &_| tail_cb_impl(edit, e);
+            walk_expr(&closure_body, &mut |expr| {
+                if let ast::Expr::ReturnExpr(ret_expr) = expr {
+                    if let Some(ret_expr_arg) = &ret_expr.expr() {
+                        for_each_tail_expr(ret_expr_arg, tail_cb);
+                    }
+                }
+            });
+            for_each_tail_expr(&closure_body, tail_cb);
+
+            // negate the whole method call
+            if let Some(prefix_expr) = method_call
+                .syntax()
+                .parent()
+                .and_then(ast::PrefixExpr::cast)
+                .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
+            {
+                edit.delete(prefix_expr.op_token().unwrap().text_range());
+            } else {
+                edit.insert(method_call.syntax().text_range().start(), "!");
+            }
+        },
+    )
+}
+
+/// Ensures that the method call is to `Iterator::all` or `Iterator::any`.
+fn validate_method_call_expr(
+    ctx: &AssistContext<'_>,
+    method_call: &ast::MethodCallExpr,
+) -> Option<(ast::NameRef, ast::Expr)> {
+    let name_ref = method_call.name_ref()?;
+    if name_ref.text() != "all" && name_ref.text() != "any" {
+        return None;
+    }
+    let arg_expr = method_call.arg_list()?.args().next()?;
+
+    let sema = &ctx.sema;
+
+    let receiver = method_call.receiver()?;
+    let it_type = sema.type_of_expr(&receiver)?.adjusted();
+    let module = sema.scope(receiver.syntax())?.module();
+    let krate = module.krate();
+
+    let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
+    it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
+}
+
+fn tail_cb_impl(edit: &mut SourceChangeBuilder, e: &ast::Expr) {
+    match e {
+        ast::Expr::BreakExpr(break_expr) => {
+            if let Some(break_expr_arg) = break_expr.expr() {
+                for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(edit, e))
+            }
+        }
+        ast::Expr::ReturnExpr(_) => {
+            // all return expressions have already been handled by the walk loop
+        }
+        e => {
+            let inverted_body = invert_boolean_expression(e.clone());
+            edit.replace(e.syntax().text_range(), inverted_body.syntax().text());
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -255,4 +378,206 @@ fn f() { !(S <= S || S < S) }
             "fn() { let x = a && b && c; }",
         )
     }
+
+    #[test]
+    fn demorgan_iterator_any_all_reverse() {
+        check_assist(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if arr.into_iter().all(|num| num $0!= 4) {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr.into_iter().any(|num| num == 4) {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn demorgan_iterator_all_any() {
+        check_assist(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr.into_iter().$0all(|num| num > 3) {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    let arr = [1, 2, 3];
+    if arr.into_iter().any(|num| num <= 3) {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn demorgan_iterator_multiple_terms() {
+        check_assist(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    let arr = [1, 2, 3];
+    if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn demorgan_iterator_double_negation() {
+        check_assist(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr.into_iter().$0all(|num| !(num > 3)) {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    let arr = [1, 2, 3];
+    if arr.into_iter().any(|num| num > 3) {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn demorgan_iterator_double_parens() {
+        check_assist(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    let arr = [1, 2, 3];
+    if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn demorgan_iterator_multiline() {
+        check_assist(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if arr
+        .into_iter()
+        .all$0(|num| !num.is_negative())
+    {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr
+        .into_iter()
+        .any(|num| num.is_negative())
+    {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn demorgan_iterator_block_closure() {
+        check_assist(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [-1, 1, 2, 3];
+    if arr.into_iter().all(|num: i32| {
+        $0if num.is_positive() {
+            num <= 3
+        } else {
+            num >= -1
+        }
+    }) {
+        println!("foo");
+    }
+}
+"#,
+            r#"
+fn main() {
+    let arr = [-1, 1, 2, 3];
+    if !arr.into_iter().any(|num: i32| {
+        if num.is_positive() {
+            num > 3
+        } else {
+            num < -1
+        }
+    }) {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn demorgan_iterator_wrong_method() {
+        check_assist_not_applicable(
+            apply_demorgan_iterator,
+            r#"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr.into_iter().$0map(|num| num > 3) {
+        println!("foo");
+    }
+}
+"#,
+        );
+    }
 }
diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs
index a17ce93e928..50476ccf363 100644
--- a/crates/ide-assists/src/lib.rs
+++ b/crates/ide-assists/src/lib.rs
@@ -226,6 +226,7 @@ mod handlers {
             add_return_type::add_return_type,
             add_turbo_fish::add_turbo_fish,
             apply_demorgan::apply_demorgan,
+            apply_demorgan::apply_demorgan_iterator,
             auto_import::auto_import,
             bind_unused_param::bind_unused_param,
             bool_to_enum::bool_to_enum,
diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs
index 5a815d5c6a1..65bd74c018b 100644
--- a/crates/ide-assists/src/tests/generated.rs
+++ b/crates/ide-assists/src/tests/generated.rs
@@ -245,6 +245,30 @@ fn main() {
 }
 
 #[test]
+fn doctest_apply_demorgan_iterator() {
+    check_doc_test(
+        "apply_demorgan_iterator",
+        r#####"
+//- minicore: iterator
+fn main() {
+    let arr = [1, 2, 3];
+    if !arr.into_iter().$0any(|num| num == 4) {
+        println!("foo");
+    }
+}
+"#####,
+        r#####"
+fn main() {
+    let arr = [1, 2, 3];
+    if arr.into_iter().all(|num| num != 4) {
+        println!("foo");
+    }
+}
+"#####,
+    )
+}
+
+#[test]
 fn doctest_auto_import() {
     check_doc_test(
         "auto_import",