about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-04-11 02:37:35 +0000
committerbors <bors@rust-lang.org>2024-04-11 02:37:35 +0000
commit08273780d84816d85f002f4385c342fc7eaba58b (patch)
tree72d64c0f45a00edb2424af00a3434c9649ca8a78
parent4435924bb62cda0131e38dd5d2bba36f9616039f (diff)
parentf2465f8f2009e58412e8ca62abba6b2e4b8dd554 (diff)
downloadrust-08273780d84816d85f002f4385c342fc7eaba58b.tar.gz
rust-08273780d84816d85f002f4385c342fc7eaba58b.zip
Auto merge of #122213 - estebank:issue-50195, r=oli-obk,estebank
Provide suggestion to dereference closure tail if appropriate

When encoutnering a case like

```rust
use std::collections::HashMap;

fn main() {
    let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];

    let mut counts = HashMap::new();
    for num in vs {
        let count = counts.entry(num).or_insert(0);
        *count += 1;
    }

    let _ = counts.iter().max_by_key(|(_, v)| v);
```
produce the following suggestion
```
error: lifetime may not live long enough
  --> $DIR/return-value-lifetime-error.rs:13:47
   |
LL |     let _ = counts.iter().max_by_key(|(_, v)| v);
   |                                       ------- ^ returning this value requires that `'1` must outlive `'2`
   |                                       |     |
   |                                       |     return type of closure is &'2 &i32
   |                                       has type `&'1 (&i32, &i32)`
   |
help: dereference the return value
   |
LL |     let _ = counts.iter().max_by_key(|(_, v)| **v);
   |                                               ++
