about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_hir_typeck/src/_match.rs99
-rw-r--r--compiler/rustc_hir_typeck/src/coercion.rs46
-rw-r--r--compiler/rustc_hir_typeck/src/expr.rs14
-rw-r--r--compiler/rustc_infer/src/infer/mod.rs15
-rw-r--r--compiler/rustc_middle/src/traits/mod.rs18
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs61
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs42
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs24
8 files changed, 132 insertions, 187 deletions
diff --git a/compiler/rustc_hir_typeck/src/_match.rs b/compiler/rustc_hir_typeck/src/_match.rs
index 4ac260cb15f..6467adb54da 100644
--- a/compiler/rustc_hir_typeck/src/_match.rs
+++ b/compiler/rustc_hir_typeck/src/_match.rs
@@ -1,12 +1,12 @@
 use rustc_errors::{Applicability, Diag};
 use rustc_hir::def::{CtorOf, DefKind, Res};
 use rustc_hir::def_id::LocalDefId;
-use rustc_hir::{self as hir, ExprKind, PatKind};
+use rustc_hir::{self as hir, ExprKind, HirId, PatKind};
 use rustc_hir_pretty::ty_to_string;
 use rustc_middle::ty::{self, Ty};
 use rustc_span::Span;
 use rustc_trait_selection::traits::{
-    IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
+    MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
 };
 use tracing::{debug, instrument};
 
@@ -414,105 +414,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
     pub(crate) fn if_cause(
         &self,
-        span: Span,
-        cond_span: Span,
-        then_expr: &'tcx hir::Expr<'tcx>,
+        expr_id: HirId,
         else_expr: &'tcx hir::Expr<'tcx>,
-        then_ty: Ty<'tcx>,
-        else_ty: Ty<'tcx>,
         tail_defines_return_position_impl_trait: Option<LocalDefId>,
     ) -> ObligationCause<'tcx> {
-        let mut outer_span = if self.tcx.sess.source_map().is_multiline(span) {
-            // The `if`/`else` isn't in one line in the output, include some context to make it
-            // clear it is an if/else expression:
-            // ```
-            // LL |      let x = if true {
-            //    | _____________-
-            // LL ||         10i32
-            //    ||         ----- expected because of this
-            // LL ||     } else {
-            // LL ||         10u32
-            //    ||         ^^^^^ expected `i32`, found `u32`
-            // LL ||     };
-            //    ||_____- `if` and `else` have incompatible types
-            // ```
-            Some(span)
-        } else {
-            // The entire expression is in one line, only point at the arms
-            // ```
-            // LL |     let x = if true { 10i32 } else { 10u32 };
-            //    |                       -----          ^^^^^ expected `i32`, found `u32`
-            //    |                       |
-            //    |                       expected because of this
-            // ```
-            None
-        };
-
-        let (error_sp, else_id) = if let ExprKind::Block(block, _) = &else_expr.kind {
-            let block = block.innermost_block();
-
-            // Avoid overlapping spans that aren't as readable:
-            // ```
-            // 2 |        let x = if true {
-            //   |   _____________-
-            // 3 |  |         3
-            //   |  |         - expected because of this
-            // 4 |  |     } else {
-            //   |  |____________^
-            // 5 | ||
-            // 6 | ||     };
-            //   | ||     ^
-            //   | ||_____|
-            //   | |______if and else have incompatible types
-            //   |        expected integer, found `()`
-            // ```
-            // by not pointing at the entire expression:
-            // ```
-            // 2 |       let x = if true {
-            //   |               ------- `if` and `else` have incompatible types
-            // 3 |           3
-            //   |           - expected because of this
-            // 4 |       } else {
-            //   |  ____________^
-            // 5 | |
-            // 6 | |     };
-            //   | |_____^ expected integer, found `()`
-            // ```
-            if block.expr.is_none()
-                && block.stmts.is_empty()
-                && let Some(outer_span) = &mut outer_span
-                && let Some(cond_span) = cond_span.find_ancestor_inside(*outer_span)
-            {
-                *outer_span = outer_span.with_hi(cond_span.hi())
-            }
-
-            (self.find_block_span(block), block.hir_id)
-        } else {
-            (else_expr.span, else_expr.hir_id)
-        };
-
-        let then_id = if let ExprKind::Block(block, _) = &then_expr.kind {
-            let block = block.innermost_block();
-            // Exclude overlapping spans
-            if block.expr.is_none() && block.stmts.is_empty() {
-                outer_span = None;
-            }
-            block.hir_id
-        } else {
-            then_expr.hir_id
-        };
+        let error_sp = self.find_block_span_from_hir_id(else_expr.hir_id);
 
         // Finally construct the cause:
         self.cause(
             error_sp,
-            ObligationCauseCode::IfExpression(Box::new(IfExpressionCause {
-                else_id,
-                then_id,
-                then_ty,
-                else_ty,
-                outer_span,
-                tail_defines_return_position_impl_trait,
-            })),
+            ObligationCauseCode::IfExpression { expr_id, tail_defines_return_position_impl_trait },
         )
     }
 
diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs
index 0ce0bc313c7..a9367415263 100644
--- a/compiler/rustc_hir_typeck/src/coercion.rs
+++ b/compiler/rustc_hir_typeck/src/coercion.rs
@@ -46,8 +46,7 @@ use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
 use rustc_infer::infer::relate::RelateResult;
 use rustc_infer::infer::{Coercion, DefineOpaqueTypes, InferOk, InferResult};
 use rustc_infer::traits::{
-    IfExpressionCause, ImplSource, MatchExpressionArmCause, Obligation, PredicateObligation,
-    PredicateObligations, SelectionError,
+    MatchExpressionArmCause, Obligation, PredicateObligation, PredicateObligations, SelectionError,
 };
 use rustc_middle::span_bug;
 use rustc_middle::ty::adjustment::{
@@ -59,7 +58,7 @@ use rustc_span::{BytePos, DUMMY_SP, DesugaringKind, Span};
 use rustc_trait_selection::infer::InferCtxtExt as _;
 use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
 use rustc_trait_selection::traits::{
-    self, NormalizeExt, ObligationCause, ObligationCauseCode, ObligationCtxt,
+    self, ImplSource, NormalizeExt, ObligationCause, ObligationCauseCode, ObligationCtxt,
 };
 use smallvec::{SmallVec, smallvec};
 use tracing::{debug, instrument};
@@ -1719,14 +1718,17 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
                             );
                         }
                     }
