about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2022-07-13 05:39:01 +0000
committerMichael Goulet <michael@errs.io>2022-07-14 23:36:46 +0000
commitd25abdc0c52cc08cdd290be325f1be04f3cea548 (patch)
treec439ec566cdce1199ad5d9c8946ff2c791ae4d11
parentddb7003b7950a6df815d1f03ea6eca9f4cc2a408 (diff)
downloadrust-d25abdc0c52cc08cdd290be325f1be04f3cea548.tar.gz
rust-d25abdc0c52cc08cdd290be325f1be04f3cea548.zip
Point out custom Fn-family trait impl
-rw-r--r--compiler/rustc_middle/src/ty/closure.rs8
-rw-r--r--compiler/rustc_typeck/src/check/fn_ctxt/checks.rs143
-rw-r--r--src/test/ui/mismatched_types/overloaded-calls-bad.stderr16
3 files changed, 114 insertions, 53 deletions
diff --git a/compiler/rustc_middle/src/ty/closure.rs b/compiler/rustc_middle/src/ty/closure.rs
index f5ce43f3afb..8ead0512274 100644
--- a/compiler/rustc_middle/src/ty/closure.rs
+++ b/compiler/rustc_middle/src/ty/closure.rs
@@ -128,6 +128,14 @@ impl<'tcx> ClosureKind {
             None
         }
     }
+
+    pub fn to_def_id(&self, tcx: TyCtxt<'_>) -> DefId {
+        match self {
+            ClosureKind::Fn => tcx.lang_items().fn_once_trait().unwrap(),
+            ClosureKind::FnMut => tcx.lang_items().fn_mut_trait().unwrap(),
+            ClosureKind::FnOnce => tcx.lang_items().fn_trait().unwrap(),
+        }
+    }
 }
 
 /// A composite describing a `Place` that is captured by a closure.
diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs
index 3dd63d74c3f..ec045d3e70c 100644
--- a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs
+++ b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs
@@ -21,6 +21,7 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::{ExprKind, Node, QPath};
 use rustc_index::vec::IndexVec;
 use rustc_infer::infer::error_reporting::{FailureCode, ObligationCauseExt};
+use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
 use rustc_infer::infer::InferOk;
 use rustc_infer::infer::TypeTrace;
 use rustc_middle::ty::adjustment::AllowTwoPhase;
@@ -29,7 +30,9 @@ use rustc_middle::ty::{self, DefIdTree, IsSuggestable, Ty};
 use rustc_session::Session;
 use rustc_span::symbol::Ident;
 use rustc_span::{self, Span};
-use rustc_trait_selection::traits::{self, ObligationCauseCode, StatementAsExpression};
+use rustc_trait_selection::traits::{
+    self, ObligationCauseCode, SelectionContext, StatementAsExpression,
+};
 
 use std::iter;
 use std::slice;
