about summary refs log tree commit diff
diff options
context:
space:
mode:
authordavidsemakula <hello@davidsemakula.com>2024-01-30 19:57:36 +0300
committerdavidsemakula <hello@davidsemakula.com>2024-02-08 19:32:53 +0300
commit2987fac76fdc09056670180c5f2e2357d9af9247 (patch)
tree59967b4dcb5b44657e0e9d6afc285e2e5aee815b
parente07183461fe710e89cdc75edec0d81e8e9882116 (diff)
downloadrust-2987fac76fdc09056670180c5f2e2357d9af9247.tar.gz
rust-2987fac76fdc09056670180c5f2e2357d9af9247.zip
diagnostic to remove trailing return
-rw-r--r--crates/hir-ty/src/diagnostics/expr.rs31
-rw-r--r--crates/hir/src/diagnostics.rs18
-rw-r--r--crates/ide-diagnostics/src/handlers/remove_trailing_return.rs262
-rw-r--r--crates/ide-diagnostics/src/handlers/type_mismatch.rs7
-rw-r--r--crates/ide-diagnostics/src/lib.rs2
5 files changed, 318 insertions, 2 deletions
diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs
index c09351390af..af3ca3c082d 100644
--- a/crates/hir-ty/src/diagnostics/expr.rs
+++ b/crates/hir-ty/src/diagnostics/expr.rs
@@ -44,6 +44,9 @@ pub enum BodyValidationDiagnostic {
         match_expr: ExprId,
         uncovered_patterns: String,
     },
+    RemoveTrailingReturn {
+        return_expr: ExprId,
+    },
     RemoveUnnecessaryElse {
         if_expr: ExprId,
     },
