about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-02-19 12:49:08 +0000
committerbors <bors@rust-lang.org>2024-02-19 12:49:08 +0000
commite6b96dba5614150b3b73a79db34a9f70c3729265 (patch)
tree5577542774680cb2254779661e5366d056d9582b
parentac1029fac3fe7d2c76cf2df71e0614fdb9683d51 (diff)
parentf2218e727840ec0286d89ecc758322674e1efb6d (diff)
downloadrust-e6b96dba5614150b3b73a79db34a9f70c3729265.tar.gz
rust-e6b96dba5614150b3b73a79db34a9f70c3729265.zip
Auto merge of #16590 - davidsemakula:unnecessary-else-diagnostic-fix, r=Veykril
fix: Fix false positives for "unnecessary else" diagnostic

Completes https://github.com/rust-lang/rust-analyzer/pull/16567 by `@ShoyuVanilla` (see https://github.com/rust-lang/rust-analyzer/pull/16567#discussion_r1492700667)
Fixes #16556
-rw-r--r--crates/hir-ty/src/diagnostics/expr.rs37
-rw-r--r--crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs122
2 files changed, 151 insertions, 8 deletions
diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs
index 0c5d6399619..6c8a1875165 100644
--- a/crates/hir-ty/src/diagnostics/expr.rs
+++ b/crates/hir-ty/src/diagnostics/expr.rs
@@ -12,6 +12,7 @@ use hir_expand::name;
 use itertools::Itertools;
 use rustc_hash::FxHashSet;
 use rustc_pattern_analysis::usefulness::{compute_match_usefulness, ValidityConstraint};
+use syntax::{ast, AstNode};
 use tracing::debug;
 use triomphe::Arc;
 use typed_arena::Arena;
@@ -108,7 +109,7 @@ impl ExprValidator {
                     self.check_for_trailing_return(*body_expr, &body);
                 }
                 Expr::If { .. } => {
-                    self.check_for_unnecessary_else(id, expr, &body);
+                    self.check_for_unnecessary_else(id, expr, db);
                 }
                 Expr::Block { .. } => {
                     self.validate_block(db, expr);
@@ -336,12 +337,12 @@ impl ExprValidator {
         }
     }
 