```

Fix #50195.
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs12
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/region_errors.rs145
-rw-r--r--compiler/rustc_hir_typeck/src/lib.rs25
-rw-r--r--compiler/rustc_middle/src/query/keys.rs13
-rw-r--r--compiler/rustc_middle/src/query/mod.rs3
-rw-r--r--tests/ui/closures/return-value-lifetime-error.fixed16
-rw-r--r--tests/ui/closures/return-value-lifetime-error.rs16
-rw-r--r--tests/ui/closures/return-value-lifetime-error.stderr16
8 files changed, 241 insertions, 5 deletions
diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
index 62e16d445c6..47bd24f1e14 100644
--- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
@@ -1469,27 +1469,31 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
         let hir = tcx.hir();
         let Some(body_id) = tcx.hir_node(self.mir_hir_id()).body_id() else { return };
         struct FindUselessClone<'hir> {
+            tcx: TyCtxt<'hir>,
+            def_id: DefId,
             pub clones: Vec<&'hir hir::Expr<'hir>>,
         }
         impl<'hir> FindUselessClone<'hir> {
-            pub fn new() -> Self {
-                Self { clones: vec![] }
+            pub fn new(tcx: TyCtxt<'hir>, def_id: DefId) -> Self {
+                Self { tcx, def_id, clones: vec![] }
             }
         }
 
         impl<'v> Visitor<'v> for FindUselessClone<'v> {
             fn visit_expr(&mut self, ex: &'v hir::Expr<'v>) {
-                // FIXME: use `lookup_method_for_diagnostic`?
                 if let hir::ExprKind::MethodCall(segment, _rcvr, args, _span) = ex.kind
                     && segment.ident.name == sym::clone
                     && args.len() == 0
+                    && let Some(def_id) = self.def_id.as_local()
+                    && let Some(method) = self.tcx.lookup_method_for_diagnostic((def_id, ex.hir_id))
+                    && Some(self.tcx.parent(method)) == self.tcx.lang_items().clone_trait()
                 {
                     self.clones.push(ex);
                 }
                 hir::intravisit::walk_expr(self, ex);
             }
         }
-        let mut expr_finder = FindUselessClone::new();
+        let mut expr_finder = FindUselessClone::new(tcx, self.mir_def_id().into());
 
         let body = hir.body(body_id).value;
         expr_finder.visit_expr(body);
diff --git a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs
index c92fccc959f..304d41d6941 100644
--- a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs
@@ -26,6 +26,9 @@ use rustc_middle::ty::{self, RegionVid, Ty};
 use rustc_middle::ty::{Region, TyCtxt};
 use rustc_span::symbol::{kw, Ident};
 use rustc_span::Span;
+use rustc_trait_selection::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
+use rustc_trait_selection::infer::InferCtxtExt;
+use rustc_trait_selection::traits::{Obligation, ObligationCtxt};
 
 use crate::borrowck_errors;
 use crate::session_diagnostics::{
@@ -810,6 +813,7 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
         self.add_static_impl_trait_suggestion(&mut diag, *fr, fr_name, *outlived_fr);
         self.suggest_adding_lifetime_params(&mut diag, *fr, *outlived_fr);
         self.suggest_move_on_borrowing_closure(&mut diag);
+        self.suggest_deref_closure_value(&mut diag);
 
         diag
     }
@@ -1041,6 +1045,147 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
 
     #[allow(rustc::diagnostic_outside_of_impl)]
     #[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable
+    /// When encountering a lifetime error caused by the return type of a closure, check the
+    /// corresponding trait bound and see if dereferencing the closure return value would satisfy
+    /// them. If so, we produce a structured suggestion.
+    fn suggest_deref_closure_value(&self, diag: &mut Diag<'_>) {
+        let tcx = self.infcx.tcx;
+        let map = tcx.hir();
+
+        // Get the closure return value and type.
+        let body_id = map.body_owned_by(self.mir_def_id());
+        let body = &map.body(body_id);
+        let value = &body.value.peel_blocks();
+        let hir::Node::Expr(closure_expr) = tcx.hir_node_by_def_id(self.mir_def_id()) else {
+            return;
+        };
+        let fn_call_id = tcx.parent_hir_id(self.mir_hir_id());
+        let hir::Node::Expr(expr) = tcx.hir_node(fn_call_id) else { return };
+        let def_id = map.enclosing_body_owner(fn_call_id);
+        let tables = tcx.typeck(def_id);
+        let Some(return_value_ty) = tables.node_type_opt(value.hir_id) else { return };
+        let return_value_ty = self.infcx.resolve_vars_if_possible(return_value_ty);
+
+        // We don't use `ty.peel_refs()` to get the number of `*`s needed to get the root type.
+        let mut ty = return_value_ty;
+        let mut count = 0;
+        while let ty::Ref(_, t, _) = ty.kind() {
+            ty = *t;
+            count += 1;
+        }
+        if !self.infcx.type_is_copy_modulo_regions(self.param_env, ty) {
+            return;
+        }
+
+        // Build a new closure where the return type is an owned value, instead of a ref.
+        let Some(ty::Closure(did, args)) =
+            tables.node_type_opt(closure_expr.hir_id).as_ref().map(|ty| ty.kind())
+        else {
+            return;
+        };
+        let sig = args.as_closure().sig();
+        let closure_sig_as_fn_ptr_ty = Ty::new_fn_ptr(
+            tcx,
+            sig.map_bound(|s| {
+                let unsafety = hir::Unsafety::Normal;
+                use rustc_target::spec::abi;
+                tcx.mk_fn_sig(
+                    [s.inputs()[0]],
+                    s.output().peel_refs(),
+                    s.c_variadic,
+                    unsafety,
+                    abi::Abi::Rust,
+                )
+            }),
+        );
+        let parent_args = GenericArgs::identity_for_item(
+            tcx,
+            tcx.typeck_root_def_id(self.mir_def_id().to_def_id()),
+        );
+        let closure_kind = args.as_closure().kind();
+        let closure_kind_ty = Ty::from_closure_kind(tcx, closure_kind);
+        let tupled_upvars_ty = self.infcx.next_ty_var(TypeVariableOrigin {
+            kind: TypeVariableOriginKind::ClosureSynthetic,
+            span: closure_expr.span,
+        });
+        let closure_args = ty::ClosureArgs::new(
+            tcx,
+            ty::ClosureArgsParts {
+                parent_args,
+                closure_kind_ty,
+                closure_sig_as_fn_ptr_ty,
+                tupled_upvars_ty,
+            },
+        );
+        let closure_ty = Ty::new_closure(tcx, *did, closure_args.args);
+        let closure_ty = tcx.erase_regions(closure_ty);
+
+        let hir::ExprKind::MethodCall(_, rcvr, args, _) = expr.kind else { return };
+        let Some(pos) = args
+            .iter()
+            .enumerate()
+            .find(|(_, arg)| arg.hir_id == closure_expr.hir_id)
+            .map(|(i, _)| i)
+        else {
+            return;
+        };
+        // The found `Self` type of the method call.
+        let Some(possible_rcvr_ty) = tables.node_type_opt(rcvr.hir_id) else { return };
+
+        // The `MethodCall` expression is `Res::Err`, so we search for the method on the `rcvr_ty`.
+        let Some(method) = tcx.lookup_method_for_diagnostic((self.mir_def_id(), expr.hir_id))
+        else {
+            return;
+        };
+
+        // Get the type for the parameter corresponding to the argument the closure with the
+        // lifetime error we had.
+        let Some(input) = tcx
+            .fn_sig(method)
+            .instantiate_identity()
+            .inputs()
+            .skip_binder()
+            // Methods have a `self` arg, so `pos` is actually `+ 1` to match the method call arg.
+            .get(pos + 1)
+        else {
+            return;
+        };
+
+        trace!(?input);
+
+        let ty::Param(closure_param) = input.kind() else { return };
+
+        // Get the arguments for the found method, only specifying that `Self` is the receiver type.
+        let args = GenericArgs::for_item(tcx, method, |param, _| {
+            if param.index == 0 {
+                possible_rcvr_ty.into()
+            } else if param.index == closure_param.index {
+                closure_ty.into()
+            } else {
+                self.infcx.var_for_def(expr.span, param)
+            }
+        });
+
+        let preds = tcx.predicates_of(method).instantiate(tcx, args);
+
+        let ocx = ObligationCtxt::new(&self.infcx);
+        ocx.register_obligations(preds.iter().map(|(pred, span)| {
+            trace!(?pred);
+            Obligation::misc(tcx, span, self.mir_def_id(), self.param_env, pred)
+        }));
+
+        if ocx.select_all_or_error().is_empty() {
+            diag.span_suggestion_verbose(
+                value.span.shrink_to_lo(),
+                "dereference the return value",
+                "*".repeat(count),
+                Applicability::MachineApplicable,
+            );
+        }
+    }
+
+    #[allow(rustc::diagnostic_outside_of_impl)]
+    #[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable
     fn suggest_move_on_borrowing_closure(&self, diag: &mut Diag<'_>) {
         let map = self.infcx.tcx.hir();
         let body_id = map.body_owned_by(self.mir_def_id());
diff --git a/compiler/rustc_hir_typeck/src/lib.rs b/compiler/rustc_hir_typeck/src/lib.rs
index 700dde184f2..476df9ae793 100644
--- a/compiler/rustc_hir_typeck/src/lib.rs
+++ b/compiler/rustc_hir_typeck/src/lib.rs
@@ -56,7 +56,7 @@ use rustc_data_structures::unord::UnordSet;
 use rustc_errors::{codes::*, struct_span_code_err, ErrorGuaranteed};
 use rustc_hir as hir;
 use rustc_hir::def::{DefKind, Res};
-use rustc_hir::intravisit::Visitor;
+use rustc_hir::intravisit::{Map, Visitor};
 use rustc_hir::{HirIdMap, Node};
 use rustc_hir_analysis::check::check_abi;
 use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
@@ -436,6 +436,28 @@ fn fatally_break_rust(tcx: TyCtxt<'_>, span: Span) -> ! {
     diag.emit()
 }
 
+pub fn lookup_method_for_diagnostic<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    (def_id, hir_id): (LocalDefId, hir::HirId),
+) -> Option<DefId> {
+    let root_ctxt = TypeckRootCtxt::new(tcx, def_id);
+    let param_env = tcx.param_env(def_id);
+    let fn_ctxt = FnCtxt::new(&root_ctxt, param_env, def_id);
+    let hir::Node::Expr(expr) = tcx.hir().hir_node(hir_id) else {
+        return None;
+    };
+    let hir::ExprKind::MethodCall(segment, rcvr, _, _) = expr.kind else {
+        return None;
+    };
+    let tables = tcx.typeck(def_id);
+    // The found `Self` type of the method call.
+    let possible_rcvr_ty = tables.node_type_opt(rcvr.hir_id)?;
+    fn_ctxt
+        .lookup_method_for_diagnostic(possible_rcvr_ty, segment, expr.span, expr, rcvr)
+        .ok()
+        .map(|method| method.def_id)
+}
+
 pub fn provide(providers: &mut Providers) {
     method::provide(providers);
     *providers = Providers {
@@ -443,6 +465,7 @@ pub fn provide(providers: &mut Providers) {
         diagnostic_only_typeck,
         has_typeck_results,
         used_trait_imports,
+        lookup_method_for_diagnostic: lookup_method_for_diagnostic,
         ..*providers
     };
 }
diff --git a/compiler/rustc_middle/src/query/keys.rs b/compiler/rustc_middle/src/query/keys.rs
index c1548eb99f5..faa137019cb 100644
--- a/compiler/rustc_middle/src/query/keys.rs
+++ b/compiler/rustc_middle/src/query/keys.rs
@@ -555,6 +555,19 @@ impl Key for HirId {
     }
 }
 
+impl Key for (LocalDefId, HirId) {
+    type Cache<V> = DefaultCache<Self, V>;
+
+    fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
+        tcx.hir().span(self.1)
+    }
+
+    #[inline(always)]
+    fn key_as_def_id(&self) -> Option<DefId> {
+        Some(self.0.into())
+    }
+}
+
 impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) {
     type Cache<V> = DefaultCache<Self, V>;
 
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 5ef7a20f460..394515f091f 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -983,6 +983,9 @@ rustc_queries! {
     query diagnostic_only_typeck(key: LocalDefId) -> &'tcx ty::TypeckResults<'tcx> {
         desc { |tcx| "type-checking `{}`", tcx.def_path_str(key) }
     }
+    query lookup_method_for_diagnostic((def_id, hir_id): (LocalDefId, hir::HirId)) -> Option<DefId> {
+        desc { |tcx| "lookup_method_for_diagnostics `{}`", tcx.def_path_str(def_id) }
+    }
 
     query used_trait_imports(key: LocalDefId) -> &'tcx UnordSet<LocalDefId> {
         desc { |tcx| "finding used_trait_imports `{}`", tcx.def_path_str(key) }
diff --git a/tests/ui/closures/return-value-lifetime-error.fixed b/tests/ui/closures/return-value-lifetime-error.fixed
new file mode 100644
index 00000000000..bf1f7e4a6cf
--- /dev/null
+++ b/tests/ui/closures/return-value-lifetime-error.fixed
@@ -0,0 +1,16 @@
+//@ run-rustfix
+use std::collections::HashMap;
+
+fn main() {
+    let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];
+
+    let mut counts = HashMap::new();
+    for num in vs {
+        let count = counts.entry(num).or_insert(0);
+        *count += 1;
+    }
+
+    let _ = counts.iter().max_by_key(|(_, v)| **v);
+    //~^ ERROR lifetime may not live long enough
+    //~| HELP dereference the return value
+}
diff --git a/tests/ui/closures/return-value-lifetime-error.rs b/tests/ui/closures/return-value-lifetime-error.rs
new file mode 100644
index 00000000000..411c91f413e
--- /dev/null
+++ b/tests/ui/closures/return-value-lifetime-error.rs
@@ -0,0 +1,16 @@
+//@ run-rustfix
+use std::collections::HashMap;
+
+fn main() {
+    let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];
+
+    let mut counts = HashMap::new();
+    for num in vs {
+        let count = counts.entry(num).or_insert(0);
+        *count += 1;
+    }
+
+    let _ = counts.iter().max_by_key(|(_, v)| v);
+    //~^ ERROR lifetime may not live long enough
+    //~| HELP dereference the return value
+}
diff --git a/tests/ui/closures/return-value-lifetime-error.stderr b/tests/ui/closures/return-value-lifetime-error.stderr
new file mode 100644
index 00000000000..a0ad127db28
--- /dev/null
+++ b/tests/ui/closures/return-value-lifetime-error.stderr
@@ -0,0 +1,16 @@
+error: lifetime may not live long enough
+  --> $DIR/return-value-lifetime-error.rs:13:47
+   |
+LL |     let _ = counts.iter().max_by_key(|(_, v)| v);
+   |                                       ------- ^ returning this value requires that `'1` must outlive `'2`
+   |                                       |     |
+   |                                       |     return type of closure is &'2 &i32
+   |                                       has type `&'1 (&i32, &i32)`
+   |
+help: dereference the return value
+   |
+LL |     let _ = counts.iter().max_by_key(|(_, v)| **v);
+   |                                               ++
+
+error: aborting due to 1 previous error
+