diff options
| author | Shoyu Vanilla <modulo641@gmail.com> | 2024-03-06 21:16:41 +0900 |
|---|---|---|
| committer | Shoyu Vanilla <modulo641@gmail.com> | 2024-03-06 21:16:41 +0900 |
| commit | a8f56112eab74521dfc75ef811f1a9ea23bcf43c (patch) | |
| tree | 3ffc8e7991f6280f7d7d72cce3aa689b5b9131c1 | |
| parent | 52d8ae791d0b1218da57b1b2de4d114d1dff2981 (diff) | |
| download | rust-a8f56112eab74521dfc75ef811f1a9ea23bcf43c.tar.gz rust-a8f56112eab74521dfc75ef811f1a9ea23bcf43c.zip | |
fix: Function argument type inference with associated type impl trait
| -rw-r--r-- | crates/hir-ty/src/lower.rs | 96 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/traits.rs | 47 |
2 files changed, 132 insertions, 11 deletions
diff --git a/crates/hir-ty/src/lower.rs b/crates/hir-ty/src/lower.rs index 75ac3b0d66b..dac20f22597 100644 --- a/crates/hir-ty/src/lower.rs +++ b/crates/hir-ty/src/lower.rs @@ -995,12 +995,12 @@ impl<'a> TyLoweringContext<'a> { pub(crate) fn lower_type_bound( &'a self, - bound: &'a TypeBound, + bound: &'a Interned<TypeBound>, self_ty: Ty, ignore_bindings: bool, ) -> impl Iterator<Item = QuantifiedWhereClause> + 'a { let mut bindings = None; - let trait_ref = match bound { + let trait_ref = match bound.as_ref() { TypeBound::Path(path, TraitBoundModifier::None) => { bindings = self.lower_trait_ref_from_path(path, Some(self_ty)); bindings @@ -1055,10 +1055,10 @@ impl<'a> TyLoweringContext<'a> { fn assoc_type_bindings_from_type_bound( &'a self, - bound: &'a TypeBound, + bound: &'a Interned<TypeBound>, trait_ref: TraitRef, ) -> impl Iterator<Item = QuantifiedWhereClause> + 'a { - let last_segment = match bound { + let last_segment = match bound.as_ref() { TypeBound::Path(path, TraitBoundModifier::None) | TypeBound::ForLifetime(_, path) => { path.segments().last() } @@ -1121,7 +1121,63 @@ impl<'a> TyLoweringContext<'a> { ); } } else { - let ty = self.lower_ty(type_ref); + let ty = 'ty: { + if matches!( + self.impl_trait_mode, + ImplTraitLoweringState::Param(_) + | ImplTraitLoweringState::Variable(_) + ) { + // Find the generic index for the target of our `bound` + let target_param_idx = self + .resolver + .where_predicates_in_scope() + .find_map(|p| match p { + WherePredicate::TypeBound { + target: WherePredicateTypeTarget::TypeOrConstParam(idx), + bound: b, + } if b == bound => Some(idx), + _ => None, + }); + if let Some(target_param_idx) = target_param_idx { + let mut counter = 0; + for (idx, data) in self.generics().params.type_or_consts.iter() + { + // Count the number of `impl Trait` things that appear before + // the target of our `bound`. + // Our counter within `impl_trait_mode` should be that number + // to properly lower each types within `type_ref` + if data.type_param().is_some_and(|p| { + p.provenance == TypeParamProvenance::ArgumentImplTrait + }) { + counter += 1; + } + if idx == *target_param_idx { + break; + } + } + let mut ext = TyLoweringContext::new_maybe_unowned( + self.db, + self.resolver, + self.owner, + ) + .with_type_param_mode(self.type_param_mode); + match &self.impl_trait_mode { + ImplTraitLoweringState::Param(_) => { + ext.impl_trait_mode = + ImplTraitLoweringState::Param(Cell::new(counter)); + } + ImplTraitLoweringState::Variable(_) => { + ext.impl_trait_mode = ImplTraitLoweringState::Variable( + Cell::new(counter), + ); + } + _ => unreachable!(), + } + break 'ty ext.lower_ty(type_ref); + } + } + self.lower_ty(type_ref) + }; let alias_eq = AliasEq { alias: AliasTy::Projection(projection_ty.clone()), ty }; predicates.push(crate::wrap_empty_binders(WhereClause::AliasEq(alias_eq))); @@ -1403,8 +1459,14 @@ pub(crate) fn generic_predicates_for_param_query( assoc_name: Option<Name>, ) -> Arc<[Binders<QuantifiedWhereClause>]> { let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, def.into()) - .with_type_param_mode(ParamLoweringMode::Variable); + let ctx = if let GenericDefId::FunctionId(_) = def { + TyLoweringContext::new(db, &resolver, def.into()) + .with_impl_trait_mode(ImplTraitLoweringMode::Variable) + .with_type_param_mode(ParamLoweringMode::Variable) + } else { + TyLoweringContext::new(db, &resolver, def.into()) + .with_type_param_mode(ParamLoweringMode::Variable) + }; let generics = generics(db.upcast(), def); // we have to filter out all other predicates *first*, before attempting to lower them @@ -1490,8 +1552,14 @@ pub(crate) fn trait_environment_query( def: GenericDefId, ) -> Arc<TraitEnvironment> { let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, def.into()) - .with_type_param_mode(ParamLoweringMode::Placeholder); + let ctx = if let GenericDefId::FunctionId(_) = def { + TyLoweringContext::new(db, &resolver, def.into()) + .with_impl_trait_mode(ImplTraitLoweringMode::Param) + .with_type_param_mode(ParamLoweringMode::Placeholder) + } else { + TyLoweringContext::new(db, &resolver, def.into()) + .with_type_param_mode(ParamLoweringMode::Placeholder) + }; let mut traits_in_scope = Vec::new(); let mut clauses = Vec::new(); for pred in resolver.where_predicates_in_scope() { @@ -1549,8 +1617,14 @@ pub(crate) fn generic_predicates_query( def: GenericDefId, ) -> Arc<[Binders<QuantifiedWhereClause>]> { let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, def.into()) - .with_type_param_mode(ParamLoweringMode::Variable); + let ctx = if let GenericDefId::FunctionId(_) = def { + TyLoweringContext::new(db, &resolver, def.into()) + .with_impl_trait_mode(ImplTraitLoweringMode::Variable) + .with_type_param_mode(ParamLoweringMode::Variable) + } else { + TyLoweringContext::new(db, &resolver, def.into()) + .with_type_param_mode(ParamLoweringMode::Variable) + }; let generics = generics(db.upcast(), def); let mut predicates = resolver diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 39c5547b8d0..b80cfe18e4c 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -1232,6 +1232,53 @@ fn test(x: impl Trait<u64>, y: &impl Trait<u64>) { } #[test] +fn argument_impl_trait_with_projection() { + check_infer( + r#" +trait X { + type Item; +} + +impl<T> X for [T; 2] { + type Item = T; +} + +trait Y {} + +impl<T> Y for T {} + +enum R<T, U> { + A(T), + B(U), +} + +fn foo<T>(x: impl X<Item = R<impl Y, T>>) -> T { loop {} } + +fn bar() { + let a = foo([R::A(()), R::B(7)]); +} +"#, + expect![[r#" + 153..154 'x': impl X<Item = R<impl Y + ?Sized, T>> + ?Sized + 190..201 '{ loop {} }': T + 192..199 'loop {}': ! + 197..199 '{}': () + 212..253 '{ ...)]); }': () + 222..223 'a': i32 + 226..229 'foo': fn foo<i32>([R<(), i32>; 2]) -> i32 + 226..250 'foo([R...B(7)])': i32 + 230..249 '[R::A(...:B(7)]': [R<(), i32>; 2] + 231..235 'R::A': extern "rust-call" A<(), i32>(()) -> R<(), i32> + 231..239 'R::A(())': R<(), i32> + 236..238 '()': () + 241..245 'R::B': extern "rust-call" B<(), i32>(i32) -> R<(), i32> + 241..248 'R::B(7)': R<(), i32> + 246..247 '7': i32 + "#]], + ); +} + +#[test] fn simple_return_pos_impl_trait() { cov_mark::check!(lower_rpit); check_infer( |
