about summary refs log tree commit diff
path: root/compiler/rustc_hir_analysis/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_hir_analysis/src')
-rw-r--r--compiler/rustc_hir_analysis/src/errors.rs4
-rw-r--r--compiler/rustc_hir_analysis/src/hir_ty_lowering/cmse.rs50
-rw-r--r--compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs2
3 files changed, 31 insertions, 25 deletions
diff --git a/compiler/rustc_hir_analysis/src/errors.rs b/compiler/rustc_hir_analysis/src/errors.rs
index f6783364f16..a77f967a5ca 100644
--- a/compiler/rustc_hir_analysis/src/errors.rs
+++ b/compiler/rustc_hir_analysis/src/errors.rs
@@ -1690,11 +1690,13 @@ pub struct CmseCallInputsStackSpill {
     #[primary_span]
     #[label]
     pub span: Span,
+    pub plural: bool,
 }
 
 #[derive(Diagnostic)]
 #[diag(hir_analysis_cmse_call_output_stack_spill, code = E0798)]
-#[note]
+#[note(hir_analysis_note1)]
+#[note(hir_analysis_note2)]
 pub struct CmseCallOutputStackSpill {
     #[primary_span]
     #[label]
diff --git a/compiler/rustc_hir_analysis/src/hir_ty_lowering/cmse.rs b/compiler/rustc_hir_analysis/src/hir_ty_lowering/cmse.rs
index 8980173f738..e99717ce00f 100644
--- a/compiler/rustc_hir_analysis/src/hir_ty_lowering/cmse.rs
+++ b/compiler/rustc_hir_analysis/src/hir_ty_lowering/cmse.rs
@@ -13,7 +13,7 @@ use crate::errors;
 /// conditions, but by checking them here rustc can emit nicer error messages.
 pub fn validate_cmse_abi<'tcx>(
     tcx: TyCtxt<'tcx>,
-    dcx: &DiagCtxtHandle<'_>,
+    dcx: DiagCtxtHandle<'_>,
     hir_id: HirId,
     abi: abi::Abi,
     fn_sig: ty::PolyFnSig<'tcx>,
@@ -30,25 +30,20 @@ pub fn validate_cmse_abi<'tcx>(
             return;
         };
 
-        // fn(u32, u32, u32, u16, u16) -> u32,
-        //    ^^^^^^^^^^^^^^^^^^^^^^^     ^^^
-        let output_span = bare_fn_ty.decl.output.span();
-        let inputs_span = match (
-            bare_fn_ty.param_names.first(),
-            bare_fn_ty.decl.inputs.first(),
-            bare_fn_ty.decl.inputs.last(),
-        ) {
-            (Some(ident), Some(ty1), Some(ty2)) => ident.span.to(ty1.span).to(ty2.span),
-            _ => *bare_fn_span,
-        };
-
         match is_valid_cmse_inputs(tcx, fn_sig) {
-            Ok(true) => {}
-            Ok(false) => {
-                dcx.emit_err(errors::CmseCallInputsStackSpill { span: inputs_span });
+            Ok(Ok(())) => {}
+            Ok(Err(index)) => {
+                // fn(x: u32, u32, u32, u16, y: u16) -> u32,
+                //                           ^^^^^^
+                let span = bare_fn_ty.param_names[index]
+                    .span
+                    .to(bare_fn_ty.decl.inputs[index].span)
+                    .to(bare_fn_ty.decl.inputs.last().unwrap().span);
+                let plural = bare_fn_ty.param_names.len() - index != 1;
+                dcx.emit_err(errors::CmseCallInputsStackSpill { span, plural });
             }
             Err(layout_err) => {
-                if let Some(err) = cmse_layout_err(layout_err, inputs_span) {
+                if let Some(err) = cmse_layout_err(layout_err, *bare_fn_span) {
                     dcx.emit_err(err);
                 }
             }
@@ -57,10 +52,11 @@ pub fn validate_cmse_abi<'tcx>(
         match is_valid_cmse_output(tcx, fn_sig) {
             Ok(true) => {}
             Ok(false) => {
-                dcx.emit_err(errors::CmseCallOutputStackSpill { span: output_span });
+                let span = bare_fn_ty.decl.output.span();
+                dcx.emit_err(errors::CmseCallOutputStackSpill { span });
             }
             Err(layout_err) => {
-                if let Some(err) = cmse_layout_err(layout_err, output_span) {
+                if let Some(err) = cmse_layout_err(layout_err, *bare_fn_span) {
                     dcx.emit_err(err);
                 }
             }
@@ -72,10 +68,11 @@ pub fn validate_cmse_abi<'tcx>(
 fn is_valid_cmse_inputs<'tcx>(
     tcx: TyCtxt<'tcx>,
     fn_sig: ty::PolyFnSig<'tcx>,
-) -> Result<bool, &'tcx LayoutError<'tcx>> {
+) -> Result<Result<(), usize>, &'tcx LayoutError<'tcx>> {
+    let mut span = None;
     let mut accum = 0u64;
 
-    for arg_def in fn_sig.inputs().iter() {
+    for (index, arg_def) in fn_sig.inputs().iter().enumerate() {
         let layout = tcx.layout_of(ParamEnv::reveal_all().and(*arg_def.skip_binder()))?;
 
         let align = layout.layout.align().abi.bytes();
@@ -83,10 +80,17 @@ fn is_valid_cmse_inputs<'tcx>(
 
         accum += size;
         accum = accum.next_multiple_of(Ord::max(4, align));
+
+        // i.e. exceeds 4 32-bit registers
+        if accum > 16 {
+            span = span.or(Some(index));
+        }
     }
 
-    // i.e. 4 32-bit registers
-    Ok(accum <= 16)
+    match span {
+        None => Ok(Ok(())),
+        Some(span) => Ok(Err(span)),
+    }
 }
 
 /// Returns whether the output will fit into the available registers
diff --git a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
index c118181780a..a632619aef2 100644
--- a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
+++ b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
@@ -2326,7 +2326,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
         let bare_fn_ty = ty::Binder::bind_with_vars(fn_ty, bound_vars);
 
         // reject function types that violate cmse ABI requirements
-        cmse::validate_cmse_abi(self.tcx(), &self.dcx(), hir_id, abi, bare_fn_ty);
+        cmse::validate_cmse_abi(self.tcx(), self.dcx(), hir_id, abi, bare_fn_ty);
 
         // Find any late-bound regions declared in return type that do
         // not appear in the arguments. These are not well-formed.