about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection/src
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2024-12-14 23:56:28 +0100
committerGitHub <noreply@github.com>2024-12-14 23:56:28 +0100
commitdb77788dc5d72da6dc2077e59d9ff321cbda1cec (patch)
treec66a72aeeff0f31ff9a70e4a3b0b9c49fb0c9569 /compiler/rustc_trait_selection/src
parent0aeaa5eb22180fdf12a8489e63c4daa18da6f236 (diff)
parent831f4549cd1b23915729cbd2f1dd841621c4e8b8 (diff)
downloadrust-db77788dc5d72da6dc2077e59d9ff321cbda1cec.tar.gz
rust-db77788dc5d72da6dc2077e59d9ff321cbda1cec.zip
Rollup merge of #132939 - uellenberg:suggest-deref, r=oli-obk
Suggest using deref in patterns

Fixes #132784

This changes the following code:
```rs
use std::sync::Arc;
fn main() {
    let mut x = Arc::new(Some(1));
    match x {
        Some(_) => {}
        None => {}
    }
}
```

to output
```rs
error[E0308]: mismatched types
  --> src/main.rs:5:9
   |
LL |     match x {
   |           - this expression has type `Arc<Option<{integer}>>`
...
LL |         Some(_) => {}
   |         ^^^^^^^ expected `Arc<Option<{integer}>>`, found `Option<_>`
   |
   = note: expected struct `Arc<Option<{integer}>>`
                found enum `Option<_>`
help: consider dereferencing to access the inner value using the Deref trait
   |
LL |     match *x {
   |           ~~
```

instead of
```rs
error[E0308]: mismatched types
 --> src/main.rs:5:9
  |
4 |     match x {
  |           - this expression has type `Arc<Option<{integer}>>`
5 |         Some(_) => {}
  |         ^^^^^^^ expected `Arc<Option<{integer}>>`, found `Option<_>`
  |
  = note: expected struct `Arc<Option<{integer}>>`
               found enum `Option<_>`
```

This makes it more obvious that a Deref is available, and gives a suggestion on how to use it in order to fix the issue at hand.
Diffstat (limited to 'compiler/rustc_trait_selection/src')
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs115
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs2
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs11
3 files changed, 98 insertions, 30 deletions
diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
index f856a8d7abb..5a62a4c3bd5 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
@@ -63,10 +63,11 @@ use rustc_hir::{self as hir};
 use rustc_macros::extension;
 use rustc_middle::bug;
 use rustc_middle::dep_graph::DepContext;
+use rustc_middle::traits::PatternOriginExpr;
 use rustc_middle::ty::error::{ExpectedFound, TypeError, TypeErrorToStringExt};
 use rustc_middle::ty::print::{PrintError, PrintTraitRefExt as _, with_forced_trimmed_paths};
 use rustc_middle::ty::{
-    self, List, Region, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable, TypeVisitable,
+    self, List, ParamEnv, Region, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable, TypeVisitable,
     TypeVisitableExt,
 };
 use rustc_span::def_id::LOCAL_CRATE;
@@ -77,7 +78,7 @@ use crate::error_reporting::TypeErrCtxt;
 use crate::errors::{ObligationCauseFailureCode, TypeErrorAdditionalDiags};
 use crate::infer;
 use crate::infer::relate::{self, RelateResult, TypeRelation};
