about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-04-24 19:56:43 +0000
committerbors <bors@rust-lang.org>2023-04-24 19:56:43 +0000
commit15ef5f552374e52ca3b578757880938e354b6cc0 (patch)
tree0bee996e82440681d68b09e10e303d213f112308
parent65ac9f4602b87e69bf17298f4cd72dc813d5554b (diff)
parent12ba5cab112a863e63e51d56e0660d999c86b679 (diff)
downloadrust-15ef5f552374e52ca3b578757880938e354b6cc0.tar.gz
rust-15ef5f552374e52ca3b578757880938e354b6cc0.zip
Auto merge of #14641 - lowr:fix/obligation-for-value-path, r=Veykril
Register obligations during path inference

Fixes #14635

When we infer path expressions that resolve to some generic item, we need to consider their generic bounds. For example, when we resolve a path `Into::into` to `fn into<?0, ?1>` (note that `?0` is the self type of trait ref), we should register an obligation `?0: Into<?1>` or else their relationship would be lost.

Relevant part in rustc is [`add_required_obligations_with_code()`] that's called in [`instantiate_value_path()`].

[`instantiate_value_path()`]: https://github.com/rust-lang/rust/blob/3462f79e94f466a56ddaccfcdd3a3d44dd1dda9f/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs#L1052
[`add_required_obligations_with_code()`]: https://github.com/rust-lang/rust/blob/3462f79e94f466a56ddaccfcdd3a3d44dd1dda9f/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs#L1411
-rw-r--r--crates/hir-ty/src/builder.rs18
-rw-r--r--crates/hir-ty/src/chalk_db.rs12
-rw-r--r--crates/hir-ty/src/infer.rs15
-rw-r--r--crates/hir-ty/src/infer/path.rs91
-rw-r--r--crates/hir-ty/src/infer/unify.rs10
-rw-r--r--crates/hir-ty/src/method_resolution.rs12
-rw-r--r--crates/hir-ty/src/tests/traits.rs35
7 files changed, 142 insertions, 51 deletions
diff --git a/crates/hir-ty/src/builder.rs b/crates/hir-ty/src/builder.rs
index 03e9443599d..97924569ccc 100644
--- a/crates/hir-ty/src/builder.rs
+++ b/crates/hir-ty/src/builder.rs
@@ -18,7 +18,6 @@ use crate::{
     consteval::unknown_const_as_generic, db::HirDatabase, infer::unify::InferenceTable, primitive,
     to_assoc_type_id, to_chalk_trait_id, utils::generics, Binders, BoundVar, CallableSig,
     GenericArg, Interner, ProjectionTy, Substitution, TraitRef, Ty, TyDefId, TyExt, TyKind,
-    ValueTyDefId,
 };
 
 #[derive(Debug, Clone, PartialEq, Eq)]
@@ -362,21 +361,4 @@ impl TyBuilder<Binders<Ty>> {
     pub fn impl_self_ty(db: &dyn HirDatabase, def: hir_def::ImplId) -> TyBuilder<Binders<Ty>> {
         TyBuilder::subst_for_def(db, def, None).with_data(db.impl_self_ty(def))
     }
-
-    pub fn value_ty(
-        db: &dyn HirDatabase,
-        def: ValueTyDefId,
-        parent_subst: Option<Substitution>,
-    ) -> TyBuilder<Binders<Ty>> {
-        let poly_value_ty = db.value_ty(def);
-        let id = match def.to_generic_def_id() {
-            Some(id) => id,
-            None => {
-                // static items
-                assert!(parent_subst.is_none());
-                return TyBuilder::new_empty(poly_value_ty);
-            }
-        };
-        TyBuilder::subst_for_def(db, id, parent_subst).with_data(poly_value_ty)
-    }
 }
diff --git a/crates/hir-ty/src/chalk_db.rs b/crates/hir-ty/src/chalk_db.rs
index 9dd3bddbd41..fb7d99711d9 100644
--- a/crates/hir-ty/src/chalk_db.rs
+++ b/crates/hir-ty/src/chalk_db.rs
@@ -803,17 +803,17 @@ pub(crate) fn adt_variance_query(
     )
 }
 
