about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2021-11-17 21:38:04 -0800
committerMichael Goulet <michael@errs.io>2021-11-20 09:53:08 -0800
commit33c443dd9d50c2a8d95609c2250708ea93cbc7f2 (patch)
tree5782c92a0df2c964b67b06f2b2e19f320013fbe8 /compiler
parent93542a8240c5f926ac5f3f99cef99366082f9c2b (diff)
downloadrust-33c443dd9d50c2a8d95609c2250708ea93cbc7f2.tar.gz
rust-33c443dd9d50c2a8d95609c2250708ea93cbc7f2.zip
Suggest await on cases involving infer
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_infer/src/infer/error_reporting/mod.rs34
1 files changed, 31 insertions, 3 deletions
diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
index c25ec1356e2..3c2f9900080 100644
--- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs
+++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs
@@ -310,6 +310,34 @@ pub fn unexpected_hidden_region_diagnostic(
     err
 }
 
+/// Structurally compares two types, modulo any inference variables.
+///
+/// Returns `true` if two types are equal, or if one type is an inference variable compatible
+/// with the other type. A TyVar inference type is compatible with any type, and an IntVar or
+/// FloatVar inference type are compatible with themselves or their concrete types (Int and
+/// Float types, respectively). When comparing two ADTs, these rules apply recursively.
+pub fn same_type_modulo_infer(a: Ty<'tcx>, b: Ty<'ctx>) -> bool {
+    match (&a.kind(), &b.kind()) {
+        (&ty::Adt(did_a, substs_a), &ty::Adt(did_b, substs_b)) => {
+            if did_a != did_b {
+                return false;
+            }
+
+            substs_a.types().zip(substs_b.types()).all(|(a, b)| same_type_modulo_infer(a, b))
+        }
+        (&ty::Int(_), &ty::Infer(ty::InferTy::IntVar(_)))
+        | (&ty::Infer(ty::InferTy::IntVar(_)), &ty::Int(_) | &ty::Infer(ty::InferTy::IntVar(_)))
+        | (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_)))
+        | (
+            &ty::Infer(ty::InferTy::FloatVar(_)),
+            &ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)),
+        )
+        | (&ty::Infer(ty::InferTy::TyVar(_)), _)
+        | (_, &ty::Infer(ty::InferTy::TyVar(_))) => true,
+        _ => a == b,
+    }
+}
+
 impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
     pub fn report_region_errors(&self, errors: &Vec<RegionResolutionError<'tcx>>) {
         debug!("report_region_errors(): {} errors to start", errors.len());
@@ -1761,7 +1789,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
             self.get_impl_future_output_ty(exp_found.expected),
             self.get_impl_future_output_ty(exp_found.found),
         ) {
-            (Some(exp), Some(found)) if ty::TyS::same_type(exp, found) => match &cause.code {
+            (Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match &cause.code {
                 ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => {
                     diag.multipart_suggestion(
                         "consider `await`ing on both `Future`s",
@@ -1793,7 +1821,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
                     diag.help("consider `await`ing on both `Future`s");
                 }
             },
-            (_, Some(ty)) if ty::TyS::same_type(exp_found.expected, ty) => {
+            (_, Some(ty)) if same_type_modulo_infer(exp_found.expected, ty) => {
                 diag.span_suggestion_verbose(
                     exp_span.shrink_to_hi(),
                     "consider `await`ing on the `Future`",
@@ -1801,7 +1829,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
                     Applicability::MaybeIncorrect,
                 );
             }
-            (Some(ty), _) if ty::TyS::same_type(ty, exp_found.found) => match cause.code {
+            (Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code {
                 ObligationCauseCode::Pattern { span: Some(span), .. }
                 | ObligationCauseCode::IfExpression(box IfExpressionCause { then: span, .. }) => {
                     diag.span_suggestion_verbose(