@@ -393,41 +396,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         }
 
         if !call_appears_satisfied {
-            // Next, let's construct the error
-            let (error_span, full_call_span, ctor_of) = match &call_expr.kind {
-                hir::ExprKind::Call(
-                    hir::Expr {
-                        span,
-                        kind:
-                            hir::ExprKind::Path(hir::QPath::Resolved(
-                                _,
-                                hir::Path { res: Res::Def(DefKind::Ctor(of, _), _), .. },
-                            )),
-                        ..
-                    },
-                    _,
-                ) => (call_span, *span, Some(of)),
-                hir::ExprKind::Call(hir::Expr { span, .. }, _) => (call_span, *span, None),
-                hir::ExprKind::MethodCall(path_segment, _, span) => {
-                    let ident_span = path_segment.ident.span;
-                    let ident_span = if let Some(args) = path_segment.args {
-                        ident_span.with_hi(args.span_ext.hi())
-                    } else {
-                        ident_span
-                    };
-                    (
-                        *span, ident_span, None, // methods are never ctors
-                    )
-                }
-                k => span_bug!(call_span, "checking argument types on a non-call: `{:?}`", k),
-            };
-            let args_span = error_span.trim_start(full_call_span).unwrap_or(error_span);
-            let call_name = match ctor_of {
-                Some(CtorOf::Struct) => "struct",
-                Some(CtorOf::Variant) => "enum variant",
-                None => "function",
-            };
-
             let compatibility_diagonal = IndexVec::from_raw(compatibility_diagonal);
             let provided_args = IndexVec::from_iter(provided_args.iter().take(if c_variadic {
                 minimum_input_count
@@ -451,13 +419,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 compatibility_diagonal,
                 formal_and_expected_inputs,
                 provided_args,
-                full_call_span,
-                error_span,
-                args_span,
-                call_name,
                 c_variadic,
                 err_code,
                 fn_def_id,
+                call_span,
                 call_expr,
             );
         }
@@ -468,15 +433,47 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         compatibility_diagonal: IndexVec<ProvidedIdx, Compatibility<'tcx>>,
         formal_and_expected_inputs: IndexVec<ExpectedIdx, (Ty<'tcx>, Ty<'tcx>)>,
         provided_args: IndexVec<ProvidedIdx, &'tcx hir::Expr<'tcx>>,
-        full_call_span: Span,
-        error_span: Span,
-        args_span: Span,
-        call_name: &str,
         c_variadic: bool,
         err_code: &str,
         fn_def_id: Option<DefId>,
+        call_span: Span,
         call_expr: &hir::Expr<'tcx>,
     ) {
+        // Next, let's construct the error
+        let (error_span, full_call_span, ctor_of) = match &call_expr.kind {
+            hir::ExprKind::Call(
+                hir::Expr {
+                    span,
+                    kind:
+                        hir::ExprKind::Path(hir::QPath::Resolved(
+                            _,
+                            hir::Path { res: Res::Def(DefKind::Ctor(of, _), _), .. },
+                        )),
+                    ..
+                },
+                _,
+            ) => (call_span, *span, Some(of)),
+            hir::ExprKind::Call(hir::Expr { span, .. }, _) => (call_span, *span, None),
+            hir::ExprKind::MethodCall(path_segment, _, span) => {
+                let ident_span = path_segment.ident.span;
+                let ident_span = if let Some(args) = path_segment.args {
+                    ident_span.with_hi(args.span_ext.hi())
+                } else {
+                    ident_span
+                };
+                (
+                    *span, ident_span, None, // methods are never ctors
+                )
+            }
+            k => span_bug!(call_span, "checking argument types on a non-call: `{:?}`", k),
+        };
+        let args_span = error_span.trim_start(full_call_span).unwrap_or(error_span);
+        let call_name = match ctor_of {
+            Some(CtorOf::Struct) => "struct",
+            Some(CtorOf::Variant) => "enum variant",
+            None => "function",
+        };
+
         // Don't print if it has error types or is just plain `_`
         fn has_error_or_infer<'tcx>(tys: impl IntoIterator<Item = Ty<'tcx>>) -> bool {
             tys.into_iter().any(|ty| ty.references_error() || ty.is_ty_var())
@@ -1818,17 +1815,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     fn label_fn_like(
         &self,
         err: &mut rustc_errors::DiagnosticBuilder<'tcx, rustc_errors::ErrorGuaranteed>,
-        def_id: Option<DefId>,
+        callable_def_id: Option<DefId>,
         callee_ty: Option<Ty<'tcx>>,
     ) {
-        let Some(mut def_id) = def_id else {
+        let Some(mut def_id) = callable_def_id else {
             return;
         };
 
         if let Some(assoc_item) = self.tcx.opt_associated_item(def_id)
-            && let trait_def_id = assoc_item.trait_item_def_id.unwrap_or_else(|| self.tcx.parent(def_id))
+            // Possibly points at either impl or trait item, so try to get it
+            // to point to trait item, then get the parent.
+            // This parent might be an impl in the case of an inherent function,
+            // but the next check will fail.
+            && let maybe_trait_item_def_id = assoc_item.trait_item_def_id.unwrap_or(def_id)
+            && let maybe_trait_def_id = self.tcx.parent(maybe_trait_item_def_id)
             // Just an easy way to check "trait_def_id == Fn/FnMut/FnOnce"
-            && ty::ClosureKind::from_def_id(self.tcx, trait_def_id).is_some()
+            && let Some(call_kind) = ty::ClosureKind::from_def_id(self.tcx, maybe_trait_def_id)
             && let Some(callee_ty) = callee_ty
         {
             let callee_ty = callee_ty.peel_refs();
@@ -1853,7 +1855,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                             std::iter::zip(instantiated.predicates, instantiated.spans)
                         {
                             if let ty::PredicateKind::Trait(pred) = predicate.kind().skip_binder()
-                                && pred.self_ty() == callee_ty
+                                && pred.self_ty().peel_refs() == callee_ty
                                 && ty::ClosureKind::from_def_id(self.tcx, pred.def_id()).is_some()
                             {
                                 err.span_note(span, "callable defined here");
@@ -1862,11 +1864,46 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         }
                     }
                 }
-                ty::Opaque(new_def_id, _) | ty::Closure(new_def_id, _) | ty::FnDef(new_def_id, _) => {
+                ty::Opaque(new_def_id, _)
+                | ty::Closure(new_def_id, _)
+                | ty::FnDef(new_def_id, _) => {
                     def_id = new_def_id;
                 }
                 _ => {
-                    return;
+                    // Look for a user-provided impl of a `Fn` trait, and point to it.
+                    let new_def_id = self.probe(|_| {
+                        let trait_ref = ty::TraitRef::new(
+                            call_kind.to_def_id(self.tcx),
+                            self.tcx.mk_substs([
+                                ty::GenericArg::from(callee_ty),
+                                self.next_ty_var(TypeVariableOrigin {
+                                    kind: TypeVariableOriginKind::MiscVariable,
+                                    span: rustc_span::DUMMY_SP,
+                                })
+                                .into(),
+                            ].into_iter()),
+                        );
+                        let obligation = traits::Obligation::new(
+                            traits::ObligationCause::dummy(),
+                            self.param_env,
+                            ty::Binder::dummy(ty::TraitPredicate {
+                                trait_ref,
+                                constness: ty::BoundConstness::NotConst,
+                                polarity: ty::ImplPolarity::Positive,
+                            }),
+                        );
+                        match SelectionContext::new(&self).select(&obligation) {
+                            Ok(Some(traits::ImplSource::UserDefined(impl_source))) => {
+                                Some(impl_source.impl_def_id)
+                            }
+                            _ => None
+                        }
+                    });
+                    if let Some(new_def_id) = new_def_id {
+                        def_id = new_def_id;
+                    } else {
+                        return;
+                    }
                 }
             }
         }
@@ -1888,8 +1925,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
             let def_kind = self.tcx.def_kind(def_id);
             err.span_note(spans, &format!("{} defined here", def_kind.descr(def_id)));
-        } else if let def_kind @ (DefKind::Closure | DefKind::OpaqueTy) = self.tcx.def_kind(def_id)
-        {
+        } else {
+            let def_kind = self.tcx.def_kind(def_id);
             err.span_note(
                 self.tcx.def_span(def_id),
                 &format!("{} defined here", def_kind.descr(def_id)),
diff --git a/src/test/ui/mismatched_types/overloaded-calls-bad.stderr b/src/test/ui/mismatched_types/overloaded-calls-bad.stderr
index 5ed15468fd6..475ea9dfaf1 100644
--- a/src/test/ui/mismatched_types/overloaded-calls-bad.stderr
+++ b/src/test/ui/mismatched_types/overloaded-calls-bad.stderr
@@ -5,6 +5,12 @@ LL |     let ans = s("what");
    |               - ^^^^^^ expected `isize`, found `&str`
    |               |
    |               arguments to this function are incorrect
+   |
+note: implementation defined here
+  --> $DIR/overloaded-calls-bad.rs:10:1
+   |
+LL | impl FnMut<(isize,)> for S {
+   | ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 error[E0057]: this function takes 1 argument but 0 arguments were supplied
   --> $DIR/overloaded-calls-bad.rs:29:15
@@ -12,6 +18,11 @@ error[E0057]: this function takes 1 argument but 0 arguments were supplied
 LL |     let ans = s();
    |               ^-- an argument of type `isize` is missing
    |
+note: implementation defined here
+  --> $DIR/overloaded-calls-bad.rs:10:1
+   |
+LL | impl FnMut<(isize,)> for S {
+   | ^^^^^^^^^^^^^^^^^^^^^^^^^^
 help: provide the argument
    |
 LL |     let ans = s(/* isize */);
@@ -25,6 +36,11 @@ LL |     let ans = s("burma", "shave");
    |                 |
    |                 expected `isize`, found `&str`
    |
+note: implementation defined here
+  --> $DIR/overloaded-calls-bad.rs:10:1
+   |
+LL | impl FnMut<(isize,)> for S {
+   | ^^^^^^^^^^^^^^^^^^^^^^^^^^
 help: remove the extra argument
    |
 LL |     let ans = s(/* isize */);