-                    ObligationCauseCode::IfExpression(box IfExpressionCause {
-                        then_id,
-                        else_id,
-                        then_ty,
-                        else_ty,
+                    ObligationCauseCode::IfExpression {
+                        expr_id,
                         tail_defines_return_position_impl_trait: Some(rpit_def_id),
-                        ..
-                    }) => {
+                    } => {
+                        let hir::Node::Expr(hir::Expr {
+                            kind: hir::ExprKind::If(_, then_expr, Some(else_expr)),
+                            ..
+                        }) = fcx.tcx.hir_node(expr_id)
+                        else {
+                            unreachable!();
+                        };
                         err = fcx.err_ctxt().report_mismatched_types(
                             cause,
                             fcx.param_env,
@@ -1734,24 +1736,12 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
                             found,
                             coercion_error,
                         );
-                        let then_span = fcx.find_block_span_from_hir_id(then_id);
-                        let else_span = fcx.find_block_span_from_hir_id(else_id);
-                        // don't suggest wrapping either blocks in `if .. {} else {}`
-                        let is_empty_arm = |id| {
-                            let hir::Node::Block(blk) = fcx.tcx.hir_node(id) else {
-                                return false;
-                            };
-                            if blk.expr.is_some() || !blk.stmts.is_empty() {
-                                return false;
-                            }
-                            let Some((_, hir::Node::Expr(expr))) =
-                                fcx.tcx.hir_parent_iter(id).nth(1)
-                            else {
-                                return false;
-                            };
-                            matches!(expr.kind, hir::ExprKind::If(..))
-                        };
-                        if !is_empty_arm(then_id) && !is_empty_arm(else_id) {
+                        let then_span = fcx.find_block_span_from_hir_id(then_expr.hir_id);
+                        let else_span = fcx.find_block_span_from_hir_id(else_expr.hir_id);
+                        // Don't suggest wrapping whole block in `Box::new`.
+                        if then_span != then_expr.span && else_span != else_expr.span {
+                            let then_ty = fcx.typeck_results.borrow().expr_ty(then_expr);
+                            let else_ty = fcx.typeck_results.borrow().expr_ty(else_expr);
                             self.suggest_boxing_tail_for_return_position_impl_trait(
                                 fcx,
                                 &mut err,
diff --git a/compiler/rustc_hir_typeck/src/expr.rs b/compiler/rustc_hir_typeck/src/expr.rs
index 2bc9dadb665..3a0d57dca12 100644
--- a/compiler/rustc_hir_typeck/src/expr.rs
+++ b/compiler/rustc_hir_typeck/src/expr.rs
@@ -583,7 +583,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 ascribed_ty
             }
             ExprKind::If(cond, then_expr, opt_else_expr) => {
-                self.check_expr_if(cond, then_expr, opt_else_expr, expr.span, expected)
+                self.check_expr_if(expr.hir_id, cond, then_expr, opt_else_expr, expr.span, expected)
             }
             ExprKind::DropTemps(e) => self.check_expr_with_expectation(e, expected),
             ExprKind::Array(args) => self.check_expr_array(args, expected, expr),
@@ -1343,6 +1343,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     // or 'if-else' expression.
     fn check_expr_if(
         &self,
+        expr_id: HirId,
         cond_expr: &'tcx hir::Expr<'tcx>,
         then_expr: &'tcx hir::Expr<'tcx>,
         opt_else_expr: Option<&'tcx hir::Expr<'tcx>>,
@@ -1382,15 +1383,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
             let tail_defines_return_position_impl_trait =
                 self.return_position_impl_trait_from_match_expectation(orig_expected);
-            let if_cause = self.if_cause(
-                sp,
-                cond_expr.span,
-                then_expr,
-                else_expr,
-                then_ty,
-                else_ty,
-                tail_defines_return_position_impl_trait,
-            );
+            let if_cause =
+                self.if_cause(expr_id, else_expr, tail_defines_return_position_impl_trait);
 
             coerce.coerce(self, &if_cause, else_expr, else_ty);
 
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index e9b58eb959b..491efba9eb0 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -35,7 +35,7 @@ use rustc_middle::ty::{
     PseudoCanonicalInput, Term, TermKind, Ty, TyCtxt, TyVid, TypeFoldable, TypeFolder,
     TypeSuperFoldable, TypeVisitable, TypeVisitableExt, TypingEnv, TypingMode, fold_regions,
 };
-use rustc_span::{Span, Symbol};
+use rustc_span::{DUMMY_SP, Span, Symbol};
 use snapshot::undo_log::InferCtxtUndoLogs;
 use tracing::{debug, instrument};
 use type_variable::TypeVariableOrigin;
@@ -1557,15 +1557,16 @@ impl<'tcx> InferCtxt<'tcx> {
         }
     }
 
-    /// Given a [`hir::HirId`] for a block, get the span of its last expression
-    /// or statement, peeling off any inner blocks.
+    /// Given a [`hir::HirId`] for a block (or an expr of a block), get the span
+    /// of its last expression or statement, peeling off any inner blocks.
     pub fn find_block_span_from_hir_id(&self, hir_id: hir::HirId) -> Span {
         match self.tcx.hir_node(hir_id) {
-            hir::Node::Block(blk) => self.find_block_span(blk),
-            // The parser was in a weird state if either of these happen, but
-            // it's better not to panic.
+            hir::Node::Block(blk)
+            | hir::Node::Expr(&hir::Expr { kind: hir::ExprKind::Block(blk, _), .. }) => {
+                self.find_block_span(blk)
+            }
             hir::Node::Expr(e) => e.span,
-            _ => rustc_span::DUMMY_SP,
+            _ => DUMMY_SP,
         }
     }
 }
diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs
index d877bd5c626..1a5a9765ce7 100644
--- a/compiler/rustc_middle/src/traits/mod.rs
+++ b/compiler/rustc_middle/src/traits/mod.rs
@@ -332,7 +332,11 @@ pub enum ObligationCauseCode<'tcx> {
     },
 
     /// Computing common supertype in an if expression
-    IfExpression(Box<IfExpressionCause<'tcx>>),
+    IfExpression {
+        expr_id: HirId,
+        // Is the expectation of this match expression an RPIT?
+        tail_defines_return_position_impl_trait: Option<LocalDefId>,
+    },
 
     /// Computing common supertype of an if expression with no else counter-part
     IfExpressionWithNoElse,
@@ -550,18 +554,6 @@ pub struct PatternOriginExpr {
     pub peeled_prefix_suggestion_parentheses: bool,
 }
 
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-#[derive(TypeFoldable, TypeVisitable, HashStable, TyEncodable, TyDecodable)]
-pub struct IfExpressionCause<'tcx> {
-    pub then_id: HirId,
-    pub else_id: HirId,
-    pub then_ty: Ty<'tcx>,
-    pub else_ty: Ty<'tcx>,
-    pub outer_span: Option<Span>,
-    // Is the expectation of this match expression an RPIT?
-    pub tail_defines_return_position_impl_trait: Option<LocalDefId>,
-}
-
 #[derive(Clone, Debug, PartialEq, Eq, HashStable, TyEncodable, TyDecodable)]
 #[derive(TypeVisitable, TypeFoldable)]
 pub struct DerivedCause<'tcx> {
diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
index 2c16672d786..bc464b099e2 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
@@ -82,9 +82,7 @@ use crate::infer;
 use crate::infer::relate::{self, RelateResult, TypeRelation};
 use crate::infer::{InferCtxt, InferCtxtExt as _, TypeTrace, ValuePairs};
 use crate::solve::deeply_normalize_for_diagnostics;
-use crate::traits::{
-    IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
-};
+use crate::traits::{MatchExpressionArmCause, ObligationCause, ObligationCauseCode};
 
 mod note_and_explain;
 mod suggest;
@@ -613,18 +611,28 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                     }
                 }
             },
