about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-04-10 18:42:48 -0400
committerMichael Goulet <michael@errs.io>2024-04-10 18:58:15 -0400
commit3764af6119ee32061de6ff25effb2bf72dced487 (patch)
tree6c585d040b6a2e954f651b55971df0c587473e9e
parentb3bd7058c139e71bae0862ef8f8ac936208873e9 (diff)
downloadrust-3764af6119ee32061de6ff25effb2bf72dced487.tar.gz
rust-3764af6119ee32061de6ff25effb2bf72dced487.zip
Use suggest_impl_trait in return type suggestion
-rw-r--r--compiler/rustc_hir_analysis/src/collect.rs98
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs12
-rw-r--r--tests/ui/async-await/track-caller/async-closure-gate.afn.stderr4
-rw-r--r--tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr4
-rw-r--r--tests/ui/suggestions/return-closures.stderr2
-rw-r--r--tests/ui/typeck/return_type_containing_closure.rs2
-rw-r--r--tests/ui/typeck/return_type_containing_closure.stderr6
7 files changed, 68 insertions, 60 deletions
diff --git a/compiler/rustc_hir_analysis/src/collect.rs b/compiler/rustc_hir_analysis/src/collect.rs
index a705d3bc107..c1bf65367aa 100644
--- a/compiler/rustc_hir_analysis/src/collect.rs
+++ b/compiler/rustc_hir_analysis/src/collect.rs
@@ -30,7 +30,7 @@ use rustc_middle::query::Providers;
 use rustc_middle::ty::util::{Discr, IntTypeExt};
 use rustc_middle::ty::{self, AdtKind, Const, IsSuggestable, ToPredicate, Ty, TyCtxt};
 use rustc_span::symbol::{kw, sym, Ident, Symbol};
-use rustc_span::Span;
+use rustc_span::{Span, DUMMY_SP};
 use rustc_target::abi::FieldIdx;
 use rustc_target::spec::abi;
 use rustc_trait_selection::infer::InferCtxtExt;
@@ -1383,7 +1383,9 @@ fn infer_return_ty_for_fn_sig<'tcx>(
                     Applicability::MachineApplicable,
                 );
                 should_recover = true;
-            } else if let Some(sugg) = suggest_impl_trait(tcx, ret_ty, ty.span, def_id) {
+            } else if let Some(sugg) =
+                suggest_impl_trait(&tcx.infer_ctxt().build(), tcx.param_env(def_id), ret_ty)
+            {
                 diag.span_suggestion(
                     ty.span,
                     "replace with an appropriate return type",
@@ -1426,11 +1428,10 @@ fn infer_return_ty_for_fn_sig<'tcx>(
     }
 }
 
-fn suggest_impl_trait<'tcx>(
-    tcx: TyCtxt<'tcx>,
+pub fn suggest_impl_trait<'tcx>(
+    infcx: &InferCtxt<'tcx>,
+    param_env: ty::ParamEnv<'tcx>,
     ret_ty: Ty<'tcx>,
