about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEsteban Küber <esteban@kuber.com.ar>2024-02-06 03:30:16 +0000
committerEsteban Küber <esteban@kuber.com.ar>2024-02-12 20:26:34 +0000
commit37d2ea2fa064411de78ec24a178a05dc02517673 (patch)
treee9f50135d04c2fc484cd435bf07704a751a8b468
parentbdc15928c8119a86d15e2946cb54851264607842 (diff)
downloadrust-37d2ea2fa064411de78ec24a178a05dc02517673.tar.gz
rust-37d2ea2fa064411de78ec24a178a05dc02517673.zip
Properly handle `async` blocks and `fn`s in `if` exprs without `else`
When encountering a tail expression in the then arm of an `if` expression
without an `else` arm, account for `async fn` and `async` blocks to
suggest `return`ing the value and pointing at the return type of the
`async fn`.

We now also account for AFIT when looking for the return type to point at.

Fix #115405.
-rw-r--r--compiler/rustc_hir_typeck/src/coercion.rs39
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs37
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs5
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs99
-rw-r--r--compiler/rustc_middle/src/hir/map/mod.rs2
-rw-r--r--compiler/rustc_parse/src/parser/diagnostics.rs2
-rw-r--r--tests/ui/async-await/missing-return-in-async-block.fixed22
-rw-r--r--tests/ui/async-await/missing-return-in-async-block.rs22
-rw-r--r--tests/ui/async-await/missing-return-in-async-block.stderr35
-rw-r--r--tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr2
-rw-r--r--tests/ui/loops/dont-suggest-break-thru-item.rs2
-rw-r--r--tests/ui/loops/dont-suggest-break-thru-item.stderr16
12 files changed, 236 insertions, 47 deletions
diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs
index 549ad44d7e3..882fa770016 100644
--- a/compiler/rustc_hir_typeck/src/coercion.rs
+++ b/compiler/rustc_hir_typeck/src/coercion.rs
@@ -92,14 +92,17 @@ impl<'a, 'tcx> Deref for Coerce<'a, 'tcx> {
 
 type CoerceResult<'tcx> = InferResult<'tcx, (Vec<Adjustment<'tcx>>, Ty<'tcx>)>;
 
-struct CollectRetsVisitor<'tcx> {
-    ret_exprs: Vec<&'tcx hir::Expr<'tcx>>,
+pub struct CollectRetsVisitor<'tcx> {
+    pub ret_exprs: Vec<&'tcx hir::Expr<'tcx>>,
 }
 
 impl<'tcx> Visitor<'tcx> for CollectRetsVisitor<'tcx> {
     fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
-        if let hir::ExprKind::Ret(_) = expr.kind {
-            self.ret_exprs.push(expr);
+        match expr.kind {
+            hir::ExprKind::Ret(_) => self.ret_exprs.push(expr),
+            // `return` in closures does not return from the outer function
+            hir::ExprKind::Closure(_) => return,
+            _ => {}
         }
         intravisit::walk_expr(self, expr);
     }
@@ -1845,13 +1848,31 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
         }
 
         let parent_id = fcx.tcx.hir().get_parent_item(id);
-        let parent_item = fcx.tcx.hir_node_by_def_id(parent_id.def_id);
+        let mut parent_item = fcx.tcx.hir_node_by_def_id(parent_id.def_id);
+        // When suggesting return, we need to account for closures and async blocks, not just items.
+        for (_, node) in fcx.tcx.hir().parent_iter(id) {
+            match node {
+                hir::Node::Expr(&hir::Expr {
+                    kind: hir::ExprKind::Closure(hir::Closure { .. }),
+                    ..
+                }) => {
+                    parent_item = node;
+                    break;
+                }
+                hir::Node::Item(_) | hir::Node::TraitItem(_) | hir::Node::ImplItem(_) => break,
+                _ => {}
+            }
+        }
 