-    fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, body: &Body) {
+    fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, db: &dyn HirDatabase) {
         if let Expr::If { condition: _, then_branch, else_branch } = expr {
             if else_branch.is_none() {
                 return;
             }
-            if let Expr::Block { statements, tail, .. } = &body.exprs[*then_branch] {
+            if let Expr::Block { statements, tail, .. } = &self.body.exprs[*then_branch] {
                 let last_then_expr = tail.or_else(|| match statements.last()? {
                     Statement::Expr { expr, .. } => Some(*expr),
                     _ => None,
@@ -349,6 +350,36 @@ impl ExprValidator {
                 if let Some(last_then_expr) = last_then_expr {
                     let last_then_expr_ty = &self.infer[last_then_expr];
                     if last_then_expr_ty.is_never() {
+                        // Only look at sources if the then branch diverges and we have an else branch.
+                        let (_, source_map) = db.body_with_source_map(self.owner);
+                        let Ok(source_ptr) = source_map.expr_syntax(id) else {
+                            return;
+                        };
+                        let root = source_ptr.file_syntax(db.upcast());
+                        let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else {
+                            return;
+                        };
+                        let mut top_if_expr = if_expr;
+                        loop {
+                            let parent = top_if_expr.syntax().parent();
+                            let has_parent_expr_stmt_or_stmt_list =
+                                parent.as_ref().map_or(false, |node| {
+                                    ast::ExprStmt::can_cast(node.kind())
+                                        | ast::StmtList::can_cast(node.kind())
+                                });
+                            if has_parent_expr_stmt_or_stmt_list {
+                                // Only emit diagnostic if parent or direct ancestor is either
+                                // an expr stmt or a stmt list.
+                                break;
+                            }
+                            let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else {
+                                // Bail if parent is neither an if expr, an expr stmt nor a stmt list.
+                                return;
+                            };
+                            // Check parent if expr.
+                            top_if_expr = parent_if_expr;
+                        }
+
                         self.diagnostics
                             .push(BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr: id })
                     }
diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
index ae8241ec2c6..9c63d79d910 100644
--- a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
+++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
@@ -2,7 +2,10 @@ use hir::{db::ExpandDatabase, diagnostics::RemoveUnnecessaryElse, HirFileIdExt};
 use ide_db::{assists::Assist, source_change::SourceChange};
 use itertools::Itertools;
 use syntax::{
-    ast::{self, edit::IndentLevel},
+    ast::{
+        self,
+        edit::{AstNodeEdit, IndentLevel},
+    },
     AstNode, SyntaxToken, TextRange,
 };
 use text_edit::TextEdit;
@@ -41,10 +44,15 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveUnnecessaryElse) -> Option<Vec<
         indent = indent + 1;
     }
     let else_replacement = match if_expr.else_branch()? {
-        ast::ElseBranch::Block(ref block) => {
-            block.statements().map(|stmt| format!("\n{indent}{stmt}")).join("")
-        }
-        ast::ElseBranch::IfExpr(ref nested_if_expr) => {
+        ast::ElseBranch::Block(block) => block
+            .statements()
+            .map(|stmt| format!("\n{indent}{stmt}"))
+            .chain(block.tail_expr().map(|tail| format!("\n{indent}{tail}")))
+            .join(""),
+        ast::ElseBranch::IfExpr(mut nested_if_expr) => {
+            if has_parent_if_expr {
+                nested_if_expr = nested_if_expr.indent(IndentLevel(1))
+            }
             format!("\n{indent}{nested_if_expr}")
         }
     };
@@ -172,6 +180,41 @@ fn test() {
     }
 
     #[test]
+    fn remove_unnecessary_else_for_return3() {
+        check_diagnostics_with_needless_return_disabled(
+            r#"
+fn test(a: bool) -> i32 {
+    if a {
+        return 1;
+    } else {
+    //^^^^ 💡 weak: remove unnecessary else block
+        0
+    }
+}
+"#,
+        );
+        check_fix(
+            r#"
+fn test(a: bool) -> i32 {
+    if a {
+        return 1;
+    } else$0 {
+        0
+    }
+}
+"#,
+            r#"
+fn test(a: bool) -> i32 {
+    if a {
+        return 1;
+    }
+    0
+}
+"#,
+        );
+    }
+
+    #[test]
     fn remove_unnecessary_else_for_return_in_child_if_expr() {
         check_diagnostics_with_needless_return_disabled(
             r#"
@@ -215,6 +258,41 @@ fn test() {
     }
 
     #[test]
+    fn remove_unnecessary_else_for_return_in_child_if_expr2() {
+        check_fix(
+            r#"
+fn test() {
+    if foo {
+        do_something();
+    } else if qux {
+        return bar;
+    } else$0 if quux {
+        do_something_else();
+    } else {
+        do_something_else2();
+    }
+}
+"#,
+            r#"
+fn test() {
+    if foo {
+        do_something();
+    } else {
+        if qux {
+            return bar;
+        }
+        if quux {
+            do_something_else();
+        } else {
+            do_something_else2();
+        }
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
     fn remove_unnecessary_else_for_break() {
         check_diagnostics(
             r#"
@@ -387,4 +465,38 @@ fn test() {
 "#,
         );
     }
+
+    #[test]
+    fn no_diagnostic_if_not_expr_stmt() {
+        check_diagnostics_with_needless_return_disabled(
+            r#"
+fn test1() {
+    let _x = if a {
+        return;
+    } else {
+        1
+    };
+}
+
+fn test2() {
+    let _x = if a {
+        return;
+    } else if b {
+        return;
+    } else if c {
+        1
+    } else {
+        return;
+    };
+}
+"#,
+        );
+        check_diagnostics(
+            r#"
+fn test3() -> u8 {
+    foo(if a { return 1 } else { 0 })
+}
+"#,
+        );
+    }
 }