about summary refs log tree commit diff
path: root/library/stdarch/crates/stdarch-gen-loongarch/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/stdarch/crates/stdarch-gen-loongarch/src/main.rs')
-rw-r--r--library/stdarch/crates/stdarch-gen-loongarch/src/main.rs46
1 files changed, 29 insertions, 17 deletions
diff --git a/library/stdarch/crates/stdarch-gen-loongarch/src/main.rs b/library/stdarch/crates/stdarch-gen-loongarch/src/main.rs
index aa9990b6ccd..40132097f5d 100644
--- a/library/stdarch/crates/stdarch-gen-loongarch/src/main.rs
+++ b/library/stdarch/crates/stdarch-gen-loongarch/src/main.rs
@@ -274,13 +274,14 @@ fn gen_bind_body(
         }
     };
 
+    let is_mem = in_t.iter().any(|s| s.contains("POINTER"));
     let is_store = current_name.to_string().contains("vst");
     let link_function = {
         let fn_decl = {
             let fn_output = if out_t.to_lowercase() == "void" {
                 String::new()
             } else {
-                format!("-> {}", type_to_rst(out_t, is_store))
+                format!(" -> {}", type_to_rst(out_t, is_store))
             };
             let fn_inputs = match para_num {
                 1 => format!("(a: {})", type_to_rst(in_t[0], is_store)),
@@ -304,7 +305,7 @@ fn gen_bind_body(
                 ),
                 _ => panic!("unsupported parameter number"),
             };
-            format!("fn __{current_name}{fn_inputs} {fn_output};")
+            format!("fn __{current_name}{fn_inputs}{fn_output};")
         };
         let function = format!(
             r#"    #[link_name = "llvm.loongarch.{}"]
@@ -456,31 +457,40 @@ fn gen_bind_body(
             };
             rustc_legacy_const_generics = "rustc_legacy_const_generics(2, 3)";
         }
-        format!("pub unsafe fn {current_name}{fn_inputs} {fn_output}")
+        format!(
+            "pub {}fn {current_name}{fn_inputs} {fn_output}",
+            if is_mem { "unsafe " } else { "" }
+        )
     };