-    span: Span,
-    def_id: LocalDefId,
 ) -> Option<String> {
     let format_as_assoc: fn(_, _, _, _, _) -> _ =
         |tcx: TyCtxt<'tcx>,
@@ -1464,24 +1465,28 @@ fn suggest_impl_trait<'tcx>(
 
     for (trait_def_id, assoc_item_def_id, formatter) in [
         (
-            tcx.get_diagnostic_item(sym::Iterator),
-            tcx.get_diagnostic_item(sym::IteratorItem),
+            infcx.tcx.get_diagnostic_item(sym::Iterator),
+            infcx.tcx.get_diagnostic_item(sym::IteratorItem),
             format_as_assoc,
         ),
         (
-            tcx.lang_items().future_trait(),
-            tcx.get_diagnostic_item(sym::FutureOutput),
+            infcx.tcx.lang_items().future_trait(),
+            infcx.tcx.get_diagnostic_item(sym::FutureOutput),
             format_as_assoc,
         ),
-        (tcx.lang_items().fn_trait(), tcx.lang_items().fn_once_output(), format_as_parenthesized),
         (
-            tcx.lang_items().fn_mut_trait(),
-            tcx.lang_items().fn_once_output(),
+            infcx.tcx.lang_items().fn_trait(),
+            infcx.tcx.lang_items().fn_once_output(),
+            format_as_parenthesized,
+        ),
+        (
+            infcx.tcx.lang_items().fn_mut_trait(),
+            infcx.tcx.lang_items().fn_once_output(),
             format_as_parenthesized,
         ),
         (
-            tcx.lang_items().fn_once_trait(),
-            tcx.lang_items().fn_once_output(),
+            infcx.tcx.lang_items().fn_once_trait(),
+            infcx.tcx.lang_items().fn_once_output(),
             format_as_parenthesized,
         ),
     ] {
@@ -1491,36 +1496,45 @@ fn suggest_impl_trait<'tcx>(
         let Some(assoc_item_def_id) = assoc_item_def_id else {
             continue;
         };
-        if tcx.def_kind(assoc_item_def_id) != DefKind::AssocTy {
+        if infcx.tcx.def_kind(assoc_item_def_id) != DefKind::AssocTy {
             continue;
         }
-        let param_env = tcx.param_env(def_id);
-        let infcx = tcx.infer_ctxt().build();
-        let args = ty::GenericArgs::for_item(tcx, trait_def_id, |param, _| {
-            if param.index == 0 { ret_ty.into() } else { infcx.var_for_def(span, param) }
+        let sugg = infcx.probe(|_| {
+            let args = ty::GenericArgs::for_item(infcx.tcx, trait_def_id, |param, _| {
+                if param.index == 0 { ret_ty.into() } else { infcx.var_for_def(DUMMY_SP, param) }
+            });
+            if !infcx
+                .type_implements_trait(trait_def_id, args, param_env)
+                .must_apply_modulo_regions()
+            {
+                return None;
+            }
+            let ocx = ObligationCtxt::new(&infcx);
+            let item_ty = ocx.normalize(
+                &ObligationCause::dummy(),
+                param_env,
+                Ty::new_projection(infcx.tcx, assoc_item_def_id, args),
+            );
+            // FIXME(compiler-errors): We may benefit from resolving regions here.
+            if ocx.select_where_possible().is_empty()
+                && let item_ty = infcx.resolve_vars_if_possible(item_ty)
+                && let Some(item_ty) = item_ty.make_suggestable(infcx.tcx, false, None)
+                && let Some(sugg) = formatter(
+                    infcx.tcx,
+                    infcx.resolve_vars_if_possible(args),
+                    trait_def_id,
+                    assoc_item_def_id,
+                    item_ty,
+                )
+            {
+                return Some(sugg);
+            }
+
+            None
         });
-        if !infcx.type_implements_trait(trait_def_id, args, param_env).must_apply_modulo_regions() {
-            continue;
-        }
-        let ocx = ObligationCtxt::new(&infcx);
-        let item_ty = ocx.normalize(
-            &ObligationCause::misc(span, def_id),
-            param_env,
-            Ty::new_projection(tcx, assoc_item_def_id, args),
-        );
-        // FIXME(compiler-errors): We may benefit from resolving regions here.
-        if ocx.select_where_possible().is_empty()
-            && let item_ty = infcx.resolve_vars_if_possible(item_ty)
-            && let Some(item_ty) = item_ty.make_suggestable(tcx, false, None)
-            && let Some(sugg) = formatter(
-                tcx,
-                infcx.resolve_vars_if_possible(args),
-                trait_def_id,
-                assoc_item_def_id,
-                item_ty,
-            )
-        {
-            return Some(sugg);
+
+        if sugg.is_some() {
+            return sugg;
         }
     }
     None
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
index 442bfd75746..1e1136ef467 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
@@ -20,6 +20,7 @@ use rustc_hir::{
     CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node,
     Path, QPath, Stmt, StmtKind, TyKind, WherePredicate,
 };
+use rustc_hir_analysis::collect::suggest_impl_trait;
 use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
 use rustc_infer::traits::{self};
 use rustc_middle::lint::in_external_macro;
@@ -814,17 +815,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         errors::AddReturnTypeSuggestion::Add { span, found: found.to_string() },
                     );
                     return true;
-                } else if let ty::Closure(_, args) = found.kind()
-                    // FIXME(compiler-errors): Get better at printing binders...
-                    && let closure = args.as_closure()
-                    && closure.sig().is_suggestable(self.tcx, false)
-                {
+                } else if let Some(sugg) = suggest_impl_trait(self, self.param_env, found) {
                     err.subdiagnostic(
                         self.dcx(),
-                        errors::AddReturnTypeSuggestion::Add {
-                            span,
-                            found: closure.print_as_impl_trait().to_string(),
-                        },
+                        errors::AddReturnTypeSuggestion::Add { span, found: sugg },
                     );
                     return true;
                 } else {
diff --git a/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr b/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr
index 92f38d5a796..640d946421a 100644
--- a/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr
+++ b/tests/ui/async-await/track-caller/async-closure-gate.afn.stderr
@@ -62,7 +62,7 @@ error[E0308]: mismatched types
   --> $DIR/async-closure-gate.rs:27:5
    |
 LL |   fn foo3() {
-   |            - help: a return type might be missing here: `-> _`
+   |            - help: try adding a return type: `-> impl Future<Output = ()>`
 LL | /     async {
 LL | |
 LL | |         let _ = #[track_caller] || {
@@ -78,7 +78,7 @@ error[E0308]: mismatched types
   --> $DIR/async-closure-gate.rs:44:5
    |
 LL |   fn foo5() {
-   |            - help: a return type might be missing here: `-> _`
+   |            - help: try adding a return type: `-> impl Future<Output = ()>`
 LL | /     async {
 LL | |
 LL | |         let _ = || {
diff --git a/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr b/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr
index 92f38d5a796..640d946421a 100644
--- a/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr
+++ b/tests/ui/async-await/track-caller/async-closure-gate.nofeat.stderr
@@ -62,7 +62,7 @@ error[E0308]: mismatched types
   --> $DIR/async-closure-gate.rs:27:5
    |
 LL |   fn foo3() {
-   |            - help: a return type might be missing here: `-> _`
+   |            - help: try adding a return type: `-> impl Future<Output = ()>`
 LL | /     async {
 LL | |
 LL | |         let _ = #[track_caller] || {
@@ -78,7 +78,7 @@ error[E0308]: mismatched types
   --> $DIR/async-closure-gate.rs:44:5
    |
 LL |   fn foo5() {
-   |            - help: a return type might be missing here: `-> _`
+   |            - help: try adding a return type: `-> impl Future<Output = ()>`
 LL | /     async {
 LL | |
 LL | |         let _ = || {
diff --git a/tests/ui/suggestions/return-closures.stderr b/tests/ui/suggestions/return-closures.stderr
index 97c13200ac3..ef1f50b8a6c 100644
--- a/tests/ui/suggestions/return-closures.stderr
+++ b/tests/ui/suggestions/return-closures.stderr
@@ -2,7 +2,7 @@ error[E0308]: mismatched types
   --> $DIR/return-closures.rs:3:5
    |
 LL | fn foo() {
-   |         - help: try adding a return type: `-> impl for<'a> Fn(&'a i32) -> i32`
+   |         - help: try adding a return type: `-> impl FnOnce(&i32) -> i32`
 LL |
 LL |     |x: &i32| 1i32
    |     ^^^^^^^^^^^^^^ expected `()`, found closure
diff --git a/tests/ui/typeck/return_type_containing_closure.rs b/tests/ui/typeck/return_type_containing_closure.rs
index 8b826daeede..b81cac0a58a 100644
--- a/tests/ui/typeck/return_type_containing_closure.rs
+++ b/tests/ui/typeck/return_type_containing_closure.rs
@@ -1,5 +1,5 @@
 #[allow(unused)]
-fn foo() { //~ HELP a return type might be missing here
+fn foo() { //~ HELP try adding a return type
     vec!['a'].iter().map(|c| c)
     //~^ ERROR mismatched types [E0308]
     //~| NOTE expected `()`, found `Map<Iter<'_, char>, ...>`
diff --git a/tests/ui/typeck/return_type_containing_closure.stderr b/tests/ui/typeck/return_type_containing_closure.stderr
index ea9c74be362..3f14650a82c 100644
--- a/tests/ui/typeck/return_type_containing_closure.stderr
+++ b/tests/ui/typeck/return_type_containing_closure.stderr
@@ -10,10 +10,10 @@ help: consider using a semicolon here
    |
 LL |     vec!['a'].iter().map(|c| c);
    |                                +
-help: a return type might be missing here
+help: try adding a return type
    |
-LL | fn foo() -> _ {
-   |          ++++
+LL | fn foo() -> impl Iterator<Item = &char> {
+   |          ++++++++++++++++++++++++++++++
 
 error: aborting due to 1 previous error