about summary refs log tree commit diff
diff options
context:
space:
mode:
authorShoyu Vanilla <modulo641@gmail.com>2024-03-06 21:16:41 +0900
committerShoyu Vanilla <modulo641@gmail.com>2024-03-06 21:16:41 +0900
commita8f56112eab74521dfc75ef811f1a9ea23bcf43c (patch)
tree3ffc8e7991f6280f7d7d72cce3aa689b5b9131c1
parent52d8ae791d0b1218da57b1b2de4d114d1dff2981 (diff)
downloadrust-a8f56112eab74521dfc75ef811f1a9ea23bcf43c.tar.gz
rust-a8f56112eab74521dfc75ef811f1a9ea23bcf43c.zip
fix: Function argument type inference with associated type impl trait
-rw-r--r--crates/hir-ty/src/lower.rs96
-rw-r--r--crates/hir-ty/src/tests/traits.rs47
2 files changed, 132 insertions, 11 deletions
diff --git a/crates/hir-ty/src/lower.rs b/crates/hir-ty/src/lower.rs
index 75ac3b0d66b..dac20f22597 100644
--- a/crates/hir-ty/src/lower.rs
+++ b/crates/hir-ty/src/lower.rs
@@ -995,12 +995,12 @@ impl<'a> TyLoweringContext<'a> {
 
     pub(crate) fn lower_type_bound(
         &'a self,
-        bound: &'a TypeBound,
+        bound: &'a Interned<TypeBound>,
         self_ty: Ty,
         ignore_bindings: bool,
     ) -> impl Iterator<Item = QuantifiedWhereClause> + 'a {
         let mut bindings = None;
-        let trait_ref = match bound {
+        let trait_ref = match bound.as_ref() {
             TypeBound::Path(path, TraitBoundModifier::None) => {
                 bindings = self.lower_trait_ref_from_path(path, Some(self_ty));
                 bindings
@@ -1055,10 +1055,10 @@ impl<'a> TyLoweringContext<'a> {
 
     fn assoc_type_bindings_from_type_bound(
         &'a self,
-        bound: &'a TypeBound,
+        bound: &'a Interned<TypeBound>,
         trait_ref: TraitRef,
     ) -> impl Iterator<Item = QuantifiedWhereClause> + 'a {
-        let last_segment = match bound {
+        let last_segment = match bound.as_ref() {
             TypeBound::Path(path, TraitBoundModifier::None) | TypeBound::ForLifetime(_, path) => {
                 path.segments().last()
             }
@@ -1121,7 +1121,63 @@ impl<'a> TyLoweringContext<'a> {
                             );
                         }
                     } else {
-                        let ty = self.lower_ty(type_ref);
+                        let ty = 'ty: {
+                            if matches!(
+                                self.impl_trait_mode,
+                                ImplTraitLoweringState::Param(_)
+                                    | ImplTraitLoweringState::Variable(_)
+                            ) {
+                                // Find the generic index for the target of our `bound`
+                                let target_param_idx = self
+                                    .resolver
+                                    .where_predicates_in_scope()
+                                    .find_map(|p| match p {
+                                        WherePredicate::TypeBound {
+                                            target: WherePredicateTypeTarget::TypeOrConstParam(idx),
+                                            bound: b,
+                                        } if b == bound => Some(idx),
+                                        _ => None,
+                                    });
+                                if let Some(target_param_idx) = target_param_idx {
+                                    let mut counter = 0;
+                                    for (idx, data) in self.generics().params.type_or_consts.iter()
+                                    {
+                                        // Count the number of `impl Trait` things that appear before
+                                        // the target of our `bound`.
+                                        // Our counter within `impl_trait_mode` should be that number
+                                        // to properly lower each types within `type_ref`
+                                        if data.type_param().is_some_and(|p| {
+                                            p.provenance == TypeParamProvenance::ArgumentImplTrait
+                                        }) {
+                                            counter += 1;
+                                        }
+                                        if idx == *target_param_idx {
+                                            break;
+                                        }
+                                    }
+                                    let mut ext = TyLoweringContext::new_maybe_unowned(
+                                        self.db,
+                                        self.resolver,
+                                        self.owner,
+                                    )
+                                    .with_type_param_mode(self.type_param_mode);
+                                    match &self.impl_trait_mode {
+                                        ImplTraitLoweringState::Param(_) => {
+                                            ext.impl_trait_mode =
+                                                ImplTraitLoweringState::Param(Cell::new(counter));
+                                        }
+                                        ImplTraitLoweringState::Variable(_) => {
+                                            ext.impl_trait_mode = ImplTraitLoweringState::Variable(
+                                                Cell::new(counter),
+                                            );
+                                        }
+                                        _ => unreachable!(),
+                                    }
+                                    break 'ty ext.lower_ty(type_ref);
+                                }
+                            }
+                            self.lower_ty(type_ref)
+                        };
                         let alias_eq =
                             AliasEq { alias: AliasTy::Projection(projection_ty.clone()), ty };
                         predicates.push(crate::wrap_empty_binders(WhereClause::AliasEq(alias_eq)));
@@ -1403,8 +1459,14 @@ pub(crate) fn generic_predicates_for_param_query(
     assoc_name: Option<Name>,
 ) -> Arc<[Binders<QuantifiedWhereClause>]> {
     let resolver = def.resolver(db.upcast());
-    let ctx = TyLoweringContext::new(db, &resolver, def.into())
-        .with_type_param_mode(ParamLoweringMode::Variable);
+    let ctx = if let GenericDefId::FunctionId(_) = def {
+        TyLoweringContext::new(db, &resolver, def.into())
+            .with_impl_trait_mode(ImplTraitLoweringMode::Variable)
+            .with_type_param_mode(ParamLoweringMode::Variable)
+    } else {
+        TyLoweringContext::new(db, &resolver, def.into())
+            .with_type_param_mode(ParamLoweringMode::Variable)
+    };
     let generics = generics(db.upcast(), def);
 
     // we have to filter out all other predicates *first*, before attempting to lower them
@@ -1490,8 +1552,14 @@ pub(crate) fn trait_environment_query(
     def: GenericDefId,
 ) -> Arc<TraitEnvironment> {
     let resolver = def.resolver(db.upcast());
-    let ctx = TyLoweringContext::new(db, &resolver, def.into())
-        .with_type_param_mode(ParamLoweringMode::Placeholder);
+    let ctx = if let GenericDefId::FunctionId(_) = def {
+        TyLoweringContext::new(db, &resolver, def.into())
+            .with_impl_trait_mode(ImplTraitLoweringMode::Param)
+            .with_type_param_mode(ParamLoweringMode::Placeholder)
+    } else {
+        TyLoweringContext::new(db, &resolver, def.into())
+            .with_type_param_mode(ParamLoweringMode::Placeholder)
+    };
     let mut traits_in_scope = Vec::new();
     let mut clauses = Vec::new();
     for pred in resolver.where_predicates_in_scope() {
@@ -1549,8 +1617,14 @@ pub(crate) fn generic_predicates_query(
     def: GenericDefId,
 ) -> Arc<[Binders<QuantifiedWhereClause>]> {
     let resolver = def.resolver(db.upcast());
-    let ctx = TyLoweringContext::new(db, &resolver, def.into())
-        .with_type_param_mode(ParamLoweringMode::Variable);
+    let ctx = if let GenericDefId::FunctionId(_) = def {
+        TyLoweringContext::new(db, &resolver, def.into())
+            .with_impl_trait_mode(ImplTraitLoweringMode::Variable)
+            .with_type_param_mode(ParamLoweringMode::Variable)
+    } else {
+        TyLoweringContext::new(db, &resolver, def.into())
+            .with_type_param_mode(ParamLoweringMode::Variable)
+    };
     let generics = generics(db.upcast(), def);
 
     let mut predicates = resolver
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 39c5547b8d0..b80cfe18e4c 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -1232,6 +1232,53 @@ fn test(x: impl Trait<u64>, y: &impl Trait<u64>) {
 }
 
 #[test]
+fn argument_impl_trait_with_projection() {
+    check_infer(
+        r#"
+trait X {
+    type Item;
+}
+
+impl<T> X for [T; 2] {
+    type Item = T;
+}
+
+trait Y {}
+
+impl<T> Y for T {}
+
+enum R<T, U> {
+    A(T),
+    B(U),
+}
+
+fn foo<T>(x: impl X<Item = R<impl Y, T>>) -> T { loop {} }
+
+fn bar() {
+    let a = foo([R::A(()), R::B(7)]);
+}
+"#,
+        expect![[r#"
+            153..154 'x': impl X<Item = R<impl Y + ?Sized, T>> + ?Sized
+            190..201 '{ loop {} }': T
+            192..199 'loop {}': !
+            197..199 '{}': ()
+            212..253 '{     ...)]); }': ()
+            222..223 'a': i32
+            226..229 'foo': fn foo<i32>([R<(), i32>; 2]) -> i32
+            226..250 'foo([R...B(7)])': i32
+            230..249 '[R::A(...:B(7)]': [R<(), i32>; 2]
+            231..235 'R::A': extern "rust-call" A<(), i32>(()) -> R<(), i32>
+            231..239 'R::A(())': R<(), i32>
+            236..238 '()': ()
+            241..245 'R::B': extern "rust-call" B<(), i32>(i32) -> R<(), i32>
+            241..248 'R::B(7)': R<(), i32>
+            246..247 '7': i32
+        "#]],
+    );
+}
+
+#[test]
 fn simple_return_pos_impl_trait() {
     cov_mark::check!(lower_rpit);
     check_infer(