about summary refs log tree commit diff
diff options
context:
space:
mode:
authorVirginia Senioria <91khr@users.noreply.github.com>2024-09-25 07:34:53 +0000
committerVirginia Senioria <91khr@users.noreply.github.com>2024-09-25 08:45:40 +0000
commit986e20d5bb5df4274b390b4148aab0058081a241 (patch)
tree04c30547407838f4ff7d61d1e81060e689b41811
parent7042c269c166191cd5d8daf0409890903df7af57 (diff)
downloadrust-986e20d5bb5df4274b390b4148aab0058081a241.tar.gz
rust-986e20d5bb5df4274b390b4148aab0058081a241.zip
Fixed diagnostics for coroutines with () as input.
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs72
-rw-r--r--tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs11
-rw-r--r--tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr15
3 files changed, 61 insertions, 37 deletions
diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
index 19e2679ae4d..969f2528836 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
@@ -2635,49 +2635,47 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
         // This shouldn't be common unless manually implementing one of the
         // traits manually, but don't make it more confusing when it does
         // happen.
-        Ok(
-            if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait()
-                && not_tupled
-            {
-                self.report_and_explain_type_error(
-                    TypeTrace::trait_refs(
-                        &obligation.cause,
-                        true,
-                        expected_trait_ref,
-                        found_trait_ref,
-                    ),
-                    ty::error::TypeError::Mismatch,
-                )
-            } else if found.len() == expected.len() {
-                self.report_closure_arg_mismatch(
-                    span,
-                    found_span,
-                    found_trait_ref,
-                    expected_trait_ref,
-                    obligation.cause.code(),
-                    found_node,
-                    obligation.param_env,
-                )
-            } else {
-                let (closure_span, closure_arg_span, found) = found_did
-                    .and_then(|did| {
-                        let node = self.tcx.hir().get_if_local(did)?;
-                        let (found_span, closure_arg_span, found) =
-                            self.get_fn_like_arguments(node)?;
-                        Some((Some(found_span), closure_arg_span, found))
-                    })
-                    .unwrap_or((found_span, None, found));
-
-                self.report_arg_count_mismatch(
+        if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait() && not_tupled
+        {
+            return Ok(self.report_and_explain_type_error(
+                TypeTrace::trait_refs(&obligation.cause, true, expected_trait_ref, found_trait_ref),
+                ty::error::TypeError::Mismatch,
+            ));
+        }
+        if found.len() != expected.len() {
+            let (closure_span, closure_arg_span, found) = found_did
+                .and_then(|did| {
+                    let node = self.tcx.hir().get_if_local(did)?;
+                    let (found_span, closure_arg_span, found) = self.get_fn_like_arguments(node)?;
+                    Some((Some(found_span), closure_arg_span, found))
+                })
+                .unwrap_or((found_span, None, found));
+
+            // If the coroutine take a single () as its argument,
+            // the trait argument would found the coroutine take 0 arguments,
+            // but get_fn_like_arguments would give 1 argument.
+            // This would result in "Expected to take 1 argument, but it takes 1 argument".
+            // Check again to avoid this.
+            if found.len() != expected.len() {
+                return Ok(self.report_arg_count_mismatch(
                     span,
                     closure_span,
                     expected,
                     found,
                     found_trait_ty.is_closure(),
                     closure_arg_span,
-                )
-            },
-        )
+                ));
+            }
+        }
+        Ok(self.report_closure_arg_mismatch(
+            span,
+            found_span,
+            found_trait_ref,
+            expected_trait_ref,
+            obligation.cause.code(),
+            found_node,
+            obligation.param_env,
+        ))
     }
 
     /// Given some node representing a fn-like thing in the HIR map,
diff --git a/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs
new file mode 100644
index 00000000000..448c7100df6
--- /dev/null
+++ b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs
@@ -0,0 +1,11 @@
+#![feature(coroutines, coroutine_trait, stmt_expr_attributes)]
+
+use std::ops::Coroutine;
+
+fn foo() -> impl Coroutine<u8> {
+    //~^ ERROR type mismatch in coroutine arguments
+    #[coroutine]
+    |_: ()| {}
+}
+
+fn main() { }
diff --git a/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr
new file mode 100644
index 00000000000..c7d6507fd79
--- /dev/null
+++ b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr
@@ -0,0 +1,15 @@
+error[E0631]: type mismatch in coroutine arguments
+  --> $DIR/arg-count-mismatch-on-unit-input.rs:5:13
+   |
+LL | fn foo() -> impl Coroutine<u8> {
+   |             ^^^^^^^^^^^^^^^^^^ expected due to this
+...
+LL |     |_: ()| {}
+   |     ------- found signature defined here
+   |
+   = note: expected coroutine signature `fn(u8) -> _`
+              found coroutine signature `fn(()) -> _`
+
+error: aborting due to 1 previous error
+
+For more information about this error, try `rustc --explain E0631`.