about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2022-10-10 23:35:51 +0000
committerMichael Goulet <michael@errs.io>2022-10-15 17:46:05 +0000
commite994de803df1f1a9f2bbd4da1258d03ea05b4231 (patch)
treeecd05e7c1f134e8778d7f9fe4e13e77c1fa7b143 /compiler
parentcb20758257a5efe790e27460df53c12bf1c90403 (diff)
downloadrust-e994de803df1f1a9f2bbd4da1258d03ea05b4231.tar.gz
rust-e994de803df1f1a9f2bbd4da1258d03ea05b4231.zip
Equate full fn signatures to infer all region variables
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_hir_analysis/src/check/compare_method.rs46
1 files changed, 33 insertions, 13 deletions
diff --git a/compiler/rustc_hir_analysis/src/check/compare_method.rs b/compiler/rustc_hir_analysis/src/check/compare_method.rs
index 986d5bed39e..6ee53436385 100644
--- a/compiler/rustc_hir_analysis/src/check/compare_method.rs
+++ b/compiler/rustc_hir_analysis/src/check/compare_method.rs
@@ -465,30 +465,30 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
     let ocx = ObligationCtxt::new(infcx);
 
     let norm_cause = ObligationCause::misc(return_span, impl_m_hir_id);
-    let impl_return_ty = ocx.normalize(
+    let impl_sig = ocx.normalize(
         norm_cause.clone(),
         param_env,
-        infcx
-            .replace_bound_vars_with_fresh_vars(
-                return_span,
-                infer::HigherRankedType,
-                tcx.fn_sig(impl_m.def_id),
-            )
-            .output(),
+        infcx.replace_bound_vars_with_fresh_vars(
+            return_span,
+            infer::HigherRankedType,
+            tcx.fn_sig(impl_m.def_id),
+        ),
     );
+    let impl_return_ty = impl_sig.output();
 
     let mut collector = ImplTraitInTraitCollector::new(&ocx, return_span, param_env, impl_m_hir_id);
-    let unnormalized_trait_return_ty = tcx
+    let unnormalized_trait_sig = tcx
         .liberate_late_bound_regions(
             impl_m.def_id,
             tcx.bound_fn_sig(trait_m.def_id).subst(tcx, trait_to_placeholder_substs),
         )
-        .output()
         .fold_with(&mut collector);
-    let trait_return_ty =
-        ocx.normalize(norm_cause.clone(), param_env, unnormalized_trait_return_ty);
+    let trait_sig = ocx.normalize(norm_cause.clone(), param_env, unnormalized_trait_sig);
+    let trait_return_ty = trait_sig.output();
 
-    let wf_tys = FxHashSet::from_iter([unnormalized_trait_return_ty, trait_return_ty]);
+    let wf_tys = FxHashSet::from_iter(
+        unnormalized_trait_sig.inputs_and_output.iter().chain(trait_sig.inputs_and_output.iter()),
+    );
 
     match infcx.at(&cause, param_env).eq(trait_return_ty, impl_return_ty) {
         Ok(infer::InferOk { value: (), obligations }) => {
@@ -521,6 +521,26 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
         }
     }
 
+    // Unify the whole function signature. We need to do this to fully infer
+    // the lifetimes of the return type, but do this after unifying just the
+    // return types, since we want to avoid duplicating errors from
+    // `compare_predicate_entailment`.
+    match infcx
+        .at(&cause, param_env)
+        .eq(tcx.mk_fn_ptr(ty::Binder::dummy(trait_sig)), tcx.mk_fn_ptr(ty::Binder::dummy(impl_sig)))
+    {
+        Ok(infer::InferOk { value: (), obligations }) => {
+            ocx.register_obligations(obligations);
+        }
+        Err(terr) => {
+            let guar = tcx.sess.delay_span_bug(
+                return_span,
+                format!("could not unify `{trait_sig}` and `{impl_sig}`: {terr:?}"),
+            );
+            return Err(guar);
+        }
+    }
+
     // Check that all obligations are satisfied by the implementation's
     // RPITs.
     let errors = ocx.select_all_or_error();