about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs58
-rw-r--r--crates/ide-diagnostics/src/handlers/type_mismatch.rs31
-rw-r--r--crates/ide-diagnostics/src/lib.rs20
3 files changed, 69 insertions, 40 deletions
diff --git a/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs b/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs
index 95a3ac1d519..5f8b3e543b9 100644
--- a/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs
+++ b/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs
@@ -1,11 +1,9 @@
-use ide_db::base_db::{FileRange, SourceDatabase};
 use syntax::{
-    algo::find_node_at_range,
     ast::{self, HasArgList},
     AstNode, TextRange,
 };
 
-use crate::{Diagnostic, DiagnosticsContext};
+use crate::{adjusted_display_range, Diagnostic, DiagnosticsContext};
 
 // Diagnostic: mismatched-arg-count
 //
@@ -20,40 +18,32 @@ pub(crate) fn mismatched_arg_count(
 }
 
 fn invalid_args_range(ctx: &DiagnosticsContext<'_>, d: &hir::MismatchedArgCount) -> TextRange {
-    let FileRange { file_id, range } =
-        ctx.sema.diagnostics_display_range(d.call_expr.clone().map(|it| it.into()));
-
-    let source_file = ctx.sema.db.parse(file_id);
-    let expr = find_node_at_range::<ast::Expr>(&source_file.syntax_node(), range)
-        .filter(|it| it.syntax().text_range() == range);
-    let arg_list = match expr {
-        Some(ast::Expr::CallExpr(call)) => call.arg_list(),
-        Some(ast::Expr::MethodCallExpr(call)) => call.arg_list(),
-        _ => None,
-    };
-    let arg_list = match arg_list {
-        Some(it) => it,
-        None => return range,
-    };
-    if d.found < d.expected {
-        if d.found == 0 {
-            return arg_list.syntax().text_range();
+    adjusted_display_range::<ast::Expr>(ctx, d.call_expr.clone().map(|it| it.into()), &|expr| {
+        let arg_list = match expr {
+            ast::Expr::CallExpr(call) => call.arg_list()?,
+            ast::Expr::MethodCallExpr(call) => call.arg_list()?,
+            _ => return None,
+        };
+        if d.found < d.expected {
+            if d.found == 0 {
+                return Some(arg_list.syntax().text_range());
+            }
+            if let Some(r_paren) = arg_list.r_paren_token() {
+                return Some(r_paren.text_range());
+            }
         }
-        if let Some(r_paren) = arg_list.r_paren_token() {
-            return r_paren.text_range();
+        if d.expected < d.found {
+            if d.expected == 0 {
+                return Some(arg_list.syntax().text_range());
+            }
+            let zip = arg_list.args().nth(d.expected).zip(arg_list.r_paren_token());
+            if let Some((arg, r_paren)) = zip {
+                return Some(arg.syntax().text_range().cover(r_paren.text_range()));
+            }
         }
-    }
-    if d.expected < d.found {
-        if d.expected == 0 {
-            return arg_list.syntax().text_range();
-        }
-        let zip = arg_list.args().nth(d.expected).zip(arg_list.r_paren_token());
-        if let Some((arg, r_paren)) = zip {
-            return arg.syntax().text_range().cover(r_paren.text_range());
-        }
-    }
 
-    range
+        None
+    })
 }
 
 #[cfg(test)]
diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
index 5826bed3434..6bf90e645b4 100644
--- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs
+++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
@@ -1,18 +1,28 @@
 use hir::{db::AstDatabase, HirDisplay, Type};
 use ide_db::{famous_defs::FamousDefs, source_change::SourceChange};
 use syntax::{
-    ast::{BlockExpr, ExprStmt},
+    ast::{self, BlockExpr, ExprStmt},
     AstNode,
 };
 use text_edit::TextEdit;
 
-use crate::{fix, Assist, Diagnostic, DiagnosticsContext};
+use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticsContext};
 
 // Diagnostic: type-mismatch
 //
 // This diagnostic is triggered when the type of an expression does not match
 // the expected type.
 pub(crate) fn type_mismatch(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Diagnostic {
+    let display_range = adjusted_display_range::<ast::BlockExpr>(
+        ctx,
+        d.expr.clone().map(|it| it.into()),
+        &|block| {
+            let r_curly_range = block.stmt_list()?.r_curly_token()?.text_range();
+            cov_mark::hit!(type_mismatch_on_block);
+            Some(r_curly_range)
+        },
+    );
+
     let mut diag = Diagnostic::new(
         "type-mismatch",
         format!(
@@ -20,7 +30,7 @@ pub(crate) fn type_mismatch(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch)
             d.expected.display(ctx.sema.db),
             d.actual.display(ctx.sema.db)
         ),
-        ctx.sema.diagnostics_display_range(d.expr.clone().map(|it| it.into())).range,
+        display_range,
     )
     .with_fixes(fixes(ctx, d));
     if diag.fixes.is_none() {
@@ -545,4 +555,19 @@ fn test() -> String {
             "#,
         );
     }
+
+    #[test]
+    fn type_mismatch_on_block() {
+        cov_mark::check!(type_mismatch_on_block);
+        check_diagnostics(
+            r#"
+fn f() -> i32 {
+    let x = 1;
+    let y = 2;
+    let _ = x + y;
+  }
+//^ error: expected i32, found ()
+"#,
+        );
+    }
 }
diff --git a/crates/ide-diagnostics/src/lib.rs b/crates/ide-diagnostics/src/lib.rs
index daf9b168867..41abaa836f5 100644
--- a/crates/ide-diagnostics/src/lib.rs
+++ b/crates/ide-diagnostics/src/lib.rs
@@ -55,15 +55,15 @@ mod handlers {
 #[cfg(test)]
 mod tests;
 
-use hir::{diagnostics::AnyDiagnostic, Semantics};
+use hir::{diagnostics::AnyDiagnostic, InFile, Semantics};
 use ide_db::{
     assists::{Assist, AssistId, AssistKind, AssistResolveStrategy},
-    base_db::{FileId, SourceDatabase},
+    base_db::{FileId, FileRange, SourceDatabase},
     label::Label,
     source_change::SourceChange,
     FxHashSet, RootDatabase,
 };
-use syntax::{ast::AstNode, TextRange};
+use syntax::{algo::find_node_at_range, ast::AstNode, SyntaxNodePtr, TextRange};
 
 #[derive(Copy, Clone, Debug, PartialEq)]
 pub struct DiagnosticCode(pub &'static str);
@@ -244,3 +244,17 @@ fn unresolved_fix(id: &'static str, label: &str, target: TextRange) -> Assist {
         trigger_signature_help: false,
     }
 }
+
+fn adjusted_display_range<N: AstNode>(
+    ctx: &DiagnosticsContext<'_>,
+    diag_ptr: InFile<SyntaxNodePtr>,
+    adj: &dyn Fn(N) -> Option<TextRange>,
+) -> TextRange {
+    let FileRange { file_id, range } = ctx.sema.diagnostics_display_range(diag_ptr);
+
+    let source_file = ctx.sema.db.parse(file_id);
+    find_node_at_range::<N>(&source_file.syntax_node(), range)
+        .filter(|it| it.syntax().text_range() == range)
+        .and_then(adj)
+        .unwrap_or(range)
+}