@@ -75,6 +78,10 @@ impl ExprValidator {
         let body = db.body(self.owner);
         let mut filter_map_next_checker = None;
 
+        if matches!(self.owner, DefWithBodyId::FunctionId(_)) {
+            self.check_for_trailing_return(body.body_expr, &body);
+        }
+
         for (id, expr) in body.exprs.iter() {
             if let Some((variant, missed_fields, true)) =
                 record_literal_missing_fields(db, &self.infer, id, expr)
@@ -93,12 +100,16 @@ impl ExprValidator {
                 Expr::Call { .. } | Expr::MethodCall { .. } => {
                     self.validate_call(db, id, expr, &mut filter_map_next_checker);
                 }
+                Expr::Closure { body: body_expr, .. } => {
+                    self.check_for_trailing_return(*body_expr, &body);
+                }
                 Expr::If { .. } => {
                     self.check_for_unnecessary_else(id, expr, &body);
                 }
                 _ => {}
             }
         }
+
         for (id, pat) in body.pats.iter() {
             if let Some((variant, missed_fields, true)) =
                 record_pattern_missing_fields(db, &self.infer, id, pat)
@@ -244,6 +255,26 @@ impl ExprValidator {
         pattern
     }
 
+    fn check_for_trailing_return(&mut self, body_expr: ExprId, body: &Body) {
+        match &body.exprs[body_expr] {
+            Expr::Block { statements, tail, .. } => {
+                let last_stmt = tail.or_else(|| match statements.last()? {
+                    Statement::Expr { expr, .. } => Some(*expr),
+                    _ => None,
+                });
+                if let Some(last_stmt) = last_stmt {
+                    self.check_for_trailing_return(last_stmt, body);
+                }
+            }
+            Expr::Return { .. } => {
+                self.diagnostics.push(BodyValidationDiagnostic::RemoveTrailingReturn {
+                    return_expr: body_expr,
+                });
+            }
+            _ => (),
+        }
+    }
+
     fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, body: &Body) {
         if let Expr::If { condition: _, then_branch, else_branch } = expr {
             if else_branch.is_none() {
diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs
index 487e0c8f7a5..cec6be8e892 100644
--- a/crates/hir/src/diagnostics.rs
+++ b/crates/hir/src/diagnostics.rs
@@ -68,6 +68,7 @@ diagnostics![
     PrivateAssocItem,
     PrivateField,
     ReplaceFilterMapNextWithFindMap,
+    RemoveTrailingReturn,
     RemoveUnnecessaryElse,
     TraitImplIncorrectSafety,
     TraitImplMissingAssocItems,
@@ -344,6 +345,12 @@ pub struct TraitImplRedundantAssocItems {
 }
 
 #[derive(Debug)]
+pub struct RemoveTrailingReturn {
+    pub file_id: HirFileId,
+    pub return_expr: AstPtr<ast::Expr>,
+}
+
+#[derive(Debug)]
 pub struct RemoveUnnecessaryElse {
     pub if_expr: InFile<AstPtr<ast::IfExpr>>,
 }
@@ -450,6 +457,17 @@ impl AnyDiagnostic {
                     Err(SyntheticSyntax) => (),
                 }
             }
+            BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => {
+                if let Ok(source_ptr) = source_map.expr_syntax(return_expr) {
+                    return Some(
+                        RemoveTrailingReturn {
+                            file_id: source_ptr.file_id,
+                            return_expr: source_ptr.value,
+                        }
+                        .into(),
+                    );
+                }
+            }
             BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr } => {
                 if let Ok(source_ptr) = source_map.expr_syntax(if_expr) {
                     if let Some(ptr) = source_ptr.value.cast::<ast::IfExpr>() {
diff --git a/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs
new file mode 100644
index 00000000000..6cb5911096f
--- /dev/null
+++ b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs
@@ -0,0 +1,262 @@
+use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn, HirFileIdExt, InFile};
+use ide_db::{assists::Assist, source_change::SourceChange};
+use syntax::{ast, AstNode, SyntaxNodePtr};
+use text_edit::TextEdit;
+
+use crate::{fix, Diagnostic, DiagnosticCode, DiagnosticsContext};
+
+// Diagnostic: remove-trailing-return
+//
+// This diagnostic is triggered when there is a redundant `return` at the end of a function
+// or closure.
+pub(crate) fn remove_trailing_return(
+    ctx: &DiagnosticsContext<'_>,
+    d: &RemoveTrailingReturn,
+) -> Diagnostic {
+    let display_range = ctx.sema.diagnostics_display_range(InFile {
+        file_id: d.file_id,
+        value: expr_stmt(ctx, d)
+            .as_ref()
+            .map(|stmt| SyntaxNodePtr::new(stmt.syntax()))
+            .unwrap_or_else(|| d.return_expr.into()),
+    });
+    Diagnostic::new(
+        DiagnosticCode::Clippy("needless_return"),
+        "replace return <expr>; with <expr>",
+        display_range,
+    )
+    .with_fixes(fixes(ctx, d))
+}
+
+fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<Vec<Assist>> {
+    let return_expr = return_expr(ctx, d)?;
+    let stmt = expr_stmt(ctx, d);
+
+    let range = stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax).text_range();
+    let replacement =
+        return_expr.expr().map_or_else(String::new, |expr| format!("{}", expr.syntax().text()));
+
+    let edit = TextEdit::replace(range, replacement);
+    let source_change = SourceChange::from_text_edit(d.file_id.original_file(ctx.sema.db), edit);
+
+    Some(vec![fix(
+        "remove_trailing_return",
+        "Replace return <expr>; with <expr>",
+        source_change,
+        range,
+    )])
+}
+
+fn return_expr(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ReturnExpr> {
+    let root = ctx.sema.db.parse_or_expand(d.file_id);
+    let expr = d.return_expr.to_node(&root);
+    ast::ReturnExpr::cast(expr.syntax().clone())
+}
+
+fn expr_stmt(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ExprStmt> {
+    let return_expr = return_expr(ctx, d)?;
+    return_expr.syntax().parent().and_then(ast::ExprStmt::cast)
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::tests::{check_diagnostics, check_fix};
+
+    #[test]
+    fn remove_trailing_return() {
+        check_diagnostics(
+            r#"
+fn foo() -> u8 {
+    return 2;
+} //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+"#,
+        );
+    }
+
+    #[test]
+    fn remove_trailing_return_inner_function() {
+        check_diagnostics(
+            r#"
+fn foo() -> u8 {
+    fn bar() -> u8 {
+        return 2;
+    } //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+    bar()
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn remove_trailing_return_closure() {
+        check_diagnostics(
+            r#"
+fn foo() -> u8 {
+    let bar = || return 2;
+    bar()      //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+}
+"#,
+        );
+        check_diagnostics(
+            r#"
+fn foo() -> u8 {
+    let bar = || {
+        return 2;
+    };//^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+    bar()
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn remove_trailing_return_unit() {
+        check_diagnostics(
+            r#"
+fn foo() {
+    return
+} //^^^^^^ 💡 weak: replace return <expr>; with <expr>
+"#,
+        );
+    }
+
+    #[test]
+    fn remove_trailing_return_no_semi() {
+        check_diagnostics(
+            r#"
+fn foo() -> u8 {
+    return 2
+} //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+"#,
+        );
+    }
+
+    #[test]
+    fn no_diagnostic_if_no_return_keyword() {
+        check_diagnostics(
+            r#"
+fn foo() -> u8 {
+    3
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn no_diagnostic_if_not_last_statement() {
+        check_diagnostics(
+            r#"
+fn foo() -> u8 {
+    if true { return 2; }
+    3
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn replace_with_expr() {
+        check_fix(
+            r#"
+fn foo() -> u8 {
+    return$0 2;
+}
+"#,
+            r#"
+fn foo() -> u8 {
+    2
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn replace_with_unit() {
+        check_fix(
+            r#"
+fn foo() {
+    return$0/*ensure tidy is happy*/
+}
+"#,
+            r#"
+fn foo() {
+    /*ensure tidy is happy*/
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn replace_with_expr_no_semi() {
+        check_fix(
+            r#"
+fn foo() -> u8 {
+    return$0 2
+}
+"#,
+            r#"
+fn foo() -> u8 {
+    2
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn replace_in_inner_function() {
+        check_fix(
+            r#"
+fn foo() -> u8 {
+    fn bar() -> u8 {
+        return$0 2;
+    }
+    bar()
+}
+"#,
+            r#"
+fn foo() -> u8 {
+    fn bar() -> u8 {
+        2
+    }
+    bar()
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn replace_in_closure() {
+        check_fix(
+            r#"
+fn foo() -> u8 {
+    let bar = || return$0 2;
+    bar()
+}
+"#,
+            r#"
+fn foo() -> u8 {
+    let bar = || 2;
+    bar()
+}
+"#,
+        );
+        check_fix(
+            r#"
+fn foo() -> u8 {
+    let bar = || {
+        return$0 2;
+    };
+    bar()
+}
+"#,
+            r#"
+fn foo() -> u8 {
+    let bar = || {
+        2
+    };
+    bar()
+}
+"#,
+        );
+    }
+}
diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
index 750189beecb..eec8efe785a 100644
--- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs
+++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
@@ -186,7 +186,9 @@ fn str_ref_to_owned(
 
 #[cfg(test)]
 mod tests {
-    use crate::tests::{check_diagnostics, check_fix, check_no_fix};
+    use crate::tests::{
+        check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix,
+    };
 
     #[test]
     fn missing_reference() {
@@ -718,7 +720,7 @@ struct Bar {
 
     #[test]
     fn return_no_value() {
-        check_diagnostics(
+        check_diagnostics_with_disabled(
             r#"
 fn f() -> i32 {
     return;
@@ -727,6 +729,7 @@ fn f() -> i32 {
 }
 fn g() { return; }
 "#,
+            std::iter::once("needless_return".to_string()),
         );
     }
 
diff --git a/crates/ide-diagnostics/src/lib.rs b/crates/ide-diagnostics/src/lib.rs
index 7423de0be74..7c5cf673303 100644
--- a/crates/ide-diagnostics/src/lib.rs
+++ b/crates/ide-diagnostics/src/lib.rs
@@ -43,6 +43,7 @@ mod handlers {
     pub(crate) mod no_such_field;
     pub(crate) mod private_assoc_item;
     pub(crate) mod private_field;
+    pub(crate) mod remove_trailing_return;
     pub(crate) mod remove_unnecessary_else;
     pub(crate) mod replace_filter_map_next_with_find_map;
     pub(crate) mod trait_impl_incorrect_safety;
@@ -383,6 +384,7 @@ pub fn diagnostics(
             AnyDiagnostic::UnusedVariable(d) => handlers::unused_variables::unused_variables(&ctx, &d),
             AnyDiagnostic::BreakOutsideOfLoop(d) => handlers::break_outside_of_loop::break_outside_of_loop(&ctx, &d),
             AnyDiagnostic::MismatchedTupleStructPatArgCount(d) => handlers::mismatched_arg_count::mismatched_tuple_struct_pat_arg_count(&ctx, &d),
+            AnyDiagnostic::RemoveTrailingReturn(d) => handlers::remove_trailing_return::remove_trailing_return(&ctx, &d),
             AnyDiagnostic::RemoveUnnecessaryElse(d) => handlers::remove_unnecessary_else::remove_unnecessary_else(&ctx, &d),
         };
         res.push(d)