about summary refs log tree commit diff
diff options
context:
space:
mode:
authordavidsemakula <hello@davidsemakula.com>2024-02-16 20:39:52 +0300
committerdavidsemakula <hello@davidsemakula.com>2024-02-19 15:24:45 +0300
commitff7031008651021c330b93d4bd502810022b045d (patch)
treebf45e0e50b08d5d5bb47a075f16dd8bb6bb44f46
parent1205853c3689a69e81578dfd066b17e3ebe376cf (diff)
downloadrust-ff7031008651021c330b93d4bd502810022b045d.tar.gz
rust-ff7031008651021c330b93d4bd502810022b045d.zip
fix: only emit "unnecessary else" diagnostic for expr stmts
-rw-r--r--crates/hir-ty/src/diagnostics/expr.rs64
-rw-r--r--crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs14
2 files changed, 49 insertions, 29 deletions
diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs
index 718409e1599..4fe75f24b80 100644
--- a/crates/hir-ty/src/diagnostics/expr.rs
+++ b/crates/hir-ty/src/diagnostics/expr.rs
@@ -109,7 +109,7 @@ impl ExprValidator {
                     self.check_for_trailing_return(*body_expr, &body);
                 }
                 Expr::If { .. } => {
-                    self.check_for_unnecessary_else(id, expr, db);
+                    self.check_for_unnecessary_else(id, expr, &body, db);
                 }
                 Expr::Block { .. } => {
                     self.validate_block(db, expr);
@@ -337,35 +337,17 @@ impl ExprValidator {
         }
     }
 
-    fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, db: &dyn HirDatabase) {
+    fn check_for_unnecessary_else(
+        &mut self,
+        id: ExprId,
+        expr: &Expr,
+        body: &Body,
+        db: &dyn HirDatabase,
+    ) {
         if let Expr::If { condition: _, then_branch, else_branch } = expr {
             if else_branch.is_none() {
                 return;
             }
-            let (body, source_map) = db.body_with_source_map(self.owner);
-            let Ok(source_ptr) = source_map.expr_syntax(id) else {
-                return;
-            };
-            let root = source_ptr.file_syntax(db.upcast());
-            let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else {
-                return;
-            };
-            let mut top_if_expr = if_expr;
-            loop {
-                let parent = top_if_expr.syntax().parent();
-                let has_parent_let_stmt =
-                    parent.as_ref().map_or(false, |node| ast::LetStmt::can_cast(node.kind()));
-                if has_parent_let_stmt {
-                    // Bail if parent or direct ancestor is a let stmt.
-                    return;
-                }
-                let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else {
-                    // Parent is neither an if expr nor a let stmt.
-                    break;
-                };
-                // Check parent if expr.
-                top_if_expr = parent_if_expr;
-            }
             if let Expr::Block { statements, tail, .. } = &body.exprs[*then_branch] {
                 let last_then_expr = tail.or_else(|| match statements.last()? {
                     Statement::Expr { expr, .. } => Some(*expr),
@@ -374,6 +356,36 @@ impl ExprValidator {
                 if let Some(last_then_expr) = last_then_expr {
                     let last_then_expr_ty = &self.infer[last_then_expr];
                     if last_then_expr_ty.is_never() {
+                        // Only look at sources if the then branch diverges and we have an else branch.
+                        let (_, source_map) = db.body_with_source_map(self.owner);
+                        let Ok(source_ptr) = source_map.expr_syntax(id) else {
+                            return;
+                        };
+                        let root = source_ptr.file_syntax(db.upcast());
+                        let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else {
+                            return;
+                        };
+                        let mut top_if_expr = if_expr;
+                        loop {
+                            let parent = top_if_expr.syntax().parent();
+                            let has_parent_expr_stmt_or_stmt_list =
+                                parent.as_ref().map_or(false, |node| {
+                                    ast::ExprStmt::can_cast(node.kind())
+                                        | ast::StmtList::can_cast(node.kind())
+                                });
+                            if has_parent_expr_stmt_or_stmt_list {
+                                // Only emit diagnostic if parent or direct ancestor is either
+                                // an expr stmt or a stmt list.
+                                break;
+                            }
+                            let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else {
+                                // Bail if parent is neither an if expr, an expr stmt nor a stmt list.
+                                return;
+                            };
+                            // Check parent if expr.
+                            top_if_expr = parent_if_expr;
+                        }
+
                         self.diagnostics
                             .push(BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr: id })
                     }
diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
index 9564807a334..7bfd64596ed 100644
--- a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
+++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
@@ -467,10 +467,10 @@ fn test() {
     }
 
     #[test]
-    fn no_diagnostic_if_tail_exists_in_else_branch() {
+    fn no_diagnostic_if_not_expr_stmt() {
         check_diagnostics_with_needless_return_disabled(
             r#"
-fn test1(a: bool) {
+fn test1() {
     let _x = if a {
         return;
     } else {
@@ -478,7 +478,7 @@ fn test1(a: bool) {
     };
 }
 
-fn test2(a: bool, b: bool, c: bool) {
+fn test2() {
     let _x = if a {
         return;
     } else if b {
@@ -491,5 +491,13 @@ fn test2(a: bool, b: bool, c: bool) {
 }
 "#,
         );
+        check_diagnostics_with_disabled(
+            r#"
+fn test3() {
+    foo(if a { return 1 } else { 0 })
+}
+"#,
+            std::iter::once("E0308".to_owned()),
+        );
     }
 }