about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/librustc_typeck/check/mod.rs43
-rw-r--r--src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs31
-rw-r--r--src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr68
3 files changed, 123 insertions, 19 deletions
diff --git a/src/librustc_typeck/check/mod.rs b/src/librustc_typeck/check/mod.rs
index b3ce92cb7d9..d8d01624f1d 100644
--- a/src/librustc_typeck/check/mod.rs
+++ b/src/librustc_typeck/check/mod.rs
@@ -3687,6 +3687,40 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         }
     }
 
+    /// If `expr` is a `match` expression that has only one non-`!` arm, use that arm's tail
+    /// expression's `Span`, otherwise return `expr.span`. This is done to give better errors
+    /// when given code like the following:
+    /// ```text
+    /// if false { return 0i32; } else { 1u32 }
+    /// //                               ^^^^ point at this instead of the whole `if` expression
+    /// ```
+    fn get_expr_coercion_span(&self, expr: &hir::Expr) -> syntax_pos::Span {
+        if let hir::ExprKind::Match(_, arms, _) = &expr.node {
+            let arm_spans: Vec<Span> = arms.iter().filter_map(|arm| {
+                self.in_progress_tables
+                    .and_then(|tables| tables.borrow().node_type_opt(arm.body.hir_id))
+                    .and_then(|arm_ty| {
+                        if arm_ty.is_never() {
+                            None
+                        } else {
+                            Some(match &arm.body.node {
+                                // Point at the tail expression when possible.
+                                hir::ExprKind::Block(block, _) => block.expr
+                                    .as_ref()
+                                    .map(|e| e.span)
+                                    .unwrap_or(block.span),
+                                _ => arm.body.span,
+                            })
+                        }
+                    })
+            }).collect();
+            if arm_spans.len() == 1 {
+                return arm_spans[0];
+            }
+        }
+        expr.span
+    }
+
     fn check_block_with_expected(
         &self,
         blk: &'tcx hir::Block,
@@ -3746,12 +3780,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             let coerce = ctxt.coerce.as_mut().unwrap();
             if let Some(tail_expr_ty) = tail_expr_ty {
                 let tail_expr = tail_expr.unwrap();
-                let cause = self.cause(tail_expr.span,
-                                       ObligationCauseCode::BlockTailExpression(blk.hir_id));
-                coerce.coerce(self,
-                              &cause,
-                              tail_expr,
-                              tail_expr_ty);
+                let span = self.get_expr_coercion_span(tail_expr);
+                let cause = self.cause(span, ObligationCauseCode::BlockTailExpression(blk.hir_id));
+                coerce.coerce(self, &cause, tail_expr, tail_expr_ty);
             } else {
                 // Subtle: if there is no explicit tail expression,
                 // that is typically equivalent to a tail expression
diff --git a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs
index 95b40368143..58109be447e 100644
--- a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs
+++ b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs
@@ -17,10 +17,10 @@ fn bar() -> impl std::fmt::Display {
 
 fn baz() -> impl std::fmt::Display {
     if false {
-    //~^ ERROR mismatched types
         return 0i32;
     } else {
         1u32
+        //~^ ERROR mismatched types
     }
 }
 
@@ -33,4 +33,33 @@ fn qux() -> impl std::fmt::Display {
     }
 }
 
+fn bat() -> impl std::fmt::Display {
+    match 13 {
+        0 => return 0i32,
+        _ => 1u32,
+        //~^ ERROR mismatched types
+    }
+}
+
+fn can() -> impl std::fmt::Display {
+    match 13 {
+    //~^ ERROR mismatched types
+        0 => return 0i32,
+        1 => 1u32,
+        _ => 2u32,
+    }
+}
+
+fn cat() -> impl std::fmt::Display {
+    match 13 {
+        0 => {
+            return 0i32;
+        }
+        _ => {
+            1u32
+            //~^ ERROR mismatched types
+        }
+    }
+}
+
 fn main() {}
diff --git a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr
index ee1e36081e7..314ff84ae3c 100644
--- a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr
+++ b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr
@@ -29,18 +29,16 @@ LL |         return 1u32;
               found type `u32`
 
 error[E0308]: mismatched types
-  --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:19:5
+  --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:22:9
    |
-LL |   fn baz() -> impl std::fmt::Display {
-   |               ---------------------- expected because this return type...
-LL | /     if false {
-LL | |
-LL | |         return 0i32;
-   | |                ---- ...is found to be `i32` here
-LL | |     } else {
-LL | |         1u32
-LL | |     }
-   | |_____^ expected i32, found u32
+LL | fn baz() -> impl std::fmt::Display {
+   |             ---------------------- expected because this return type...
+LL |     if false {
+LL |         return 0i32;
+   |                ---- ...is found to be `i32` here
+LL |     } else {
+LL |         1u32
+   |         ^^^^ expected i32, found u32
    |
    = note: expected type `i32`
               found type `u32`
@@ -61,6 +59,52 @@ LL | |     }
    = note: expected type `i32`
               found type `u32`
 
-error: aborting due to 4 previous errors
+error[E0308]: mismatched types
+  --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:39:14
+   |
+LL | fn bat() -> impl std::fmt::Display {
+   |             ---------------------- expected because this return type...
+LL |     match 13 {
+LL |         0 => return 0i32,
+   |                     ---- ...is found to be `i32` here
+LL |         _ => 1u32,
+   |              ^^^^ expected i32, found u32
+   |
+   = note: expected type `i32`
+              found type `u32`
+
+error[E0308]: mismatched types
+  --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:45:5
+   |
+LL |   fn can() -> impl std::fmt::Display {
+   |               ---------------------- expected because this return type...
+LL | /     match 13 {
+LL | |
+LL | |         0 => return 0i32,
+   | |                     ---- ...is found to be `i32` here
+LL | |         1 => 1u32,
+LL | |         _ => 2u32,
+LL | |     }
+   | |_____^ expected i32, found u32
+   |
+   = note: expected type `i32`
+              found type `u32`
+
+error[E0308]: mismatched types
+  --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:59:13
+   |
+LL | fn cat() -> impl std::fmt::Display {
+   |             ---------------------- expected because this return type...
+...
+LL |             return 0i32;
+   |                    ---- ...is found to be `i32` here
+...
+LL |             1u32
+   |             ^^^^ expected i32, found u32
+   |
+   = note: expected type `i32`
+              found type `u32`
+
+error: aborting due to 7 previous errors
 
 For more information about this error, try `rustc --explain E0308`.