-        if let (Some(expr), Some(_), Some((fn_id, fn_decl, _, _))) =
-            (expression, blk_id, fcx.get_node_fn_decl(parent_item))
-        {
+        if let (Some(expr), Some(_), Some(fn_decl)) = (expression, blk_id, parent_item.fn_decl()) {
             fcx.suggest_missing_break_or_return_expr(
-                &mut err, expr, fn_decl, expected, found, id, fn_id,
+                &mut err,
+                expr,
+                fn_decl,
+                expected,
+                found,
+                id,
+                parent_id.into(),
             );
         }
 
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
index 165937de247..3847f03a378 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
@@ -963,14 +963,35 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 owner_id,
                 ..
             }) => Some((hir::HirId::make_owner(owner_id.def_id), &sig.decl, ident, false)),
-            Node::Expr(&hir::Expr { hir_id, kind: hir::ExprKind::Closure(..), .. })
-                if let Node::Item(&hir::Item {
-                    ident,
-                    kind: hir::ItemKind::Fn(ref sig, ..),
-                    owner_id,
-                    ..
-                }) = self.tcx.parent_hir_node(hir_id) =>
-            {
+            Node::Expr(&hir::Expr {
+                hir_id,
+                kind:
+                    hir::ExprKind::Closure(hir::Closure {
+                        kind: hir::ClosureKind::Coroutine(..), ..
+                    }),
+                ..
+            }) => {
+                let (ident, sig, owner_id) = match self.tcx.parent_hir_node(hir_id) {
+                    Node::Item(&hir::Item {
+                        ident,
+                        kind: hir::ItemKind::Fn(ref sig, ..),
+                        owner_id,
+                        ..
+                    }) => (ident, sig, owner_id),
+                    Node::TraitItem(&hir::TraitItem {
+                        ident,
+                        kind: hir::TraitItemKind::Fn(ref sig, ..),
+                        owner_id,
+                        ..
+                    }) => (ident, sig, owner_id),
+                    Node::ImplItem(&hir::ImplItem {
+                        ident,
+                        kind: hir::ImplItemKind::Fn(ref sig, ..),
+                        owner_id,
+                        ..
+                    }) => (ident, sig, owner_id),
+                    _ => return None,
+                };
                 Some((
                     hir::HirId::make_owner(owner_id.def_id),
                     &sig.decl,
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
index 35b3f27d791..65b8505c090 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
@@ -1726,7 +1726,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     }
 
     /// Given a function block's `HirId`, returns its `FnDecl` if it exists, or `None` otherwise.
-    fn get_parent_fn_decl(&self, blk_id: hir::HirId) -> Option<(&'tcx hir::FnDecl<'tcx>, Ident)> {
+    pub(crate) fn get_parent_fn_decl(
+        &self,
+        blk_id: hir::HirId,
+    ) -> Option<(&'tcx hir::FnDecl<'tcx>, Ident)> {
         let parent = self.tcx.hir_node_by_def_id(self.tcx.hir().get_parent_item(blk_id).def_id);
         self.get_node_fn_decl(parent).map(|(_, fn_decl, ident, _)| (fn_decl, ident))
     }
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
index 193c9a4b908..38cc1f5c102 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
@@ -1,5 +1,6 @@
 use super::FnCtxt;
 
+use crate::coercion::CollectRetsVisitor;
 use crate::errors;
 use crate::fluent_generated as fluent;
 use crate::fn_ctxt::rustc_span::BytePos;
@@ -16,6 +17,7 @@ use rustc_errors::{Applicability, Diagnostic, MultiSpan};
 use rustc_hir as hir;
 use rustc_hir::def::Res;
 use rustc_hir::def::{CtorKind, CtorOf, DefKind};
+use rustc_hir::intravisit::{Map, Visitor};
 use rustc_hir::lang_items::LangItem;
 use rustc_hir::{
     CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node,
@@ -827,6 +829,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             }
             hir::FnRetTy::Return(hir_ty) => {
                 if let hir::TyKind::OpaqueDef(item_id, ..) = hir_ty.kind
+                    // FIXME: account for RPITIT.
                     && let hir::Node::Item(hir::Item {
                         kind: hir::ItemKind::OpaqueTy(op_ty), ..
                     }) = self.tcx.hir_node(item_id.hir_id())
@@ -1038,33 +1041,81 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             return;
         }
 
-        if let hir::FnRetTy::Return(ty) = fn_decl.output {
-            let ty = self.astconv().ast_ty_to_ty(ty);
-            let bound_vars = self.tcx.late_bound_vars(fn_id);
-            let ty = self
-                .tcx
-                .instantiate_bound_regions_with_erased(Binder::bind_with_vars(ty, bound_vars));
-            let ty = match self.tcx.asyncness(fn_id.owner) {
-                ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
-                    span_bug!(fn_decl.output.span(), "failed to get output type of async function")
-                }),
-                ty::Asyncness::No => ty,
-            };
-            let ty = self.normalize(expr.span, ty);
-            if self.can_coerce(found, ty) {
-                if let Some(owner_node) = self.tcx.hir_node(fn_id).as_owner()
-                    && let Some(span) = expr.span.find_ancestor_inside(*owner_node.span())
+        let in_closure = matches!(
+            self.tcx
+                .hir()
+                .parent_iter(id)
+                .filter(|(_, node)| {
+                    matches!(
+                        node,
+                        Node::Expr(Expr { kind: ExprKind::Closure(..), .. })
+                            | Node::Item(_)
+                            | Node::TraitItem(_)
+                            | Node::ImplItem(_)
+                    )
+                })
+                .next(),
+            Some((_, Node::Expr(Expr { kind: ExprKind::Closure(..), .. })))
+        );
+
+        let can_return = match fn_decl.output {
+            hir::FnRetTy::Return(ty) => {
+                let ty = self.astconv().ast_ty_to_ty(ty);
+                let bound_vars = self.tcx.late_bound_vars(fn_id);
+                let ty = self
+                    .tcx
+                    .instantiate_bound_regions_with_erased(Binder::bind_with_vars(ty, bound_vars));
+                let ty = match self.tcx.asyncness(fn_id.owner) {
+                    ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
+                        span_bug!(
+                            fn_decl.output.span(),
+                            "failed to get output type of async function"
+                        )
+                    }),
+                    ty::Asyncness::No => ty,
+                };
+                let ty = self.normalize(expr.span, ty);
+                self.can_coerce(found, ty)
+            }
+            hir::FnRetTy::DefaultReturn(_) if in_closure => {
+                let mut rets = vec![];
+                if let Some(ret_coercion) = self.ret_coercion.as_ref() {
+                    let ret_ty = ret_coercion.borrow().expected_ty();
+                    rets.push(ret_ty);
+                }
+                let mut visitor = CollectRetsVisitor { ret_exprs: vec![] };
+                if let Some(item) = self.tcx.hir().find(id)
+                    && let Node::Expr(expr) = item
                 {
-                    err.multipart_suggestion(
-                        "you might have meant to return this value",
-                        vec![
-                            (span.shrink_to_lo(), "return ".to_string()),
-                            (span.shrink_to_hi(), ";".to_string()),
-                        ],
-                        Applicability::MaybeIncorrect,
-                    );
+                    visitor.visit_expr(expr);
+                    for expr in visitor.ret_exprs {
+                        if let Some(ty) = self.typeck_results.borrow().node_type_opt(expr.hir_id) {
+                            rets.push(ty);
+                        }
+                    }
+                    if let hir::ExprKind::Block(hir::Block { expr: Some(expr), .. }, _) = expr.kind
+                    {
+                        if let Some(ty) = self.typeck_results.borrow().node_type_opt(expr.hir_id) {
+                            rets.push(ty);
+                        }
+                    }
                 }
+                rets.into_iter().all(|ty| self.can_coerce(found, ty))
             }
+            _ => false,
+        };
+        if can_return
+            && let Some(owner_node) = self.tcx.hir_node(fn_id).as_owner()
+            && let Some(span) = expr.span.find_ancestor_inside(owner_node.span())
+        {
+            err.multipart_suggestion(
+                "you might have meant to return this value",
+                vec![
+                    (span.shrink_to_lo(), "return ".to_string()),
+                    (span.shrink_to_hi(), ";".to_string()),
+                ],
+                Applicability::MaybeIncorrect,
+            );
         }
     }
 
diff --git a/compiler/rustc_middle/src/hir/map/mod.rs b/compiler/rustc_middle/src/hir/map/mod.rs
index 8e1cb6a514f..e7d9dc04886 100644
--- a/compiler/rustc_middle/src/hir/map/mod.rs
+++ b/compiler/rustc_middle/src/hir/map/mod.rs
@@ -617,7 +617,7 @@ impl<'hir> Map<'hir> {
                 Node::Item(_)
                 | Node::ForeignItem(_)
                 | Node::TraitItem(_)
-                | Node::Expr(Expr { kind: ExprKind::Closure { .. }, .. })
+                | Node::Expr(Expr { kind: ExprKind::Closure(_), .. })
                 | Node::ImplItem(_)
                     // The input node `id` must be enclosed in the method's body as opposed
                     // to some other place such as its return type (fixes #114918).
diff --git a/compiler/rustc_parse/src/parser/diagnostics.rs b/compiler/rustc_parse/src/parser/diagnostics.rs
index 7a24b819b5f..445d5b2ce79 100644
--- a/compiler/rustc_parse/src/parser/diagnostics.rs
+++ b/compiler/rustc_parse/src/parser/diagnostics.rs
@@ -900,7 +900,7 @@ impl<'a> Parser<'a> {
             // fn foo() -> Foo {
             //     field: value,
             // }
-            info!(?maybe_struct_name, ?self.token);
+            debug!(?maybe_struct_name, ?self.token);
             let mut snapshot = self.create_snapshot_for_diagnostic();
             let path = Path {
                 segments: ThinVec::new(),
diff --git a/tests/ui/async-await/missing-return-in-async-block.fixed b/tests/ui/async-await/missing-return-in-async-block.fixed
new file mode 100644
index 00000000000..3dbac7945b6
--- /dev/null
+++ b/tests/ui/async-await/missing-return-in-async-block.fixed
@@ -0,0 +1,22 @@
+// run-rustfix
+// edition:2021
+use std::future::Future;
+use std::pin::Pin;
+pub struct S;
+pub fn foo() {
+    let _ = Box::pin(async move {
+        if true {
+            return Ok(S); //~ ERROR mismatched types
+        }
+        Err(())
+    });
+}
+pub fn bar() -> Pin<Box<dyn Future<Output = Result<S, ()>> + 'static>> {
+    Box::pin(async move {
+        if true {
+            return Ok(S); //~ ERROR mismatched types
+        }
+        Err(())
+    })
+}
+fn main() {}
diff --git a/tests/ui/async-await/missing-return-in-async-block.rs b/tests/ui/async-await/missing-return-in-async-block.rs
new file mode 100644
index 00000000000..7d04e0e0fad
--- /dev/null
+++ b/tests/ui/async-await/missing-return-in-async-block.rs
@@ -0,0 +1,22 @@
+// run-rustfix
+// edition:2021
+use std::future::Future;
+use std::pin::Pin;
+pub struct S;
+pub fn foo() {
+    let _ = Box::pin(async move {
+        if true {
+            Ok(S) //~ ERROR mismatched types
+        }
+        Err(())
+    });
+}
+pub fn bar() -> Pin<Box<dyn Future<Output = Result<S, ()>> + 'static>> {
+    Box::pin(async move {
+        if true {
+            Ok(S) //~ ERROR mismatched types
+        }
+        Err(())
+    })
+}
+fn main() {}
diff --git a/tests/ui/async-await/missing-return-in-async-block.stderr b/tests/ui/async-await/missing-return-in-async-block.stderr
new file mode 100644
index 00000000000..5ea76e5f7bf
--- /dev/null
+++ b/tests/ui/async-await/missing-return-in-async-block.stderr
@@ -0,0 +1,35 @@
+error[E0308]: mismatched types
+  --> $DIR/missing-return-in-async-block.rs:9:13
+   |
+LL | /         if true {
+LL | |             Ok(S)
+   | |             ^^^^^ expected `()`, found `Result<S, _>`
+LL | |         }
+   | |_________- expected this to be `()`
+   |
+   = note: expected unit type `()`
+                   found enum `Result<S, _>`
+help: you might have meant to return this value
+   |
+LL |             return Ok(S);
+   |             ++++++      +
+
+error[E0308]: mismatched types
+  --> $DIR/missing-return-in-async-block.rs:17:13
+   |
+LL | /         if true {
+LL | |             Ok(S)
+   | |             ^^^^^ expected `()`, found `Result<S, _>`
+LL | |         }
+   | |_________- expected this to be `()`
+   |
+   = note: expected unit type `()`
+                   found enum `Result<S, _>`
+help: you might have meant to return this value
+   |
+LL |             return Ok(S);
+   |             ++++++      +
+
+error: aborting due to 2 previous errors
+
+For more information about this error, try `rustc --explain E0308`.
diff --git a/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr b/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr
index 77f6945f064..9fa73d817ca 100644
--- a/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr
+++ b/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr
@@ -1,6 +1,8 @@
 error[E0308]: mismatched types
   --> $DIR/default-body-type-err-2.rs:7:9
    |
+LL |     async fn woopsie_async(&self) -> String {
+   |                                      ------ expected `String` because of return type
 LL |         42
    |         ^^- help: try using a conversion method: `.to_string()`
    |         |
diff --git a/tests/ui/loops/dont-suggest-break-thru-item.rs b/tests/ui/loops/dont-suggest-break-thru-item.rs
index b46ba89e81d..308101115e5 100644
--- a/tests/ui/loops/dont-suggest-break-thru-item.rs
+++ b/tests/ui/loops/dont-suggest-break-thru-item.rs
@@ -8,6 +8,7 @@ fn closure() {
             if true {
                 Err(1)
                 //~^ ERROR mismatched types
+                //~| HELP you might have meant to return this value
             }
 
             Ok(())
@@ -21,6 +22,7 @@ fn async_block() {
             if true {
                 Err(1)
                 //~^ ERROR mismatched types
+                //~| HELP you might have meant to return this value
             }
 
             Ok(())
diff --git a/tests/ui/loops/dont-suggest-break-thru-item.stderr b/tests/ui/loops/dont-suggest-break-thru-item.stderr
index 4fce4715119..c84a98198f5 100644
--- a/tests/ui/loops/dont-suggest-break-thru-item.stderr
+++ b/tests/ui/loops/dont-suggest-break-thru-item.stderr
@@ -5,27 +5,37 @@ LL | /             if true {
 LL | |                 Err(1)
    | |                 ^^^^^^ expected `()`, found `Result<_, {integer}>`
 LL | |
+LL | |
 LL | |             }
    | |_____________- expected this to be `()`
    |
    = note: expected unit type `()`
                    found enum `Result<_, {integer}>`
+help: you might have meant to return this value
+   |
+LL |                 return Err(1);
+   |                 ++++++       +
 
 error[E0308]: mismatched types
-  --> $DIR/dont-suggest-break-thru-item.rs:22:17
+  --> $DIR/dont-suggest-break-thru-item.rs:23:17
    |
 LL | /             if true {
 LL | |                 Err(1)
    | |                 ^^^^^^ expected `()`, found `Result<_, {integer}>`
 LL | |
+LL | |
 LL | |             }
    | |_____________- expected this to be `()`
    |
    = note: expected unit type `()`
                    found enum `Result<_, {integer}>`
+help: you might have meant to return this value
+   |
+LL |                 return Err(1);
+   |                 ++++++       +
 
 error[E0308]: mismatched types
-  --> $DIR/dont-suggest-break-thru-item.rs:35:17
+  --> $DIR/dont-suggest-break-thru-item.rs:37:17
    |
 LL | /             if true {
 LL | |                 Err(1)
@@ -38,7 +48,7 @@ LL | |             }
                    found enum `Result<_, {integer}>`
 
 error[E0308]: mismatched types
-  --> $DIR/dont-suggest-break-thru-item.rs:47:17
+  --> $DIR/dont-suggest-break-thru-item.rs:49:17
    |
 LL | /             if true {
 LL | |                 Err(1)