+/// Returns instantiated predicates.
 pub(super) fn convert_where_clauses(
     db: &dyn HirDatabase,
     def: GenericDefId,
     substs: &Substitution,
 ) -> Vec<chalk_ir::QuantifiedWhereClause<Interner>> {
-    let generic_predicates = db.generic_predicates(def);
-    let mut result = Vec::with_capacity(generic_predicates.len());
-    for pred in generic_predicates.iter() {
-        result.push(pred.clone().substitute(Interner, substs));
-    }
-    result
+    db.generic_predicates(def)
+        .iter()
+        .cloned()
+        .map(|pred| pred.substitute(Interner, substs))
+        .collect()
 }
 
 pub(super) fn generic_predicate_to_inline_bound(
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 65c29b1cc94..4affe7424e1 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -16,7 +16,10 @@
 use std::sync::Arc;
 use std::{convert::identity, ops::Index};
 
-use chalk_ir::{cast::Cast, DebruijnIndex, Mutability, Safety, Scalar, TypeFlags};
+use chalk_ir::{
+    cast::Cast, fold::TypeFoldable, interner::HasInterner, DebruijnIndex, Mutability, Safety,
+    Scalar, TypeFlags,
+};
 use either::Either;
 use hir_def::{
     body::Body,
@@ -798,7 +801,10 @@ impl<'a> InferenceContext<'a> {
         self.table.insert_type_vars_shallow(ty)
     }
 
-    fn insert_type_vars(&mut self, ty: Ty) -> Ty {
+    fn insert_type_vars<T>(&mut self, ty: T) -> T
+    where
+        T: HasInterner<Interner = Interner> + TypeFoldable<Interner>,
+    {
         self.table.insert_type_vars(ty)
     }
 
@@ -875,7 +881,10 @@ impl<'a> InferenceContext<'a> {
     /// type annotation (e.g. from a let type annotation, field type or function
     /// call). `make_ty` handles this already, but e.g. for field types we need
     /// to do it as well.
-    fn normalize_associated_types_in(&mut self, ty: Ty) -> Ty {
+    fn normalize_associated_types_in<T>(&mut self, ty: T) -> T
+    where
+        T: HasInterner<Interner = Interner> + TypeFoldable<Interner>,
+    {
         self.table.normalize_associated_types_in(ty)
     }
 
diff --git a/crates/hir-ty/src/infer/path.rs b/crates/hir-ty/src/infer/path.rs
index 368c3f65243..95a20f983f1 100644
--- a/crates/hir-ty/src/infer/path.rs
+++ b/crates/hir-ty/src/infer/path.rs
@@ -4,7 +4,7 @@ use chalk_ir::cast::Cast;
 use hir_def::{
     path::{Path, PathSegment},
     resolver::{ResolveValueResult, TypeNs, ValueNs},
-    AdtId, AssocItemId, EnumVariantId, ItemContainerId, Lookup,
+    AdtId, AssocItemId, EnumVariantId, GenericDefId, ItemContainerId, Lookup,
 };
 use hir_expand::name::Name;
 use stdx::never;
@@ -13,6 +13,7 @@ use crate::{
     builder::ParamKind,
     consteval,
     method_resolution::{self, VisibleFromModule},
+    to_chalk_trait_id,
     utils::generics,
     InferenceDiagnostic, Interner, Substitution, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
     ValueTyDefId,
@@ -20,15 +21,25 @@ use crate::{
 
 use super::{ExprOrPatId, InferenceContext, TraitRef};
 
-impl<'a> InferenceContext<'a> {
+impl InferenceContext<'_> {
     pub(super) fn infer_path(&mut self, path: &Path, id: ExprOrPatId) -> Option<Ty> {
-        let ty = self.resolve_value_path(path, id)?;
-        let ty = self.insert_type_vars(ty);
+        let (value_def, generic_def, substs) = match self.resolve_value_path(path, id)? {
+            ValuePathResolution::GenericDef(value_def, generic_def, substs) => {
+                (value_def, generic_def, substs)
+            }
+            ValuePathResolution::NonGeneric(ty) => return Some(ty),
+        };
+        let substs = self.insert_type_vars(substs);
+        let substs = self.normalize_associated_types_in(substs);
+
+        self.add_required_obligations_for_value_path(generic_def, &substs);
+
+        let ty = self.db.value_ty(value_def).substitute(Interner, &substs);
         let ty = self.normalize_associated_types_in(ty);
         Some(ty)
     }
 
-    fn resolve_value_path(&mut self, path: &Path, id: ExprOrPatId) -> Option<Ty> {
+    fn resolve_value_path(&mut self, path: &Path, id: ExprOrPatId) -> Option<ValuePathResolution> {
         let (value, self_subst) = if let Some(type_ref) = path.type_anchor() {
             let last = path.segments().last()?;
 
@@ -56,9 +67,9 @@ impl<'a> InferenceContext<'a> {
             }
         };
 
-        let typable: ValueTyDefId = match value {
+        let value_def = match value {
             ValueNs::LocalBinding(pat) => match self.result.type_of_binding.get(pat) {
-                Some(ty) => return Some(ty.clone()),
+                Some(ty) => return Some(ValuePathResolution::NonGeneric(ty.clone())),
                 None => {
                     never!("uninferred pattern?");
                     return None;
@@ -82,28 +93,45 @@ impl<'a> InferenceContext<'a> {
                 let substs = generics.placeholder_subst(self.db);
                 let ty = self.db.impl_self_ty(impl_id).substitute(Interner, &substs);
                 if let Some((AdtId::StructId(struct_id), substs)) = ty.as_adt() {
-                    let ty = self.db.value_ty(struct_id.into()).substitute(Interner, &substs);
-                    return Some(ty);
+                    return Some(ValuePathResolution::GenericDef(
+                        struct_id.into(),
+                        struct_id.into(),
+                        substs.clone(),
+                    ));
                 } else {
                     // FIXME: report error, invalid Self reference
                     return None;
                 }
             }
-            ValueNs::GenericParam(it) => return Some(self.db.const_param_ty(it)),
+            ValueNs::GenericParam(it) => {
+                return Some(ValuePathResolution::NonGeneric(self.db.const_param_ty(it)))
+            }
         };
 
         let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver);
-        let substs = ctx.substs_from_path(path, typable, true);
+        let substs = ctx.substs_from_path(path, value_def, true);
         let substs = substs.as_slice(Interner);
         let parent_substs = self_subst.or_else(|| {
-            let generics = generics(self.db.upcast(), typable.to_generic_def_id()?);
+            let generics = generics(self.db.upcast(), value_def.to_generic_def_id()?);
             let parent_params_len = generics.parent_generics()?.len();
             let parent_args = &substs[substs.len() - parent_params_len..];
             Some(Substitution::from_iter(Interner, parent_args))
         });
         let parent_substs_len = parent_substs.as_ref().map_or(0, |s| s.len(Interner));
         let mut it = substs.iter().take(substs.len() - parent_substs_len).cloned();
-        let ty = TyBuilder::value_ty(self.db, typable, parent_substs)
+
+        let Some(generic_def) = value_def.to_generic_def_id() else {
+            // `value_def` is the kind of item that can never be generic (i.e. statics, at least
+            // currently). We can just skip the binders to get its type.
+            let (ty, binders) = self.db.value_ty(value_def).into_value_and_skipped_binders();
+            stdx::always!(
+                parent_substs.is_none() && binders.is_empty(Interner),
+                "non-empty binders for non-generic def",
+            );
+            return Some(ValuePathResolution::NonGeneric(ty));
+        };
+        let builder = TyBuilder::subst_for_def(self.db, generic_def, parent_substs);
+        let substs = builder
             .fill(|x| {
                 it.next().unwrap_or_else(|| match x {
                     ParamKind::Type => self.result.standard_types.unknown.clone().cast(Interner),
@@ -111,7 +139,35 @@ impl<'a> InferenceContext<'a> {
                 })
             })
             .build();
-        Some(ty)
+
+        Some(ValuePathResolution::GenericDef(value_def, generic_def, substs))
+    }
+
+    fn add_required_obligations_for_value_path(&mut self, def: GenericDefId, subst: &Substitution) {
+        let predicates = self.db.generic_predicates(def);
+        for predicate in predicates.iter() {
+            let (predicate, binders) =
+                predicate.clone().substitute(Interner, &subst).into_value_and_skipped_binders();
+            // Quantified where clauses are not yet handled.
+            stdx::always!(binders.is_empty(Interner));
+            self.push_obligation(predicate.cast(Interner));
+        }
+
+        // We need to add `Self: Trait` obligation when `def` is a trait assoc item.
+        let container = match def {
+            GenericDefId::FunctionId(id) => id.lookup(self.db.upcast()).container,
+            GenericDefId::ConstId(id) => id.lookup(self.db.upcast()).container,
+            _ => return,
+        };
+
+        if let ItemContainerId::TraitId(trait_) = container {
+            let param_len = generics(self.db.upcast(), def).len_self();
+            let parent_subst =
+                Substitution::from_iter(Interner, subst.iter(Interner).skip(param_len));
+            let trait_ref =
+                TraitRef { trait_id: to_chalk_trait_id(trait_), substitution: parent_subst };
+            self.push_obligation(trait_ref.cast(Interner));
+        }
     }
 
     fn resolve_assoc_item(
@@ -307,3 +363,10 @@ impl<'a> InferenceContext<'a> {
         Some((ValueNs::EnumVariantId(variant), subst.clone()))
     }
 }
+
+enum ValuePathResolution {
+    // It's awkward to wrap a single ID in two enums, but we need both and this saves fallible
+    // conversion between them + `unwrap()`.
+    GenericDef(ValueTyDefId, GenericDefId, Substitution),
+    NonGeneric(Ty),
+}
diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs
index 2328d97c3b1..2988c710398 100644
--- a/crates/hir-ty/src/infer/unify.rs
+++ b/crates/hir-ty/src/infer/unify.rs
@@ -231,7 +231,10 @@ impl<'a> InferenceTable<'a> {
     /// type annotation (e.g. from a let type annotation, field type or function
     /// call). `make_ty` handles this already, but e.g. for field types we need
     /// to do it as well.
-    pub(crate) fn normalize_associated_types_in(&mut self, ty: Ty) -> Ty {
+    pub(crate) fn normalize_associated_types_in<T>(&mut self, ty: T) -> T
+    where
+        T: HasInterner<Interner = Interner> + TypeFoldable<Interner>,
+    {
         fold_tys(
             ty,
             |ty, _| match ty.kind(Interner) {
@@ -720,7 +723,10 @@ impl<'a> InferenceTable<'a> {
         }
     }
 
-    pub(super) fn insert_type_vars(&mut self, ty: Ty) -> Ty {
+    pub(super) fn insert_type_vars<T>(&mut self, ty: T) -> T
+    where
+        T: HasInterner<Interner = Interner> + TypeFoldable<Interner>,
+    {
         fold_tys_and_consts(
             ty,
             |x, _| match x {
diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs
index 504dbe77d93..9fb7fdcc5fc 100644
--- a/crates/hir-ty/src/method_resolution.rs
+++ b/crates/hir-ty/src/method_resolution.rs
@@ -742,9 +742,8 @@ fn find_matching_impl(
     actual_trait_ref: TraitRef,
 ) -> Option<(Arc<ImplData>, Substitution)> {
     let db = table.db;
-    loop {
-        let impl_ = impls.next()?;
-        let r = table.run_in_snapshot(|table| {
+    impls.find_map(|impl_| {
+        table.run_in_snapshot(|table| {
             let impl_data = db.impl_data(impl_);
             let impl_substs =
                 TyBuilder::subst_for_def(db, impl_, None).fill_with_inference_vars(table).build();
@@ -762,11 +761,8 @@ fn find_matching_impl(
                 .map(|b| b.cast(Interner));
             let goal = crate::Goal::all(Interner, wcs);
             table.try_obligation(goal).map(|_| (impl_data, table.resolve_completely(impl_substs)))
-        });
-        if r.is_some() {
-            break r;
-        }
-    }
+        })
+    })
 }
 
 fn is_inherent_impl_coherent(
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 857891a2148..829a6ab189e 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -4375,3 +4375,38 @@ fn derive_macro_bounds() {
         "#,
     );
 }
+
+#[test]
+fn trait_obligations_should_be_registered_during_path_inference() {
+    check_types(
+        r#"
+//- minicore: fn, from
+struct S<T>(T);
+fn map<T, U, F: FnOnce(T) -> S<U>>(_: T, _: F) -> U { loop {} }
+
+fn test(v: S<i32>) {
+    let res = map(v, Into::into);
+      //^^^ i32
+}
+"#,
+    );
+}
+
+#[test]
+fn fn_obligation_should_be_registered_during_path_inference() {
+    check_types(
+        r#"
+//- minicore: fn, from
+struct S<T>(T);
+impl<T> S<T> {
+    fn foo<U: Into<S<T>>>(_: U) -> Self { loop {} }
+}
+fn map<T, U, F: FnOnce(T) -> U>(_: T, _: F) -> U { loop {} }
+
+fn test(v: S<i32>) {
+    let res = map(v, S::foo);
+      //^^^ S<i32>
+}
+"#,
+    );
+}