about summary refs log tree commit diff
diff options
context:
space:
mode:
authorGuillaume Gomez <guillaume1.gomez@gmail.com>2024-09-27 00:43:31 +0200
committerGitHub <noreply@github.com>2024-09-27 00:43:31 +0200
commit0acddf50604bc713aee96fb6f6f553b564688ac1 (patch)
tree327a97703a34451e1652de6395b2f3f9d0a08a09
parenta4a591a78c78e96746978e2018c9db8291ae1a8d (diff)
parent986e20d5bb5df4274b390b4148aab0058081a241 (diff)
downloadrust-0acddf50604bc713aee96fb6f6f553b564688ac1.tar.gz
rust-0acddf50604bc713aee96fb6f6f553b564688ac1.zip
Rollup merge of #130820 - 91khr:fix-coroutine-unit-arg, r=compiler-errors
Fix diagnostics for coroutines with () as input.

This may be a more real-life example to trigger the diagnostic:

```rust
#![features(try_blocks, coroutine_trait, coroutines)]

use std::ops::Coroutine;

struct Request;
struct Response;
fn get_args() -> Result<String, String> { todo!() }
fn build_request(_arg: String) -> Request { todo!() }
fn work() -> impl Coroutine<Option<Response>, Yield = Request> {
    #[coroutine]
    |_| {
        let r: Result<(), String> = try {
            let req = get_args()?;
            yield build_request(req)
        };
        if let Err(msg) = r {
            eprintln!("Error: {msg}");
        }
    }
}
```
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs72
-rw-r--r--tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs11
-rw-r--r--tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr15
3 files changed, 61 insertions, 37 deletions
diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
index 19e2679ae4d..969f2528836 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
@@ -2635,49 +2635,47 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
         // This shouldn't be common unless manually implementing one of the
         // traits manually, but don't make it more confusing when it does
         // happen.
-        Ok(
-            if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait()
-                && not_tupled
-            {
-                self.report_and_explain_type_error(
-                    TypeTrace::trait_refs(
-                        &obligation.cause,
-                        true,
-                        expected_trait_ref,
-                        found_trait_ref,
-                    ),
-                    ty::error::TypeError::Mismatch,
-                )
-            } else if found.len() == expected.len() {
-                self.report_closure_arg_mismatch(
-                    span,
-                    found_span,
-                    found_trait_ref,
-                    expected_trait_ref,
-                    obligation.cause.code(),
-                    found_node,
-                    obligation.param_env,
-                )
-            } else {
-                let (closure_span, closure_arg_span, found) = found_did
-                    .and_then(|did| {
-                        let node = self.tcx.hir().get_if_local(did)?;
-                        let (found_span, closure_arg_span, found) =
-                            self.get_fn_like_arguments(node)?;
-                        Some((Some(found_span), closure_arg_span, found))
-                    })
-                    .unwrap_or((found_span, None, found));
-
-                self.report_arg_count_mismatch(
+        if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait() && not_tupled
+        {
+            return Ok(self.report_and_explain_type_error(
+                TypeTrace::trait_refs(&obligation.cause, true, expected_trait_ref, found_trait_ref),
+                ty::error::TypeError::Mismatch,
+            ));
+        }
+        if found.len() != expected.len() {
+            let (closure_span, closure_arg_span, found) = found_did
+                .and_then(|did| {
+                    let node = self.tcx.hir().get_if_local(did)?;
+                    let (found_span, closure_arg_span, found) = self.get_fn_like_arguments(node)?;
+                    Some((Some(found_span), closure_arg_span, found))
+                })
+                .unwrap_or((found_span, None, found));
+
+            // If the coroutine take a single () as its argument,
+            // the trait argument would found the coroutine take 0 arguments,
+            // but get_fn_like_arguments would give 1 argument.
+            // This would result in "Expected to take 1 argument, but it takes 1 argument".
+            // Check again to avoid this.
+            if found.len() != expected.len() {
+                return Ok(self.report_arg_count_mismatch(
                     span,
                     closure_span,
                     expected,
                     found,
                     found_trait_ty.is_closure(),
                     closure_arg_span,
-                )
-            },
-        )
+                ));
+            }
+        }
+        Ok(self.report_closure_arg_mismatch(
+            span,
+            found_span,
+            found_trait_ref,
+            expected_trait_ref,
+            obligation.cause.code(),
+            found_node,
+            obligation.param_env,
+        ))
     }
 
     /// Given some node representing a fn-like thing in the HIR map,
diff --git a/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs
new file mode 100644
index 00000000000..448c7100df6
--- /dev/null
+++ b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs
@@ -0,0 +1,11 @@
+#![feature(coroutines, coroutine_trait, stmt_expr_attributes)]
+
+use std::ops::Coroutine;
+
+fn foo() -> impl Coroutine<u8> {
+    //~^ ERROR type mismatch in coroutine arguments
+    #[coroutine]
+    |_: ()| {}
+}
+
+fn main() { }
diff --git a/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr
new file mode 100644
index 00000000000..c7d6507fd79
--- /dev/null
+++ b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr
@@ -0,0 +1,15 @@
+error[E0631]: type mismatch in coroutine arguments
+  --> $DIR/arg-count-mismatch-on-unit-input.rs:5:13
+   |
+LL | fn foo() -> impl Coroutine<u8> {
+   |             ^^^^^^^^^^^^^^^^^^ expected due to this
+...
+LL |     |_: ()| {}
+   |     ------- found signature defined here
+   |
+   = note: expected coroutine signature `fn(u8) -> _`
+              found coroutine signature `fn(()) -> _`
+
+error: aborting due to 1 previous error
+
+For more information about this error, try `rustc --explain E0631`.