+    let unsafe_start = if !is_mem { "unsafe { " } else { "" };
+    let unsafe_end = if !is_mem { " }" } else { "" };
     let mut call_params = {
         match para_num {
-            1 => format!("__{current_name}(a)"),
-            2 => format!("__{current_name}(a, b)"),
-            3 => format!("__{current_name}(a, b, c)"),
-            4 => format!("__{current_name}(a, b, c, d)"),
+            1 => format!("{unsafe_start}__{current_name}(a){unsafe_end}"),
+            2 => format!("{unsafe_start}__{current_name}(a, b){unsafe_end}"),
+            3 => format!("{unsafe_start}__{current_name}(a, b, c){unsafe_end}"),
+            4 => format!("{unsafe_start}__{current_name}(a, b, c, d){unsafe_end}"),
             _ => panic!("unsupported parameter number"),
         }
     };
     if para_num == 1 && in_t[0] == "HI" {
         call_params = match asm_fmts[1].as_str() {
             "si10" => {
-                format!("static_assert_simm_bits!(IMM_S10, 10);\n    __{current_name}(IMM_S10)")
+                format!(
+                    "static_assert_simm_bits!(IMM_S10, 10);\n    {unsafe_start}__{current_name}(IMM_S10){unsafe_end}"
+                )
             }
             "i13" => {
-                format!("static_assert_simm_bits!(IMM_S13, 13);\n    __{current_name}(IMM_S13)")
+                format!(
+                    "static_assert_simm_bits!(IMM_S13, 13);\n    {unsafe_start}__{current_name}(IMM_S13){unsafe_end}"
+                )
             }
             _ => panic!("unsupported assembly format: {}", asm_fmts[2]),
         }
     } else if para_num == 2 && (in_t[1] == "UQI" || in_t[1] == "USI") {
         call_params = if asm_fmts[2].starts_with("ui") {
             format!(
-                "static_assert_uimm_bits!(IMM{0}, {0});\n    __{current_name}(a, IMM{0})",
+                "static_assert_uimm_bits!(IMM{0}, {0});\n    {unsafe_start}__{current_name}(a, IMM{0}){unsafe_end}",
                 asm_fmts[2].get(2..).unwrap()
             )
         } else {
@@ -489,14 +499,16 @@ fn gen_bind_body(
     } else if para_num == 2 && in_t[1] == "QI" {
         call_params = match asm_fmts[2].as_str() {
             "si5" => {
-                format!("static_assert_simm_bits!(IMM_S5, 5);\n    __{current_name}(a, IMM_S5)")
+                format!(
+                    "static_assert_simm_bits!(IMM_S5, 5);\n    {unsafe_start}__{current_name}(a, IMM_S5){unsafe_end}"
+                )
             }
             _ => panic!("unsupported assembly format: {}", asm_fmts[2]),
         };
     } else if para_num == 2 && in_t[0] == "CVPOINTER" && in_t[1] == "SI" {
         call_params = if asm_fmts[2].starts_with("si") {
             format!(
-                "static_assert_simm_bits!(IMM_S{0}, {0});\n    __{current_name}(mem_addr, IMM_S{0})",
+                "static_assert_simm_bits!(IMM_S{0}, {0});\n    {unsafe_start}__{current_name}(mem_addr, IMM_S{0}){unsafe_end}",
                 asm_fmts[2].get(2..).unwrap()
             )
         } else {
@@ -504,13 +516,13 @@ fn gen_bind_body(
         }
     } else if para_num == 2 && in_t[0] == "CVPOINTER" && in_t[1] == "DI" {
         call_params = match asm_fmts[2].as_str() {
-            "rk" => format!("__{current_name}(mem_addr, b)"),
+            "rk" => format!("{unsafe_start}__{current_name}(mem_addr, b){unsafe_end}"),
             _ => panic!("unsupported assembly format: {}", asm_fmts[2]),
         };
     } else if para_num == 3 && (in_t[2] == "USI" || in_t[2] == "UQI") {
         call_params = if asm_fmts[2].starts_with("ui") {
             format!(
-                "static_assert_uimm_bits!(IMM{0}, {0});\n    __{current_name}(a, b, IMM{0})",
+                "static_assert_uimm_bits!(IMM{0}, {0});\n    {unsafe_start}__{current_name}(a, b, IMM{0}){unsafe_end}",
                 asm_fmts[2].get(2..).unwrap()
             )
         } else {
@@ -519,19 +531,19 @@ fn gen_bind_body(
     } else if para_num == 3 && in_t[1] == "CVPOINTER" && in_t[2] == "SI" {
         call_params = match asm_fmts[2].as_str() {
             "si12" => format!(
-                "static_assert_simm_bits!(IMM_S12, 12);\n    __{current_name}(a, mem_addr, IMM_S12)"
+                "static_assert_simm_bits!(IMM_S12, 12);\n    {unsafe_start}__{current_name}(a, mem_addr, IMM_S12){unsafe_end}"
             ),
             _ => panic!("unsupported assembly format: {}", asm_fmts[2]),
         };
     } else if para_num == 3 && in_t[1] == "CVPOINTER" && in_t[2] == "DI" {
         call_params = match asm_fmts[2].as_str() {
-            "rk" => format!("__{current_name}(a, mem_addr, b)"),
+            "rk" => format!("{unsafe_start}__{current_name}(a, mem_addr, b){unsafe_end}"),
             _ => panic!("unsupported assembly format: {}", asm_fmts[2]),
         };
     } else if para_num == 4 {
         call_params = match (asm_fmts[2].as_str(), current_name.chars().last().unwrap()) {
             ("si8", t) => format!(
-                "static_assert_simm_bits!(IMM_S8, 8);\n    static_assert_uimm_bits!(IMM{0}, {0});\n    __{current_name}(a, mem_addr, IMM_S8, IMM{0})",
+                "static_assert_simm_bits!(IMM_S8, 8);\n    static_assert_uimm_bits!(IMM{0}, {0});\n    {unsafe_start}__{current_name}(a, mem_addr, IMM_S8, IMM{0}){unsafe_end}",
                 type_to_imm(t)
             ),
             (_, _) => panic!(