-            ObligationCauseCode::IfExpression(box IfExpressionCause {
-                then_id,
-                else_id,
-                then_ty,
-                else_ty,
-                outer_span,
-                ..
-            }) => {
-                let then_span = self.find_block_span_from_hir_id(then_id);
-                let else_span = self.find_block_span_from_hir_id(else_id);
-                if let hir::Node::Expr(e) = self.tcx.hir_node(else_id)
-                    && let hir::ExprKind::If(_cond, _then, None) = e.kind
+            ObligationCauseCode::IfExpression { expr_id, .. } => {
+                let hir::Node::Expr(&hir::Expr {
+                    kind: hir::ExprKind::If(cond_expr, then_expr, Some(else_expr)),
+                    span: expr_span,
+                    ..
+                }) = self.tcx.hir_node(expr_id)
+                else {
+                    return;
+                };
+                let then_span = self.find_block_span_from_hir_id(then_expr.hir_id);
+                let then_ty = self
+                    .typeck_results
+                    .as_ref()
+                    .expect("if expression only expected inside FnCtxt")
+                    .expr_ty(then_expr);
+                let else_span = self.find_block_span_from_hir_id(else_expr.hir_id);
+                let else_ty = self
+                    .typeck_results
+                    .as_ref()
+                    .expect("if expression only expected inside FnCtxt")
+                    .expr_ty(else_expr);
+                if let hir::ExprKind::If(_cond, _then, None) = else_expr.kind
                     && else_ty.is_unit()
                 {
                     // Account for `let x = if a { 1 } else if b { 2 };`
@@ -632,9 +640,32 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                     err.note("consider adding an `else` block that evaluates to the expected type");
                 }
                 err.span_label(then_span, "expected because of this");
+
+                let outer_span = if self.tcx.sess.source_map().is_multiline(expr_span) {
+                    if then_span.hi() == expr_span.hi() || else_span.hi() == expr_span.hi() {
+                        // Point at condition only if either block has the same end point as
+                        // the whole expression, since that'll cause awkward overlapping spans.
+                        Some(expr_span.shrink_to_lo().to(cond_expr.peel_drop_temps().span))
+                    } else {
+                        Some(expr_span)
+                    }
+                } else {
+                    None
+                };
                 if let Some(sp) = outer_span {
                     err.span_label(sp, "`if` and `else` have incompatible types");
                 }
+
+                let then_id = if let hir::ExprKind::Block(then_blk, _) = then_expr.kind {
+                    then_blk.hir_id
+                } else {
+                    then_expr.hir_id
+                };
+                let else_id = if let hir::ExprKind::Block(else_blk, _) = else_expr.kind {
+                    else_blk.hir_id
+                } else {
+                    else_expr.hir_id
+                };
                 if let Some(subdiag) = self.suggest_remove_semi_or_return_binding(
                     Some(then_id),
                     then_ty,
diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs
index be508c8cee1..0a4a9144c94 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs
@@ -420,19 +420,33 @@ impl<T> Trait<T> for X {
                         }
                         // If two if arms can be coerced to a trait object, provide a structured
                         // suggestion.
-                        let ObligationCauseCode::IfExpression(cause) = cause.code() else {
+                        let ObligationCauseCode::IfExpression { expr_id, .. } = cause.code() else {
                             return;
                         };
-                        let hir::Node::Block(blk) = self.tcx.hir_node(cause.then_id) else {
-                            return;
-                        };
-                        let Some(then) = blk.expr else {
-                            return;
-                        };
-                        let hir::Node::Block(blk) = self.tcx.hir_node(cause.else_id) else {
-                            return;
-                        };
-                        let Some(else_) = blk.expr else {
+                        let hir::Node::Expr(&hir::Expr {
+                            kind:
+                                hir::ExprKind::If(
+                                    _,
+                                    &hir::Expr {
+                                        kind:
+                                            hir::ExprKind::Block(
+                                                &hir::Block { expr: Some(then), .. },
+                                                _,
+                                            ),
+                                        ..
+                                    },
+                                    Some(&hir::Expr {
+                                        kind:
+                                            hir::ExprKind::Block(
+                                                &hir::Block { expr: Some(else_), .. },
+                                                _,
+                                            ),
+                                        ..
+                                    }),
+                                ),
+                            ..
+                        }) = self.tcx.hir_node(*expr_id)
+                        else {
                             return;
                         };
                         let expected = match values.found.kind() {
@@ -486,8 +500,10 @@ impl<T> Trait<T> for X {
                         }
                     }
                     (ty::Adt(_, _), ty::Adt(def, args))
-                        if let ObligationCauseCode::IfExpression(cause) = cause.code()
-                            && let hir::Node::Block(blk) = self.tcx.hir_node(cause.then_id)
+                        if let ObligationCauseCode::IfExpression { expr_id, .. } = cause.code()
+                            && let hir::Node::Expr(if_expr) = self.tcx.hir_node(*expr_id)
+                            && let hir::ExprKind::If(_, then_expr, _) = if_expr.kind
+                            && let hir::ExprKind::Block(blk, _) = then_expr.kind
                             && let Some(then) = blk.expr
                             && def.is_box()
                             && let boxed_ty = args.type_at(0)
diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs
index 3804c13acce..c0daf08ce07 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs
@@ -8,9 +8,7 @@ use rustc_errors::{Applicability, Diag};
 use rustc_hir as hir;
 use rustc_hir::def::Res;
 use rustc_hir::{MatchSource, Node};
-use rustc_middle::traits::{
-    IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
-};
+use rustc_middle::traits::{MatchExpressionArmCause, ObligationCause, ObligationCauseCode};
 use rustc_middle::ty::error::TypeError;
 use rustc_middle::ty::print::with_no_trimmed_paths;
 use rustc_middle::ty::{self as ty, GenericArgKind, IsSuggestable, Ty, TypeVisitableExt};
@@ -196,8 +194,14 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
             (Some(exp), Some(found)) if self.same_type_modulo_infer(exp, found) => match cause
                 .code()
             {
-                ObligationCauseCode::IfExpression(box IfExpressionCause { then_id, .. }) => {
-                    let then_span = self.find_block_span_from_hir_id(*then_id);
+                ObligationCauseCode::IfExpression { expr_id, .. } => {
+                    let hir::Node::Expr(hir::Expr {
+                        kind: hir::ExprKind::If(_, then_expr, _), ..
+                    }) = self.tcx.hir_node(*expr_id)
+                    else {
+                        return;
+                    };
+                    let then_span = self.find_block_span_from_hir_id(then_expr.hir_id);
                     Some(ConsiderAddingAwait::BothFuturesSugg {
                         first: then_span.shrink_to_hi(),
                         second: exp_span.shrink_to_hi(),
@@ -232,8 +236,14 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
                         span: then_span.shrink_to_hi(),
                     })
                 }
-                ObligationCauseCode::IfExpression(box IfExpressionCause { then_id, .. }) => {
-                    let then_span = self.find_block_span_from_hir_id(*then_id);
+                ObligationCauseCode::IfExpression { expr_id, .. } => {
+                    let hir::Node::Expr(hir::Expr {
+                        kind: hir::ExprKind::If(_, then_expr, _), ..
+                    }) = self.tcx.hir_node(*expr_id)
+                    else {
+                        return;
+                    };
+                    let then_span = self.find_block_span_from_hir_id(then_expr.hir_id);
                     Some(ConsiderAddingAwait::FutureSugg { span: then_span.shrink_to_hi() })
                 }
                 ObligationCauseCode::MatchExpressionArm(box MatchExpressionArmCause {