about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNilstrieb <48135649+Nilstrieb@users.noreply.github.com>2021-10-14 21:44:43 +0200
committerNilstrieb <48135649+Nilstrieb@users.noreply.github.com>2022-02-18 20:40:08 +0100
commit4bed7485da0706dfc30f019f94f07fcdf5122358 (patch)
tree96f2c77c5845e754b2906ea1bae4b8b1dcb6ba99
parentb8c56fa8c30821129b0960180f528d4a1a4f9316 (diff)
downloadrust-4bed7485da0706dfc30f019f94f07fcdf5122358.tar.gz
rust-4bed7485da0706dfc30f019f94f07fcdf5122358.zip
Suggest `impl Trait` return type
Address #85991

Suggest the `impl Trait` return type syntax if the user tried to return a generic parameter and we get a type mismatch

The suggestion is not emitted if the param appears in the function parameters, and only get the bounds that actually involve `T: ` directly

It also checks whether the generic param is contained in any where bound (where it isn't the self type), and if one is found (like `Option<T>: Send`), it is not suggested.

This also adds `TyS::contains`, which recursively vistits the type and looks if the other type is contained anywhere
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs24
-rw-r--r--compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs116
-rw-r--r--src/test/ui/return/return-impl-trait-bad.rs31
-rw-r--r--src/test/ui/return/return-impl-trait-bad.stderr59
-rw-r--r--src/test/ui/return/return-impl-trait.fixed30
-rw-r--r--src/test/ui/return/return-impl-trait.rs30
-rw-r--r--src/test/ui/return/return-impl-trait.stderr34
7 files changed, 321 insertions, 3 deletions
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index 9835211a748..7c6d6ea1cb6 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -8,7 +8,9 @@ use crate::infer::canonical::Canonical;
 use crate::ty::fold::ValidateBoundVars;
 use crate::ty::subst::{GenericArg, InternalSubsts, Subst, SubstsRef};
 use crate::ty::InferTy::{self, *};
-use crate::ty::{self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable};
+use crate::ty::{
+    self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable, TypeVisitor,
+};
 use crate::ty::{DelaySpanBugEmitted, List, ParamEnv};
 use polonius_engine::Atom;
 use rustc_data_structures::captures::Captures;
@@ -24,7 +26,7 @@ use std::borrow::Cow;
 use std::cmp::Ordering;
 use std::fmt;
 use std::marker::PhantomData;
-use std::ops::{Deref, Range};
+use std::ops::{ControlFlow, Deref, Range};
 use ty::util::IntTypeExt;
 
 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, TyEncodable, TyDecodable)]
@@ -2072,6 +2074,24 @@ impl<'tcx> Ty<'tcx> {
         !matches!(self.kind(), Param(_) | Infer(_) | Error(_))
     }
 
+    /// Checks whether a type recursively contains another type
+    ///
+    /// Example: `Option<()>` contains `()`
+    pub fn contains(self, other: Ty<'tcx>) -> bool {
+        struct ContainsTyVisitor<'tcx>(Ty<'tcx>);
+
+        impl<'tcx> TypeVisitor<'tcx> for ContainsTyVisitor<'tcx> {
+            type BreakTy = ();
+
+            fn visit_ty(&mut self, t: Ty<'tcx>) -> ControlFlow<Self::BreakTy> {
+                if self.0 == t { ControlFlow::BREAK } else { t.super_visit_with(self) }
+            }
+        }
+
+        let cf = self.visit_with(&mut ContainsTyVisitor(other));
+        cf.is_break()
+    }
+
     /// Returns the type and mutability of `*ty`.
     ///
     /// The parameter `explicit` indicates if this is an *explicit* dereference.
diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs b/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs
index 86cf850d723..f9c482713f1 100644
--- a/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs
+++ b/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs
@@ -8,8 +8,12 @@ use rustc_errors::{Applicability, DiagnosticBuilder};
 use rustc_hir as hir;
 use rustc_hir::def::{CtorOf, DefKind};
 use rustc_hir::lang_items::LangItem;
