about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJason Newcomb <jsnewcomb@pm.me>2025-02-03 11:19:17 -0500
committerJason Newcomb <jsnewcomb@pm.me>2025-02-03 12:54:53 -0500
commit6378fbc366ad552ee791bcac670e0f3939489ef7 (patch)
treedb1ba7d3b1c1d4d1a0370b808c40a9b851dda452
parentb909c36f40035bcc0a25f8734ee6480685cba1b1 (diff)
downloadrust-6378fbc366ad552ee791bcac670e0f3939489ef7.tar.gz
rust-6378fbc366ad552ee791bcac670e0f3939489ef7.zip
Check for generic parameter mismatches on trait functions.
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs124
-rw-r--r--tests/ui/fn/param-mismatch-trait-fn.rs10
-rw-r--r--tests/ui/fn/param-mismatch-trait-fn.stderr23
-rw-r--r--tests/ui/methods/issues/issue-61525.stderr2
-rw-r--r--tests/ui/suggestions/trait-with-missing-associated-type-restriction.stderr2
-rw-r--r--tests/ui/traits/issue-52893.stderr2
6 files changed, 120 insertions, 43 deletions
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
index 7ee246e0774..e30c0e115dc 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
@@ -21,7 +21,7 @@ use rustc_middle::ty::visit::TypeVisitableExt;
 use rustc_middle::ty::{self, IsSuggestable, Ty, TyCtxt};
 use rustc_middle::{bug, span_bug};
 use rustc_session::Session;
-use rustc_span::{DUMMY_SP, Ident, Span, kw, sym};
+use rustc_span::{DUMMY_SP, Ident, Span, Symbol, kw, sym};
 use rustc_trait_selection::error_reporting::infer::{FailureCode, ObligationCauseExt};
 use rustc_trait_selection::infer::InferCtxtExt;
 use rustc_trait_selection::traits::{self, ObligationCauseCode, ObligationCtxt, SelectionContext};
