about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2022-05-03 11:50:31 +0000
committerbors <bors@rust-lang.org>2022-05-03 11:50:31 +0000
commit0ee4e6a22d79f58b6b459dbc874d6b90a4495d83 (patch)
tree74ab0ec5fe138371815d20df5575aeaca19dc1fb
parenteeb45329e33ca03a7d1c21f35183fed19ee1c292 (diff)
parent970276b5594a6205709b0f2d0f6af9b6a0121683 (diff)
downloadrust-0ee4e6a22d79f58b6b459dbc874d6b90a4495d83.tar.gz
rust-0ee4e6a22d79f58b6b459dbc874d6b90a4495d83.zip
Auto merge of #12086 - iDawer:infer.rpit, r=flodiebold
infer from RPIT bounds of _this_ function

Collect obligations from RPITs (Return Position `impl Trait`) of a function which is being inferred.
This allows inferring {unknown}s from RPIT bounds.

Closes #8403
-rw-r--r--crates/hir-ty/src/infer.rs52
-rw-r--r--crates/hir-ty/src/tests/traits.rs36
2 files changed, 75 insertions, 13 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 7f5ad415096..2a11c9d9bf1 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -19,7 +19,7 @@ use std::sync::Arc;
 use chalk_ir::{cast::Cast, ConstValue, DebruijnIndex, Mutability, Safety, Scalar, TypeFlags};
 use hir_def::{
     body::Body,
-    data::{ConstData, FunctionData, StaticData},
+    data::{ConstData, StaticData},
     expr::{BindingAnnotation, ExprId, PatId},
     lang_item::LangItemTarget,
     path::{path, Path},
@@ -32,12 +32,13 @@ use hir_expand::name::{name, Name};
 use itertools::Either;
 use la_arena::ArenaMap;
 use rustc_hash::FxHashMap;
-use stdx::impl_from;
+use stdx::{always, impl_from};
 
 use crate::{
-    db::HirDatabase, fold_tys_and_consts, infer::coerce::CoerceMany, lower::ImplTraitLoweringMode,
-    to_assoc_type_id, AliasEq, AliasTy, Const, DomainGoal, GenericArg, Goal, InEnvironment,
-    Interner, ProjectionTy, Substitution, TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
+    db::HirDatabase, fold_tys, fold_tys_and_consts, infer::coerce::CoerceMany,
+    lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Const, DomainGoal,
+    GenericArg, Goal, ImplTraitId, InEnvironment, Interner, ProjectionTy, Substitution,
+    TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
 };
 
 // This lint has a false positive here. See the link below for details.
@@ -64,7 +65,7 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<Infer
 
     match def {
         DefWithBodyId::ConstId(c) => ctx.collect_const(&db.const_data(c)),
-        DefWithBodyId::FunctionId(f) => ctx.collect_fn(&db.function_data(f)),
+        DefWithBodyId::FunctionId(f) => ctx.collect_fn(f),
         DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_data(s)),
     }
 
@@ -457,7 +458,8 @@ impl<'a> InferenceContext<'a> {
         self.return_ty = self.make_ty(&data.type_ref);
     }
 
-    fn collect_fn(&mut self, data: &FunctionData) {
+    fn collect_fn(&mut self, func: FunctionId) {
+        let data = self.db.function_data(func);
         let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver)
             .with_impl_trait_mode(ImplTraitLoweringMode::Param);
         let param_tys =
@@ -474,8 +476,42 @@ impl<'a> InferenceContext<'a> {
         } else {
             &*data.ret_type
         };
-        let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Disallowed); // FIXME implement RPIT
+        let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Opaque);
         self.return_ty = return_ty;
+
+        if let Some(rpits) = self.db.return_type_impl_traits(func) {
+            // RPIT opaque types use substitution of their parent function.
+            let fn_placeholders = TyBuilder::placeholder_subst(self.db, func);
+            self.return_ty = fold_tys(
+                self.return_ty.clone(),
+                |ty, _| {
+                    let opaque_ty_id = match ty.kind(Interner) {
+                        TyKind::OpaqueType(opaque_ty_id, _) => *opaque_ty_id,
+                        _ => return ty,
+                    };
+                    let idx = match self.db.lookup_intern_impl_trait_id(opaque_ty_id.into()) {
+                        ImplTraitId::ReturnTypeImplTrait(_, idx) => idx,
+                        _ => unreachable!(),
+                    };
+                    let bounds = (*rpits).map_ref(|rpits| {
+                        rpits.impl_traits[idx as usize].bounds.map_ref(|it| it.into_iter())
+                    });
+                    let var = self.table.new_type_var();
+                    let var_subst = Substitution::from1(Interner, var.clone());
+                    for bound in bounds {
+                        let predicate =
+                            bound.map(|it| it.cloned()).substitute(Interner, &fn_placeholders);
+                        let (var_predicate, binders) = predicate
+                            .substitute(Interner, &var_subst)
+                            .into_value_and_skipped_binders();
+                        always!(binders.len(Interner) == 0); // quantified where clauses not yet handled
+                        self.push_obligation(var_predicate.cast(Interner));
+                    }
+                    var
+                },
+                DebruijnIndex::INNERMOST,
+            );
+        }
     }
 
     fn infer_body(&mut self) {
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 5e58d5ad838..0b08aa4711c 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -1256,6 +1256,32 @@ fn test() {
 }
 
 #[test]
+fn infer_from_return_pos_impl_trait() {
+    check_infer_with_mismatches(
+        r#"
+//- minicore: fn, sized
+trait Trait<T> {}
+struct Bar<T>(T);
+impl<T> Trait<T> for Bar<T> {}
+fn foo<const C: u8, T>() -> (impl FnOnce(&str, T), impl Trait<u8>) {
+    (|input, t| {}, Bar(C))
+}
+"#,
+        expect![[r#"
+            134..165 '{     ...(C)) }': (|&str, T| -> (), Bar<u8>)
+            140..163 '(|inpu...ar(C))': (|&str, T| -> (), Bar<u8>)
+            141..154 '|input, t| {}': |&str, T| -> ()
+            142..147 'input': &str
+            149..150 't': T
+            152..154 '{}': ()
+            156..159 'Bar': Bar<u8>(u8) -> Bar<u8>
+            156..162 'Bar(C)': Bar<u8>
+            160..161 'C': u8
+        "#]],
+    );
+}
+
+#[test]
 fn dyn_trait() {
     check_infer(
         r#"
@@ -2392,7 +2418,7 @@ fn test() -> impl Trait<i32> {
             171..182 '{ loop {} }': T
             173..180 'loop {}': !
             178..180 '{}': ()
-            213..309 '{     ...t()) }': S<{unknown}>
+            213..309 '{     ...t()) }': S<i32>
             223..225 's1': S<u32>
             228..229 'S': S<u32>(u32) -> S<u32>
             228..240 'S(default())': S<u32>
@@ -2408,10 +2434,10 @@ fn test() -> impl Trait<i32> {
             276..288 'S(default())': S<i32>
             278..285 'default': fn default<i32>() -> i32
             278..287 'default()': i32
-            295..296 'S': S<{unknown}>({unknown}) -> S<{unknown}>
-            295..307 'S(default())': S<{unknown}>
-            297..304 'default': fn default<{unknown}>() -> {unknown}
-            297..306 'default()': {unknown}
+            295..296 'S': S<i32>(i32) -> S<i32>
+            295..307 'S(default())': S<i32>
+            297..304 'default': fn default<i32>() -> i32
+            297..306 'default()': i32
         "#]],
     );
 }