-use rustc_hir::{Expr, ExprKind, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind};
+use rustc_hir::{
+    Expr, ExprKind, GenericBound, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind,
+    WherePredicate,
+};
 use rustc_infer::infer::{self, TyCtxtInferExt};
+
 use rustc_middle::lint::in_external_macro;
 use rustc_middle::ty::{self, Binder, Ty};
 use rustc_span::symbol::{kw, sym};
@@ -559,6 +563,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 let ty = self.tcx.erase_late_bound_regions(ty);
                 if self.can_coerce(expected, ty) {
                     err.span_label(sp, format!("expected `{}` because of return type", expected));
+                    self.try_suggest_return_impl_trait(err, expected, ty, fn_id);
                     return true;
                 }
                 false
@@ -566,6 +571,115 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         }
     }
 
+    /// check whether the return type is a generic type with a trait bound
+    /// only suggest this if the generic param is not present in the arguments
+    /// if this is true, hint them towards changing the return type to `impl Trait`
+    /// ```
+    /// fn cant_name_it<T: Fn() -> u32>() -> T {
+    ///     || 3
+    /// }
+    /// ```
+    fn try_suggest_return_impl_trait(
+        &self,
+        err: &mut DiagnosticBuilder<'_>,
+        expected: Ty<'tcx>,
+        found: Ty<'tcx>,
+        fn_id: hir::HirId,
+    ) {
+        // Only apply the suggestion if:
+        //  - the return type is a generic parameter
+        //  - the generic param is not used as a fn param
+        //  - the generic param has at least one bound
+        //  - the generic param doesn't appear in any other bounds where it's not the Self type
+        // Suggest:
+        //  - Changing the return type to be `impl <all bounds>`
+
+        debug!("try_suggest_return_impl_trait, expected = {:?}, found = {:?}", expected, found);
+
+        let ty::Param(expected_ty_as_param) = expected.kind() else { return };
+
+        let fn_node = self.tcx.hir().find(fn_id);
+
+        let Some(hir::Node::Item(hir::Item {
+            kind:
+                hir::ItemKind::Fn(
+                    hir::FnSig { decl: hir::FnDecl { inputs: fn_parameters, output: fn_return, .. }, .. },
+                    hir::Generics { params, where_clause, .. },
+                    _body_id,
+                ),
+            ..
+        })) = fn_node else { return };
+
+        let Some(expected_generic_param) = params.get(expected_ty_as_param.index as usize) else { return };
+
+        // get all where BoundPredicates here, because they are used in to cases below
+        let where_predicates = where_clause
+            .predicates
+            .iter()
+            .filter_map(|p| match p {
+                WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
+                    bounds,
+                    bounded_ty,
+                    ..
+                }) => {
+                    // FIXME: Maybe these calls to `ast_ty_to_ty` can be removed (and the ones below)
+                    let ty = <dyn AstConv<'_>>::ast_ty_to_ty(self, bounded_ty);
+                    Some((ty, bounds))
+                }
+                _ => None,
+            })
+            .map(|(ty, bounds)| match ty.kind() {
+                ty::Param(param_ty) if param_ty == expected_ty_as_param => Ok(Some(bounds)),
+                // check whether there is any predicate that contains our `T`, like `Option<T>: Send`
+                _ => match ty.contains(expected) {
+                    true => Err(()),
+                    false => Ok(None),
+                },
+            })
+            .collect::<Result<Vec<_>, _>>();
+
+        let Ok(where_predicates) =  where_predicates else { return };
+
+        // now get all predicates in the same types as the where bounds, so we can chain them
+        let predicates_from_where =
+            where_predicates.iter().flatten().map(|bounds| bounds.iter()).flatten();
+
+        // extract all bounds from the source code using their spans
+        let all_matching_bounds_strs = expected_generic_param
+            .bounds
+            .iter()
+            .chain(predicates_from_where)
+            .filter_map(|bound| match bound {
+                GenericBound::Trait(_, _) => {
+                    self.tcx.sess.source_map().span_to_snippet(bound.span()).ok()
+                }
+                _ => None,
+            })
+            .collect::<Vec<String>>();
+
+        if all_matching_bounds_strs.len() == 0 {
+            return;
+        }
+
+        let all_bounds_str = all_matching_bounds_strs.join(" + ");
+
+        let ty_param_used_in_fn_params = fn_parameters.iter().any(|param| {
+                let ty = <dyn AstConv<'_>>::ast_ty_to_ty(self, param);
+                matches!(ty.kind(), ty::Param(fn_param_ty_param) if expected_ty_as_param == fn_param_ty_param)
+            });
+
+        if ty_param_used_in_fn_params {
+            return;
+        }
+
+        err.span_suggestion(
+            fn_return.span(),
+            "consider using an impl return type",
+            format!("impl {}", all_bounds_str),
+            Applicability::MaybeIncorrect,
+        );
+    }
+
     pub(in super::super) fn suggest_missing_break_or_return_expr(
         &self,
         err: &mut DiagnosticBuilder<'_>,
diff --git a/src/test/ui/return/return-impl-trait-bad.rs b/src/test/ui/return/return-impl-trait-bad.rs
new file mode 100644
index 00000000000..e3f6ddb9a14
--- /dev/null
+++ b/src/test/ui/return/return-impl-trait-bad.rs
@@ -0,0 +1,31 @@
+trait Trait {}
+impl Trait for () {}
+
+fn bad_echo<T>(_t: T) -> T {
+    "this should not suggest impl Trait" //~ ERROR mismatched types
+}
+
+fn bad_echo_2<T: Trait>(_t: T) -> T {
+    "this will not suggest it, because that would probably be wrong" //~ ERROR mismatched types
+}
+
+fn other_bounds_bad<T>() -> T
+where
+    T: Send,
+    Option<T>: Send,
+{
+    "don't suggest this, because Option<T> places additional constraints" //~ ERROR mismatched types
+}
+
+// FIXME: implement this check
+trait GenericTrait<T> {}
+
+fn used_in_trait<T>() -> T
+where
+    T: Send,
+    (): GenericTrait<T>,
+{
+    "don't suggest this, because the generic param is used in the bound." //~ ERROR mismatched types
+}
+
+fn main() {}
diff --git a/src/test/ui/return/return-impl-trait-bad.stderr b/src/test/ui/return/return-impl-trait-bad.stderr
new file mode 100644
index 00000000000..237b85ee66a
--- /dev/null
+++ b/src/test/ui/return/return-impl-trait-bad.stderr
@@ -0,0 +1,59 @@
+error[E0308]: mismatched types
+  --> $DIR/return-impl-trait-bad.rs:5:5
+   |
+LL | fn bad_echo<T>(_t: T) -> T {
+   |             -            - expected `T` because of return type
+   |             |
+   |             this type parameter
+LL |     "this should not suggest impl Trait"
+   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
+   |
+   = note: expected type parameter `T`
+                   found reference `&'static str`
+
+error[E0308]: mismatched types
+  --> $DIR/return-impl-trait-bad.rs:9:5
+   |
+LL | fn bad_echo_2<T: Trait>(_t: T) -> T {
+   |               -                   - expected `T` because of return type
+   |               |
+   |               this type parameter
+LL |     "this will not suggest it, because that would probably be wrong"
+   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
+   |
+   = note: expected type parameter `T`
+                   found reference `&'static str`
+
+error[E0308]: mismatched types
+  --> $DIR/return-impl-trait-bad.rs:17:5
+   |
+LL | fn other_bounds_bad<T>() -> T
+   |                     -       - expected `T` because of return type
+   |                     |
+   |                     this type parameter
+...
+LL |     "don't suggest this, because Option<T> places additional constraints"
+   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
+   |
+   = note: expected type parameter `T`
+                   found reference `&'static str`
+
+error[E0308]: mismatched types
+  --> $DIR/return-impl-trait-bad.rs:28:5
+   |
+LL | fn used_in_trait<T>() -> T
+   |                  -       -
+   |                  |       |
+   |                  |       expected `T` because of return type
+   |                  |       help: consider using an impl return type: `impl Send`
+   |                  this type parameter
+...
+LL |     "don't suggest this, because the generic param is used in the bound."
+   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
+   |
+   = note: expected type parameter `T`
+                   found reference `&'static str`
+
+error: aborting due to 4 previous errors
+
+For more information about this error, try `rustc --explain E0308`.
diff --git a/src/test/ui/return/return-impl-trait.fixed b/src/test/ui/return/return-impl-trait.fixed
new file mode 100644
index 00000000000..ff2b02f73ea
--- /dev/null
+++ b/src/test/ui/return/return-impl-trait.fixed
@@ -0,0 +1,30 @@
+// run-rustfix
+
+trait Trait {}
+impl Trait for () {}
+
+// this works
+fn foo() -> impl Trait {
+    ()
+}
+
+fn bar<T: Trait + std::marker::Sync>() -> impl Trait + std::marker::Sync + Send
+where
+    T: Send,
+{
+    () //~ ERROR mismatched types
+}
+
+fn other_bounds<T>() -> impl Trait
+where
+    T: Trait,
+    Vec<usize>: Clone,
+{
+    () //~ ERROR mismatched types
+}
+
+fn main() {
+    foo();
+    bar::<()>();
+    other_bounds::<()>();
+}
diff --git a/src/test/ui/return/return-impl-trait.rs b/src/test/ui/return/return-impl-trait.rs
new file mode 100644
index 00000000000..e905d712f62
--- /dev/null
+++ b/src/test/ui/return/return-impl-trait.rs
@@ -0,0 +1,30 @@
+// run-rustfix
+
+trait Trait {}
+impl Trait for () {}
+
+// this works
+fn foo() -> impl Trait {
+    ()
+}
+
+fn bar<T: Trait + std::marker::Sync>() -> T
+where
+    T: Send,
+{
+    () //~ ERROR mismatched types
+}
+
+fn other_bounds<T>() -> T
+where
+    T: Trait,
+    Vec<usize>: Clone,
+{
+    () //~ ERROR mismatched types
+}
+
+fn main() {
+    foo();
+    bar::<()>();
+    other_bounds::<()>();
+}
diff --git a/src/test/ui/return/return-impl-trait.stderr b/src/test/ui/return/return-impl-trait.stderr
new file mode 100644
index 00000000000..43d40972fca
--- /dev/null
+++ b/src/test/ui/return/return-impl-trait.stderr
@@ -0,0 +1,34 @@
+error[E0308]: mismatched types
+  --> $DIR/return-impl-trait.rs:15:5
+   |
+LL | fn bar<T: Trait + std::marker::Sync>() -> T
+   |        -                                  -
+   |        |                                  |
+   |        |                                  expected `T` because of return type
+   |        this type parameter                help: consider using an impl return type: `impl Trait + std::marker::Sync + Send`
+...
+LL |     ()
+   |     ^^ expected type parameter `T`, found `()`
+   |
+   = note: expected type parameter `T`
+                   found unit type `()`
+
+error[E0308]: mismatched types
+  --> $DIR/return-impl-trait.rs:23:5
+   |
+LL | fn other_bounds<T>() -> T
+   |                 -       -
+   |                 |       |
+   |                 |       expected `T` because of return type
+   |                 |       help: consider using an impl return type: `impl Trait`
+   |                 this type parameter
+...
+LL |     ()
+   |     ^^ expected type parameter `T`, found `()`
+   |
+   = note: expected type parameter `T`
+                   found unit type `()`
+
+error: aborting due to 2 previous errors
+
+For more information about this error, try `rustc --explain E0308`.