@@ -2414,11 +2414,11 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     })
                 {
                     let Some(generic_param) = generic_param else {
-                        spans.push_span_label(param.span, "");
+                        spans.push_span_label(param.span(), "");
                         continue;
                     };
 
-                    let other_params_matched: Vec<(ExpectedIdx, &hir::Param<'_>)> =
+                    let other_params_matched: Vec<(ExpectedIdx, FnParam<'_>)> =
                         params_with_generics
                             .iter_enumerated()
                             .filter(|&(other_idx, &(other_generic_param, _))| {
@@ -2447,9 +2447,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         let other_param_matched_names: Vec<String> = other_params_matched
                             .iter()
                             .map(|(idx, other_param)| {
-                                if let hir::PatKind::Binding(_, _, ident, _) = other_param.pat.kind
-                                {
-                                    format!("`{ident}`")
+                                if let Some(name) = other_param.name() {
+                                    format!("`{name}`")
                                 } else {
                                     format!("parameter #{}", idx.as_u32() + 1)
                                 }
@@ -2462,7 +2461,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
                         if matched_inputs[idx].is_some() {
                             spans.push_span_label(
-                                param.span,
+                                param.span(),
                                 format!(
                                     "{} need{} to match the {} type of this parameter",
                                     listify(&other_param_matched_names, |n| n.to_string())
@@ -2477,7 +2476,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                             );
                         } else {
                             spans.push_span_label(
-                                param.span,
+                                param.span(),
                                 format!(
                                     "this parameter needs to match the {} type of {}",
                                     matched_ty,
@@ -2488,7 +2487,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         }
                         generics_with_unmatched_params.push(generic_param);
                     } else {
-                        spans.push_span_label(param.span, "");
+                        spans.push_span_label(param.span(), "");
                     }
                 }
 
@@ -2515,8 +2514,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                             }
                         })
                         .map(|(idx, &(_, param))| {
-                            if let hir::PatKind::Binding(_, _, ident, _) = param.pat.kind {
-                                format!("`{ident}`")
+                            if let Some(name) = param.name() {
+                                format!("`{name}`")
                             } else {
                                 format!("parameter #{}", idx.as_u32() + 1)
                             }
@@ -2673,35 +2672,56 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         &self,
         def_id: DefId,
         is_method: bool,
-    ) -> Option<IndexVec<ExpectedIdx, (Option<&hir::GenericParam<'_>>, &hir::Param<'_>)>> {
-        let fn_node = self.tcx.hir().get_if_local(def_id)?;
-        let fn_decl = fn_node.fn_decl()?;
-        let generic_params = fn_node.generics()?.params;
-
-        // Remove both the receiver and variadic arguments. Neither can have an unmatched generic
-        // parameter.
-        let params = self.tcx.hir().body(fn_node.body_id()?).params;
-        let params = params.get(is_method as usize..params.len() - fn_decl.c_variadic as usize)?;
-        let fn_inputs = fn_decl.inputs.get(is_method as usize..)?;
-        debug_assert_eq!(params.len(), fn_inputs.len());
-
-        Some(
-            fn_inputs
-                .into_iter()
-                .map(|param| {
-                    if let hir::TyKind::Path(QPath::Resolved(
-                        _,
-                        &hir::Path { res: Res::Def(_, res_def_id), .. },
-                    )) = param.kind
-                    {
-                        generic_params.iter().find(|param| param.def_id.to_def_id() == res_def_id)
-                    } else {
-                        None
-                    }
-                })
-                .zip(params)
-                .collect(),
-        )
+    ) -> Option<IndexVec<ExpectedIdx, (Option<&hir::GenericParam<'_>>, FnParam<'_>)>> {
+        let (sig, generics, body_id, param_names) = match self.tcx.hir().get_if_local(def_id)? {
+            hir::Node::TraitItem(&hir::TraitItem {
+                generics,
+                kind: hir::TraitItemKind::Fn(sig, trait_fn),
+                ..
+            }) => match trait_fn {
+                hir::TraitFn::Required(params) => (sig, generics, None, Some(params)),
+                hir::TraitFn::Provided(body) => (sig, generics, Some(body), None),
+            },
+            hir::Node::ImplItem(&hir::ImplItem {
+                generics,
+                kind: hir::ImplItemKind::Fn(sig, body),
+                ..
+            })
+            | hir::Node::Item(&hir::Item {
+                kind: hir::ItemKind::Fn { sig, generics, body, .. },
+                ..
+            }) => (sig, generics, Some(body), None),
+            _ => return None,
+        };
+
+        // Make sure to remove both the receiver and variadic argument. Both are removed
+        // when matching parameter types.
+        let fn_inputs = sig.decl.inputs.get(is_method as usize..)?.iter().map(|param| {
+            if let hir::TyKind::Path(QPath::Resolved(
+                _,
+                &hir::Path { res: Res::Def(_, res_def_id), .. },
+            )) = param.kind
+            {
+                generics.params.iter().find(|param| param.def_id.to_def_id() == res_def_id)
+            } else {
+                None
+            }
+        });
+        match (body_id, param_names) {
+            (Some(_), Some(_)) | (None, None) => unreachable!(),
+            (Some(body), None) => {
+                let params = self.tcx.hir().body(body).params;
+                let params =
+                    params.get(is_method as usize..params.len() - sig.decl.c_variadic as usize)?;
+                debug_assert_eq!(params.len(), fn_inputs.len());
+                Some(fn_inputs.zip(params.iter().map(|param| FnParam::Param(param))).collect())
+            }
+            (None, Some(params)) => {
+                let params = params.get(is_method as usize..)?;
+                debug_assert_eq!(params.len(), fn_inputs.len());
+                Some(fn_inputs.zip(params.iter().map(|param| FnParam::Name(param))).collect())
+            }
+        }
     }
 }
 
@@ -2724,3 +2744,27 @@ impl<'tcx> Visitor<'tcx> for FindClosureArg<'tcx> {
         hir::intravisit::walk_expr(self, ex);
     }
 }
+
+#[derive(Clone, Copy)]
+enum FnParam<'hir> {
+    Param(&'hir hir::Param<'hir>),
+    Name(&'hir Ident),
+}
+impl FnParam<'_> {
+    fn span(&self) -> Span {
+        match self {
+            Self::Param(x) => x.span,
+            Self::Name(x) => x.span,
+        }
+    }
+
+    fn name(&self) -> Option<Symbol> {
+        match self {
+            Self::Param(x) if let hir::PatKind::Binding(_, _, ident, _) = x.pat.kind => {
+                Some(ident.name)
+            }
+            Self::Name(x) if x.name != kw::Empty => Some(x.name),
+            _ => None,
+        }
+    }
+}
diff --git a/tests/ui/fn/param-mismatch-trait-fn.rs b/tests/ui/fn/param-mismatch-trait-fn.rs
new file mode 100644
index 00000000000..69ded6a9068
--- /dev/null
+++ b/tests/ui/fn/param-mismatch-trait-fn.rs
@@ -0,0 +1,10 @@
+trait Foo {
+    fn same_type<T>(_: T, _: T);
+}
+
+fn f<T: Foo, X, Y>(x: X, y: Y) {
+    T::same_type([x], Some(y));
+    //~^ ERROR mismatched types
+}
+
+fn main() {}
diff --git a/tests/ui/fn/param-mismatch-trait-fn.stderr b/tests/ui/fn/param-mismatch-trait-fn.stderr
new file mode 100644
index 00000000000..28e1bcaaf49
--- /dev/null
+++ b/tests/ui/fn/param-mismatch-trait-fn.stderr
@@ -0,0 +1,23 @@
+error[E0308]: mismatched types
+  --> $DIR/param-mismatch-trait-fn.rs:6:23
+   |
+LL |     T::same_type([x], Some(y));
+   |     ------------ ---  ^^^^^^^ expected `[X; 1]`, found `Option<Y>`
+   |     |            |
+   |     |            expected all arguments to be this `[X; 1]` type because they need to match the type of this parameter
+   |     arguments to this function are incorrect
+   |
+   = note: expected array `[X; 1]`
+               found enum `Option<Y>`
+note: associated function defined here
+  --> $DIR/param-mismatch-trait-fn.rs:2:8
+   |
+LL |     fn same_type<T>(_: T, _: T);
+   |        ^^^^^^^^^ -  -     - this parameter needs to match the `[X; 1]` type of parameter #1
+   |                  |  |
+   |                  |  parameter #2 needs to match the `[X; 1]` type of this parameter
+   |                  parameter #1 and parameter #2 both reference this parameter `T`
+
+error: aborting due to 1 previous error
+
+For more information about this error, try `rustc --explain E0308`.
diff --git a/tests/ui/methods/issues/issue-61525.stderr b/tests/ui/methods/issues/issue-61525.stderr
index 35001ae22a6..7ac3d3dc0cf 100644
--- a/tests/ui/methods/issues/issue-61525.stderr
+++ b/tests/ui/methods/issues/issue-61525.stderr
@@ -32,7 +32,7 @@ note: method defined here
   --> $DIR/issue-61525.rs:2:8
    |
 LL |     fn query<Q>(self, q: Q);
-   |        ^^^^^
+   |        ^^^^^          -
 
 error: aborting due to 2 previous errors
 
diff --git a/tests/ui/suggestions/trait-with-missing-associated-type-restriction.stderr b/tests/ui/suggestions/trait-with-missing-associated-type-restriction.stderr
index 980c2455c8e..df59a28c4b9 100644
--- a/tests/ui/suggestions/trait-with-missing-associated-type-restriction.stderr
+++ b/tests/ui/suggestions/trait-with-missing-associated-type-restriction.stderr
@@ -94,7 +94,7 @@ note: method defined here
   --> $DIR/trait-with-missing-associated-type-restriction.rs:9:8
    |
 LL |     fn funk(&self, _: Self::A);
-   |        ^^^^
+   |        ^^^^        -
 help: consider constraining the associated type `<T as Trait<i32>>::A` to `{integer}`
    |
 LL | fn bar2<T: Trait<i32, A = {integer}>>(x: T) {
diff --git a/tests/ui/traits/issue-52893.stderr b/tests/ui/traits/issue-52893.stderr
index c37dde90e33..3c5df82fcdc 100644
--- a/tests/ui/traits/issue-52893.stderr
+++ b/tests/ui/traits/issue-52893.stderr
@@ -22,7 +22,7 @@ note: method defined here
   --> $DIR/issue-52893.rs:11:8
    |
 LL |     fn push(self, other: T) -> Self::PushRes;
-   |        ^^^^
+   |        ^^^^       -----
 
 error: aborting due to 1 previous error