about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir-ty/src/infer/unify.rs12
-rw-r--r--crates/hir-ty/src/method_resolution.rs84
-rw-r--r--crates/hir/src/source_analyzer.rs65
-rw-r--r--crates/ide/src/goto_definition.rs82
4 files changed, 170 insertions, 73 deletions
diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs
index b00e3216b2d..12f45f00f9c 100644
--- a/crates/hir-ty/src/infer/unify.rs
+++ b/crates/hir-ty/src/infer/unify.rs
@@ -340,8 +340,8 @@ impl<'a> InferenceTable<'a> {
         self.resolve_with_fallback(t, &|_, _, d, _| d)
     }
 
-    /// Unify two types and register new trait goals that arise from that.
-    pub(crate) fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
+    /// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that.
+    pub(crate) fn unify<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool {
         let result = match self.try_unify(ty1, ty2) {
             Ok(r) => r,
             Err(_) => return false,
@@ -350,9 +350,13 @@ impl<'a> InferenceTable<'a> {
         true
     }
 
-    /// Unify two types and return new trait goals arising from it, so the
+    /// Unify two relatable values (e.g. `Ty`) and return new trait goals arising from it, so the
     /// caller needs to deal with them.
-    pub(crate) fn try_unify<T: Zip<Interner>>(&mut self, t1: &T, t2: &T) -> InferResult<()> {
+    pub(crate) fn try_unify<T: ?Sized + Zip<Interner>>(
+        &mut self,
+        t1: &T,
+        t2: &T,
+    ) -> InferResult<()> {
         match self.var_unification_table.relate(
             Interner,
             &self.db,
diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs
index 5998680dcd3..b1178ba0d2a 100644
--- a/crates/hir-ty/src/method_resolution.rs
+++ b/crates/hir-ty/src/method_resolution.rs
@@ -22,10 +22,10 @@ use crate::{
     from_foreign_def_id,
     infer::{unify::InferenceTable, Adjust, Adjustment, AutoBorrow, OverloadedDeref, PointerCast},
     primitive::{FloatTy, IntTy, UintTy},
-    static_lifetime,
+    static_lifetime, to_chalk_trait_id,
     utils::all_super_traits,
     AdtId, Canonical, CanonicalVarKinds, DebruijnIndex, ForeignDefId, InEnvironment, Interner,
-    Scalar, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
+    Scalar, Substitution, TraitEnvironment, TraitRef, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
 };
 
 /// This is used as a key for indexing impls.
@@ -624,52 +624,76 @@ pub(crate) fn iterate_method_candidates<T>(
     slot
 }
 
+/// Looks up the impl method that actually runs for the trait method `func`.
+///
+/// Returns `func` if it's not a method defined in a trait or the lookup failed.
 pub fn lookup_impl_method(
-    self_ty: &Ty,
     db: &dyn HirDatabase,
     env: Arc<TraitEnvironment>,
-    trait_: TraitId,
+    func: FunctionId,
+    fn_subst: Substitution,
+) -> FunctionId {
+    let trait_id = match func.lookup(db.upcast()).container {
+        ItemContainerId::TraitId(id) => id,
+        _ => return func,
+    };
+    let trait_params = db.generic_params(trait_id.into()).type_or_consts.len();
+    let fn_params = fn_subst.len(Interner) - trait_params;
+    let trait_ref = TraitRef {
+        trait_id: to_chalk_trait_id(trait_id),
+        substitution: Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params)),
+    };
+
+    let name = &db.function_data(func).name;
+    lookup_impl_method_for_trait_ref(trait_ref, db, env, name).unwrap_or(func)
+}
+
+fn lookup_impl_method_for_trait_ref(
+    trait_ref: TraitRef,
+    db: &dyn HirDatabase,
+    env: Arc<TraitEnvironment>,
     name: &Name,
 ) -> Option<FunctionId> {
-    let self_ty_fp = TyFingerprint::for_trait_impl(self_ty)?;
-    let trait_impls = db.trait_impls_in_deps(env.krate);
-    let impls = trait_impls.for_trait_and_self_ty(trait_, self_ty_fp);
-    let mut table = InferenceTable::new(db, env.clone());
-    find_matching_impl(impls, &mut table, &self_ty).and_then(|data| {
-        data.items.iter().find_map(|it| match it {
-            AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f),
-            _ => None,
-        })
+    let self_ty = trait_ref.self_type_parameter(Interner);
+    let self_ty_fp = TyFingerprint::for_trait_impl(&self_ty)?;
+    let impls = db.trait_impls_in_deps(env.krate);
+    let impls = impls.for_trait_and_self_ty(trait_ref.hir_trait_id(), self_ty_fp);
+
+    let table = InferenceTable::new(db, env);
+
+    let impl_data = find_matching_impl(impls, table, trait_ref)?;
+    impl_data.items.iter().find_map(|it| match it {
+        AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f),
+        _ => None,
     })
 }
 
 fn find_matching_impl(
     mut impls: impl Iterator<Item = ImplId>,
-    table: &mut InferenceTable<'_>,
-    self_ty: &Ty,
+    mut table: InferenceTable<'_>,
+    actual_trait_ref: TraitRef,
 ) -> Option<Arc<ImplData>> {
     let db = table.db;
     loop {
         let impl_ = impls.next()?;
         let r = table.run_in_snapshot(|table| {
             let impl_data = db.impl_data(impl_);
-            let substs =
+            let impl_substs =
                 TyBuilder::subst_for_def(db, impl_, None).fill_with_inference_vars(table).build();
-            let impl_ty = db.impl_self_ty(impl_).substitute(Interner, &substs);
-
-            table
-                .unify(self_ty, &impl_ty)
-                .then(|| {
-                    let wh_goals =
-                        crate::chalk_db::convert_where_clauses(db, impl_.into(), &substs)
-                            .into_iter()
-                            .map(|b| b.cast(Interner));
+            let trait_ref = db
+                .impl_trait(impl_)
+                .expect("non-trait method in find_matching_impl")
+                .substitute(Interner, &impl_substs);
 
-                    let goal = crate::Goal::all(Interner, wh_goals);
+            if !table.unify(&trait_ref, &actual_trait_ref) {
+                return None;
+            }
 
-                    table.try_obligation(goal).map(|_| impl_data)
-                })
-                .flatten()
+            let wcs = crate::chalk_db::convert_where_clauses(db, impl_.into(), &impl_substs)
+                .into_iter()
+                .map(|b| b.cast(Interner));
+            let goal = crate::Goal::all(Interner, wcs);
+            table.try_obligation(goal).map(|_| impl_data)
         });
         if r.is_some() {
             break r;
@@ -1214,7 +1238,7 @@ fn is_valid_fn_candidate(
             let expected_receiver =
                 sig.map(|s| s.params()[0].clone()).substitute(Interner, &fn_subst);
 
-            check_that!(table.unify(&receiver_ty, &expected_receiver));
+            check_that!(table.unify(receiver_ty, &expected_receiver));
         }
 
         if let ItemContainerId::ImplId(impl_id) = container {
diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs
index 07bae2b38c7..f86c5710053 100644
--- a/crates/hir/src/source_analyzer.rs
+++ b/crates/hir/src/source_analyzer.rs
@@ -270,7 +270,7 @@ impl SourceAnalyzer {
         let expr_id = self.expr_id(db, &call.clone().into())?;
         let (f_in_trait, substs) = self.infer.as_ref()?.method_resolution(expr_id)?;
 
-        Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, &substs))
+        Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, substs))
     }
 
     pub(crate) fn resolve_await_to_poll(
@@ -311,7 +311,7 @@ impl SourceAnalyzer {
         // HACK: subst for `poll()` coincides with that for `Future` because `poll()` itself
         // doesn't have any generic parameters, so we skip building another subst for `poll()`.
         let substs = hir_ty::TyBuilder::subst_for_def(db, future_trait, None).push(ty).build();
-        Some(self.resolve_impl_method_or_trait_def(db, poll_fn, &substs))
+        Some(self.resolve_impl_method_or_trait_def(db, poll_fn, substs))
     }
 
     pub(crate) fn resolve_prefix_expr(
@@ -331,7 +331,7 @@ impl SourceAnalyzer {
         // don't have any generic parameters, so we skip building another subst for the methods.
         let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
 
-        Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
+        Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
     }
 
     pub(crate) fn resolve_index_expr(
@@ -351,7 +351,7 @@ impl SourceAnalyzer {
             .push(base_ty.clone())
             .push(index_ty.clone())
             .build();
-        Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
+        Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
     }
 
     pub(crate) fn resolve_bin_expr(
@@ -372,7 +372,7 @@ impl SourceAnalyzer {
             .push(rhs.clone())
             .build();
 
-        Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
+        Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
     }
 
     pub(crate) fn resolve_try_expr(
@@ -392,7 +392,7 @@ impl SourceAnalyzer {
         // doesn't have any generic parameters, so we skip building another subst for `branch()`.
         let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
 
-        Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
+        Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
     }
 
     pub(crate) fn resolve_field(
@@ -487,9 +487,9 @@ impl SourceAnalyzer {
 
         let mut prefer_value_ns = false;
         let resolved = (|| {
+            let infer = self.infer.as_deref()?;
             if let Some(path_expr) = parent().and_then(ast::PathExpr::cast) {
                 let expr_id = self.expr_id(db, &path_expr.into())?;
-                let infer = self.infer.as_ref()?;
                 if let Some(assoc) = infer.assoc_resolutions_for_expr(expr_id) {
                     let assoc = match assoc {
                         AssocItemId::FunctionId(f_in_trait) => {
@@ -497,9 +497,12 @@ impl SourceAnalyzer {
                                 None => assoc,
                                 Some(func_ty) => {
                                     if let TyKind::FnDef(_fn_def, subs) = func_ty.kind(Interner) {
-                                        self.resolve_impl_method(db, f_in_trait, subs)
-                                            .map(AssocItemId::FunctionId)
-                                            .unwrap_or(assoc)
+                                        self.resolve_impl_method_or_trait_def(
+                                            db,
+                                            f_in_trait,
+                                            subs.clone(),
+                                        )
+                                        .into()
                                     } else {
                                         assoc
                                     }
@@ -520,18 +523,18 @@ impl SourceAnalyzer {
                 prefer_value_ns = true;
             } else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) {
                 let pat_id = self.pat_id(&path_pat.into())?;
-                if let Some(assoc) = self.infer.as_ref()?.assoc_resolutions_for_pat(pat_id) {
+                if let Some(assoc) = infer.assoc_resolutions_for_pat(pat_id) {
                     return Some(PathResolution::Def(AssocItem::from(assoc).into()));
                 }
                 if let Some(VariantId::EnumVariantId(variant)) =
-                    self.infer.as_ref()?.variant_resolution_for_pat(pat_id)
+                    infer.variant_resolution_for_pat(pat_id)
                 {
                     return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
                 }
             } else if let Some(rec_lit) = parent().and_then(ast::RecordExpr::cast) {
                 let expr_id = self.expr_id(db, &rec_lit.into())?;
                 if let Some(VariantId::EnumVariantId(variant)) =
-                    self.infer.as_ref()?.variant_resolution_for_expr(expr_id)
+                    infer.variant_resolution_for_expr(expr_id)
                 {
                     return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
                 }
@@ -541,8 +544,7 @@ impl SourceAnalyzer {
                     || parent().and_then(ast::TupleStructPat::cast).map(ast::Pat::from);
                 if let Some(pat) = record_pat.or_else(tuple_struct_pat) {
                     let pat_id = self.pat_id(&pat)?;
-                    let variant_res_for_pat =
-                        self.infer.as_ref()?.variant_resolution_for_pat(pat_id);
+                    let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id);
                     if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat {
                         return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
                     }
@@ -780,37 +782,22 @@ impl SourceAnalyzer {
         false
     }
 
-    fn resolve_impl_method(
+    fn resolve_impl_method_or_trait_def(
         &self,
         db: &dyn HirDatabase,
         func: FunctionId,
-        substs: &Substitution,
-    ) -> Option<FunctionId> {
-        let impled_trait = match func.lookup(db.upcast()).container {
-            ItemContainerId::TraitId(trait_id) => trait_id,
-            _ => return None,
-        };
-        if substs.is_empty(Interner) {
-            return None;
-        }
-        let self_ty = substs.at(Interner, 0).ty(Interner)?;
+        substs: Substitution,
+    ) -> FunctionId {
         let krate = self.resolver.krate();
-        let trait_env = self.resolver.body_owner()?.as_generic_def_id().map_or_else(
+        let owner = match self.resolver.body_owner() {
+            Some(it) => it,
+            None => return func,
+        };
+        let env = owner.as_generic_def_id().map_or_else(
             || Arc::new(hir_ty::TraitEnvironment::empty(krate)),
             |d| db.trait_environment(d),
         );
-
-        let fun_data = db.function_data(func);
-        method_resolution::lookup_impl_method(self_ty, db, trait_env, impled_trait, &fun_data.name)
-    }
-
-    fn resolve_impl_method_or_trait_def(
-        &self,
-        db: &dyn HirDatabase,
-        func: FunctionId,
-        substs: &Substitution,
-    ) -> FunctionId {
-        self.resolve_impl_method(db, func, substs).unwrap_or(func)
+        method_resolution::lookup_impl_method(db, env, func, substs)
     }
 
     fn lang_trait_fn(
diff --git a/crates/ide/src/goto_definition.rs b/crates/ide/src/goto_definition.rs
index d0be1b3f404..f97c67b144a 100644
--- a/crates/ide/src/goto_definition.rs
+++ b/crates/ide/src/goto_definition.rs
@@ -1834,4 +1834,86 @@ fn f() {
 "#,
         );
     }
+
+    #[test]
+    fn goto_bin_op_multiple_impl() {
+        check(
+            r#"
+//- minicore: add
+struct S;
+impl core::ops::Add for S {
+    fn add(
+     //^^^
+    ) {}
+}
+impl core::ops::Add<usize> for S {
+    fn add(
+    ) {}
+}
+
+fn f() {
+    S +$0 S
+}
+"#,
+        );
+
+        check(
+            r#"
+//- minicore: add
+struct S;
+impl core::ops::Add for S {
+    fn add(
+    ) {}
+}
+impl core::ops::Add<usize> for S {
+    fn add(
+     //^^^
+    ) {}
+}
+
+fn f() {
+    S +$0 0usize
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn path_call_multiple_trait_impl() {
+        check(
+            r#"
+trait Trait<T> {
+    fn f(_: T);
+}
+impl Trait<i32> for usize {
+    fn f(_: i32) {}
+     //^
+}
+impl Trait<i64> for usize {
+    fn f(_: i64) {}
+}
+fn main() {
+    usize::f$0(0i32);
+}
+"#,
+        );
+
+        check(
+            r#"
+trait Trait<T> {
+    fn f(_: T);
+}
+impl Trait<i32> for usize {
+    fn f(_: i32) {}
+}
+impl Trait<i64> for usize {
+    fn f(_: i64) {}
+     //^
+}
+fn main() {
+    usize::f$0(0i64);
+}
+"#,
+        )
+    }
 }