-use crate::infer::{InferCtxt, TypeTrace, ValuePairs};
+use crate::infer::{InferCtxt, InferCtxtExt as _, TypeTrace, ValuePairs};
 use crate::solve::deeply_normalize_for_diagnostics;
 use crate::traits::{
     IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
@@ -433,15 +434,22 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
         cause: &ObligationCause<'tcx>,
         exp_found: Option<ty::error::ExpectedFound<Ty<'tcx>>>,
         terr: TypeError<'tcx>,
+        param_env: Option<ParamEnv<'tcx>>,
     ) {
         match *cause.code() {
-            ObligationCauseCode::Pattern { origin_expr: true, span: Some(span), root_ty } => {
-                let ty = self.resolve_vars_if_possible(root_ty);
-                if !matches!(ty.kind(), ty::Infer(ty::InferTy::TyVar(_) | ty::InferTy::FreshTy(_)))
-                {
+            ObligationCauseCode::Pattern {
+                origin_expr: Some(origin_expr),
+                span: Some(span),
+                root_ty,
+            } => {
+                let expected_ty = self.resolve_vars_if_possible(root_ty);
+                if !matches!(
+                    expected_ty.kind(),
+                    ty::Infer(ty::InferTy::TyVar(_) | ty::InferTy::FreshTy(_))
+                ) {
                     // don't show type `_`
                     if span.desugaring_kind() == Some(DesugaringKind::ForLoop)
-                        && let ty::Adt(def, args) = ty.kind()
+                        && let ty::Adt(def, args) = expected_ty.kind()
                         && Some(def.did()) == self.tcx.get_diagnostic_item(sym::Option)
                     {
                         err.span_label(
@@ -449,22 +457,48 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                             format!("this is an iterator with items of type `{}`", args.type_at(0)),
                         );
                     } else {
-                        err.span_label(span, format!("this expression has type `{ty}`"));
+                        err.span_label(span, format!("this expression has type `{expected_ty}`"));
                     }
                 }
                 if let Some(ty::error::ExpectedFound { found, .. }) = exp_found
-                    && ty.boxed_ty() == Some(found)
-                    && let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span)
+                    && let Ok(mut peeled_snippet) =
+                        self.tcx.sess.source_map().span_to_snippet(origin_expr.peeled_span)
                 {
-                    err.span_suggestion(
-                        span,
-                        "consider dereferencing the boxed value",
-                        format!("*{snippet}"),
-                        Applicability::MachineApplicable,
-                    );
+                    // Parentheses are needed for cases like as casts.
+                    // We use the peeled_span for deref suggestions.
+                    // It's also safe to use for box, since box only triggers if there
+                    // wasn't a reference to begin with.
+                    if origin_expr.peeled_prefix_suggestion_parentheses {
+                        peeled_snippet = format!("({peeled_snippet})");
+                    }
+
+                    // Try giving a box suggestion first, as it is a special case of the
+                    // deref suggestion.
+                    if expected_ty.boxed_ty() == Some(found) {
+                        err.span_suggestion_verbose(
+                            span,
+                            "consider dereferencing the boxed value",
+                            format!("*{peeled_snippet}"),
+                            Applicability::MachineApplicable,
+                        );
+                    } else if let Some(param_env) = param_env
+                        && let Some(prefix) = self.should_deref_suggestion_on_mismatch(
+                            param_env,
+                            found,
+                            expected_ty,
+                            origin_expr,
+                        )
+                    {
+                        err.span_suggestion_verbose(
+                            span,
+                            "consider dereferencing to access the inner value using the Deref trait",
+                            format!("{prefix}{peeled_snippet}"),
+                            Applicability::MaybeIncorrect,
+                        );
+                    }
                 }
             }
-            ObligationCauseCode::Pattern { origin_expr: false, span: Some(span), .. } => {
+            ObligationCauseCode::Pattern { origin_expr: None, span: Some(span), .. } => {
                 err.span_label(span, "expected due to this");
             }
             ObligationCauseCode::BlockTailExpression(
@@ -618,6 +652,45 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
         }
     }
 
+    /// Determines whether deref_to == <deref_from as Deref>::Target, and if so,
+    /// returns a prefix that should be added to deref_from as a suggestion.
+    fn should_deref_suggestion_on_mismatch(
+        &self,
+        param_env: ParamEnv<'tcx>,
+        deref_to: Ty<'tcx>,
+        deref_from: Ty<'tcx>,
+        origin_expr: PatternOriginExpr,
+    ) -> Option<String> {
+        // origin_expr contains stripped away versions of our expression.
+        // We'll want to use that to avoid suggesting things like *&x.
+        // However, the type that we have access to hasn't been stripped away,
+        // so we need to ignore the first n dereferences, where n is the number
+        // that's been stripped away in origin_expr.
+
+        // Find a way to autoderef from deref_from to deref_to.
+        let Some((num_derefs, (after_deref_ty, _))) = (self.autoderef_steps)(deref_from)
+            .into_iter()
+            .enumerate()
+            .find(|(_, (ty, _))| self.infcx.can_eq(param_env, *ty, deref_to))
+        else {
+            return None;
+        };
+
+        if num_derefs <= origin_expr.peeled_count {
+            return None;
+        }
+
+        let deref_part = "*".repeat(num_derefs - origin_expr.peeled_count);
+
+        // If the user used a reference in the original expression, they probably
+        // want the suggestion to still give a reference.
+        if deref_from.is_ref() && !after_deref_ty.is_ref() {
+            Some(format!("&{deref_part}"))
+        } else {
+            Some(deref_part)
+        }
+    }
+
     /// Given that `other_ty` is the same as a type argument for `name` in `sub`, populate `value`
     /// highlighting `name` and every type argument that isn't at `pos` (which is `other_ty`), and
     /// populate `other_value` with `other_ty`.
@@ -1406,8 +1479,8 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
             Variable(ty::error::ExpectedFound<Ty<'a>>),
             Fixed(&'static str),
         }
-        let (expected_found, exp_found, is_simple_error, values) = match values {
-            None => (None, Mismatch::Fixed("type"), false, None),
+        let (expected_found, exp_found, is_simple_error, values, param_env) = match values {
+            None => (None, Mismatch::Fixed("type"), false, None, None),
             Some(ty::ParamEnvAnd { param_env, value: values }) => {
                 let mut values = self.resolve_vars_if_possible(values);
                 if self.next_trait_solver() {
@@ -1459,7 +1532,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                     diag.downgrade_to_delayed_bug();
                     return;
                 };
-                (Some(vals), exp_found, is_simple_error, Some(values))
+                (Some(vals), exp_found, is_simple_error, Some(values), Some(param_env))
             }
         };
 
@@ -1791,7 +1864,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
 
         // It reads better to have the error origin as the final
         // thing.
-        self.note_error_origin(diag, cause, exp_found, terr);
+        self.note_error_origin(diag, cause, exp_found, terr, param_env);
 
         debug!(?diag);
     }
diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs
index fc2d0ba36f0..08775df5ac9 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/infer/suggest.rs
@@ -210,7 +210,7 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
             (Some(ty), _) if self.same_type_modulo_infer(ty, exp_found.found) => match cause.code()
             {
                 ObligationCauseCode::Pattern { span: Some(then_span), origin_expr, .. } => {
-                    origin_expr.then_some(ConsiderAddingAwait::FutureSugg {
+                    origin_expr.is_some().then_some(ConsiderAddingAwait::FutureSugg {
                         span: then_span.shrink_to_hi(),
                     })
                 }
diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs
index 94682f501a8..cc8941b9224 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs
@@ -20,7 +20,8 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::intravisit::Visitor;
 use rustc_hir::lang_items::LangItem;
 use rustc_hir::{
-    CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, HirId, Node, is_range_literal,
+    CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, HirId, Node, expr_needs_parens,
+    is_range_literal,
 };
 use rustc_infer::infer::{BoundRegionConversionTime, DefineOpaqueTypes, InferCtxt, InferOk};
 use rustc_middle::hir::map;
@@ -1391,13 +1392,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                         let Some(expr) = expr_finder.result else {
                             return false;
                         };
-                        let needs_parens = match expr.kind {
-                            // parenthesize if needed (Issue #46756)
-                            hir::ExprKind::Cast(_, _) | hir::ExprKind::Binary(_, _, _) => true,
-                            // parenthesize borrows of range literals (Issue #54505)
-                            _ if is_range_literal(expr) => true,
-                            _ => false,
-                        };
+                        let needs_parens = expr_needs_parens(expr);
 
                         let span = if needs_parens { span } else { span.shrink_to_lo() };
                         let suggestions = if !needs_parens {