about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir/src/diagnostics.rs10
-rw-r--r--crates/ide-diagnostics/src/handlers/remove_trailing_return.rs46
2 files changed, 24 insertions, 32 deletions
diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs
index b9f86f34156..e07c48aa8ac 100644
--- a/crates/hir/src/diagnostics.rs
+++ b/crates/hir/src/diagnostics.rs
@@ -11,7 +11,7 @@ use cfg::{CfgExpr, CfgOptions};
 use either::Either;
 use hir_def::{body::SyntheticSyntax, hir::ExprOrPatId, path::ModPath, AssocItemId, DefWithBodyId};
 use hir_expand::{name::Name, HirFileId, InFile};
-use syntax::{ast, AstNode, AstPtr, SyntaxError, SyntaxNodePtr, TextRange};
+use syntax::{ast, AstPtr, SyntaxError, SyntaxNodePtr, TextRange};
 
 use crate::{AssocItem, Field, Local, MacroKind, Trait, Type};
 
@@ -346,8 +346,7 @@ pub struct TraitImplRedundantAssocItems {
 
 #[derive(Debug)]
 pub struct RemoveTrailingReturn {
-    pub file_id: HirFileId,
-    pub return_expr: AstPtr<ast::Expr>,
+    pub return_expr: InFile<AstPtr<ast::ReturnExpr>>,
 }
 
 #[derive(Debug)]
@@ -460,11 +459,10 @@ impl AnyDiagnostic {
             BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => {
                 if let Ok(source_ptr) = source_map.expr_syntax(return_expr) {
                     // Filters out desugared return expressions (e.g. desugared try operators).
-                    if ast::ReturnExpr::can_cast(source_ptr.value.kind()) {
+                    if let Some(ptr) = source_ptr.value.cast::<ast::ReturnExpr>() {
                         return Some(
                             RemoveTrailingReturn {
-                                file_id: source_ptr.file_id,
-                                return_expr: source_ptr.value,
+                                return_expr: InFile::new(source_ptr.file_id, ptr),
                             }
                             .into(),
                         );
diff --git a/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs
index 6c4ec7c132f..276ac0d15d9 100644
--- a/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs
+++ b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs
@@ -1,9 +1,9 @@
-use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn, HirFileIdExt, InFile};
-use ide_db::{assists::Assist, source_change::SourceChange};
-use syntax::{ast, AstNode, SyntaxNodePtr};
+use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn};
+use ide_db::{assists::Assist, base_db::FileRange, source_change::SourceChange};
+use syntax::{ast, AstNode};
 use text_edit::TextEdit;
 
-use crate::{fix, Diagnostic, DiagnosticCode, DiagnosticsContext};
+use crate::{adjusted_display_range, fix, Diagnostic, DiagnosticCode, DiagnosticsContext};
 
 // Diagnostic: remove-trailing-return
 //
@@ -13,12 +13,12 @@ pub(crate) fn remove_trailing_return(
     ctx: &DiagnosticsContext<'_>,
     d: &RemoveTrailingReturn,
 ) -> Diagnostic {
-    let display_range = ctx.sema.diagnostics_display_range(InFile {
-        file_id: d.file_id,
-        value: expr_stmt(ctx, d)
-            .as_ref()
-            .map(|stmt| SyntaxNodePtr::new(stmt.syntax()))
-            .unwrap_or_else(|| d.return_expr.into()),
+    let display_range = adjusted_display_range(ctx, d.return_expr, &|return_expr| {
+        return_expr
+            .syntax()
+            .parent()
+            .and_then(ast::ExprStmt::cast)
+            .map(|stmt| stmt.syntax().text_range())
     });
     Diagnostic::new(
         DiagnosticCode::Clippy("needless_return"),
@@ -29,15 +29,20 @@ pub(crate) fn remove_trailing_return(
 }
 
 fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<Vec<Assist>> {
-    let return_expr = return_expr(ctx, d)?;
-    let stmt = expr_stmt(ctx, d);
+    let root = ctx.sema.db.parse_or_expand(d.return_expr.file_id);
+    let return_expr = d.return_expr.value.to_node(&root);
+    let stmt = return_expr.syntax().parent().and_then(ast::ExprStmt::cast);
+
+    let FileRange { range, file_id } =
+        ctx.sema.original_range_opt(stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax))?;
+    if Some(file_id) != d.return_expr.file_id.file_id() {
+        return None;
+    }
 
-    let range = stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax).text_range();
     let replacement =
         return_expr.expr().map_or_else(String::new, |expr| format!("{}", expr.syntax().text()));
-
     let edit = TextEdit::replace(range, replacement);
-    let source_change = SourceChange::from_text_edit(d.file_id.original_file(ctx.sema.db), edit);
+    let source_change = SourceChange::from_text_edit(file_id, edit);
 
     Some(vec![fix(
         "remove_trailing_return",
@@ -47,17 +52,6 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<Vec<A
     )])
 }
 
-fn return_expr(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ReturnExpr> {
-    let root = ctx.sema.db.parse_or_expand(d.file_id);
-    let expr = d.return_expr.to_node(&root);
-    ast::ReturnExpr::cast(expr.syntax().clone())
-}
-
-fn expr_stmt(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ExprStmt> {
-    let return_expr = return_expr(ctx, d)?;
-    return_expr.syntax().parent().and_then(ast::ExprStmt::cast)
-}
-
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_diagnostics, check_fix};