about summary refs log tree commit diff
path: root/compiler/rustc_mir_build/src/check_tail_calls.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_build/src/check_tail_calls.rs')
-rw-r--r--compiler/rustc_mir_build/src/check_tail_calls.rs57
1 files changed, 54 insertions, 3 deletions
diff --git a/compiler/rustc_mir_build/src/check_tail_calls.rs b/compiler/rustc_mir_build/src/check_tail_calls.rs
index 6ed100899d8..3ecccb422c4 100644
--- a/compiler/rustc_mir_build/src/check_tail_calls.rs
+++ b/compiler/rustc_mir_build/src/check_tail_calls.rs
@@ -60,9 +60,13 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
         let BodyTy::Fn(caller_sig) = self.thir.body_type else {
             span_bug!(
                 call.span,
-                "`become` outside of functions should have been disallowed by hit_typeck"
+                "`become` outside of functions should have been disallowed by hir_typeck"
             )
         };
+        // While the `caller_sig` does have its regions erased, it does not have its
+        // binders anonymized. We call `erase_regions` once again to anonymize any binders
+        // within the signature, such as in function pointer or `dyn Trait` args.
+        let caller_sig = self.tcx.erase_regions(caller_sig);
 
         let ExprKind::Scope { value, .. } = call.kind else {
             span_bug!(call.span, "expected scope, found: {call:?}")
@@ -95,9 +99,15 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
             // So we have to check for them in this weird way...
             let parent = self.tcx.parent(did);
             if self.tcx.fn_trait_kind_from_def_id(parent).is_some()
-                && args.first().and_then(|arg| arg.as_type()).is_some_and(Ty::is_closure)
+                && let Some(this) = args.first()
+                && let Some(this) = this.as_type()
             {
-                self.report_calling_closure(&self.thir[fun], args[1].as_type().unwrap(), expr);
+                if this.is_closure() {
+                    self.report_calling_closure(&self.thir[fun], args[1].as_type().unwrap(), expr);
+                } else {
+                    // This can happen when tail calling `Box` that wraps a function
+                    self.report_nonfn_callee(fn_span, self.thir[fun].span, this);
+                }
 
                 // Tail calling is likely to cause unrelated errors (ABI, argument mismatches),
                 // skip them, producing an error about calling a closure is enough.
@@ -109,6 +119,13 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
             }
         }
 
+        let (ty::FnDef(..) | ty::FnPtr(..)) = ty.kind() else {
+            self.report_nonfn_callee(fn_span, self.thir[fun].span, ty);
+
+            // `fn_sig` below panics otherwise
+            return;
+        };
+
         // Erase regions since tail calls don't care about lifetimes
         let callee_sig =
             self.tcx.normalize_erasing_late_bound_regions(self.typing_env, ty.fn_sig(self.tcx));
@@ -294,6 +311,40 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
         self.found_errors = Err(err);
     }
 
+    fn report_nonfn_callee(&mut self, call_sp: Span, fun_sp: Span, ty: Ty<'_>) {
+        let mut err = self
+            .tcx
+            .dcx()
+            .struct_span_err(
+                call_sp,
+                "tail calls can only be performed with function definitions or pointers",
+            )
+            .with_note(format!("callee has type `{ty}`"));
+
+        let mut ty = ty;
+        let mut refs = 0;
+        while ty.is_box() || ty.is_ref() {
+            ty = ty.builtin_deref(false).unwrap();
+            refs += 1;
+        }
+
+        if refs > 0 && ty.is_fn() {
+            let thing = if ty.is_fn_ptr() { "pointer" } else { "definition" };
+
+            let derefs =
+                std::iter::once('(').chain(std::iter::repeat_n('*', refs)).collect::<String>();
+
+            err.multipart_suggestion(
+                format!("consider dereferencing the expression to get a function {thing}"),
+                vec![(fun_sp.shrink_to_lo(), derefs), (fun_sp.shrink_to_hi(), ")".to_owned())],
+                Applicability::MachineApplicable,
+            );
+        }
+
+        let err = err.emit();
+        self.found_errors = Err(err);
+    }
+
     fn report_abi_mismatch(&mut self, sp: Span, caller_abi: ExternAbi, callee_abi: ExternAbi) {
         let err = self
             .tcx