diff options
Diffstat (limited to 'library/compiler-builtins/crates/libm-macros/src/lib.rs')
| -rw-r--r-- | library/compiler-builtins/crates/libm-macros/src/lib.rs | 504 |
1 files changed, 504 insertions, 0 deletions
diff --git a/library/compiler-builtins/crates/libm-macros/src/lib.rs b/library/compiler-builtins/crates/libm-macros/src/lib.rs new file mode 100644 index 00000000000..482da974ca8 --- /dev/null +++ b/library/compiler-builtins/crates/libm-macros/src/lib.rs @@ -0,0 +1,504 @@ +#![feature(let_chains)] + +mod enums; +mod parse; +mod shared; + +use parse::{Invocation, StructuredInput}; +use proc_macro as pm; +use proc_macro2::{self as pm2, Span}; +use quote::{ToTokens, quote}; +pub(crate) use shared::{ALL_OPERATIONS, FloatTy, MathOpInfo, Ty}; +use syn::spanned::Spanned; +use syn::visit_mut::VisitMut; +use syn::{Ident, ItemEnum}; + +const KNOWN_TYPES: &[&str] = &[ + "FTy", "CFn", "CArgs", "CRet", "RustFn", "RustArgs", "RustRet", "public", +]; + +/// Populate an enum with a variant representing function. Names are in upper camel case. +/// +/// Applied to an empty enum. Expects one attribute `#[function_enum(BaseName)]` that provides +/// the name of the `BaseName` enum. +#[proc_macro_attribute] +pub fn function_enum(attributes: pm::TokenStream, tokens: pm::TokenStream) -> pm::TokenStream { + let item = syn::parse_macro_input!(tokens as ItemEnum); + let res = enums::function_enum(item, attributes.into()); + + match res { + Ok(ts) => ts, + Err(e) => e.into_compile_error(), + } + .into() +} + +/// Create an enum representing all possible base names, with names in upper camel case. +/// +/// Applied to an empty enum. +#[proc_macro_attribute] +pub fn base_name_enum(attributes: pm::TokenStream, tokens: pm::TokenStream) -> pm::TokenStream { + let item = syn::parse_macro_input!(tokens as ItemEnum); + let res = enums::base_name_enum(item, attributes.into()); + + match res { + Ok(ts) => ts, + Err(e) => e.into_compile_error(), + } + .into() +} + +/// Do something for each function present in this crate. +/// +/// Takes a callback macro and invokes it multiple times, once for each function that +/// this crate exports. This makes it easy to create generic tests, benchmarks, or other checks +/// and apply it to each symbol. +/// +/// Additionally, the `extra` and `fn_extra` patterns can make use of magic identifiers: +/// +/// - `MACRO_FN_NAME`: gets replaced with the name of the function on that invocation. +/// - `MACRO_FN_NAME_NORMALIZED`: similar to the above, but removes sufixes so e.g. `sinf` becomes +/// `sin`, `cosf128` becomes `cos`, etc. +/// +/// Invoke as: +/// +/// ``` +/// // Macro that is invoked once per function +/// macro_rules! callback_macro { +/// ( +/// // Name of that function +/// fn_name: $fn_name:ident, +/// // The basic float type for this function (e.g. `f32`, `f64`) +/// FTy: $FTy:ty, +/// // Function signature of the C version (e.g. `fn(f32, &mut f32) -> f32`) +/// CFn: $CFn:ty, +/// // A tuple representing the C version's arguments (e.g. `(f32, &mut f32)`) +/// CArgs: $CArgs:ty, +/// // The C version's return type (e.g. `f32`) +/// CRet: $CRet:ty, +/// // Function signature of the Rust version (e.g. `fn(f32) -> (f32, f32)`) +/// RustFn: $RustFn:ty, +/// // A tuple representing the Rust version's arguments (e.g. `(f32,)`) +/// RustArgs: $RustArgs:ty, +/// // The Rust version's return type (e.g. `(f32, f32)`) +/// RustRet: $RustRet:ty, +/// // True if this is part of `libm`'s public API +/// public: $public:expr, +/// // Attributes for the current function, if any +/// attrs: [$($attr:meta),*], +/// // Extra tokens passed directly (if any) +/// extra: [$extra:ident], +/// // Extra function-tokens passed directly (if any) +/// fn_extra: $fn_extra:expr, +/// ) => { }; +/// } +/// +/// // All fields except for `callback` are optional. +/// libm_macros::for_each_function! { +/// // The macro to invoke as a callback +/// callback: callback_macro, +/// // Which types to include either as a list (`[CFn, RustFn, RustArgs]`) or "all" +/// emit_types: all, +/// // Functions to skip, i.e. `callback` shouldn't be called at all for these. +/// skip: [sin, cos], +/// // Attributes passed as `attrs` for specific functions. For example, here the invocation +/// // with `sinf` and that with `cosf` will both get `meta1` and `meta2`, but no others will. +/// // +/// // Note that `f16_enabled` and `f128_enabled` will always get emitted regardless of whether +/// // or not this is specified. +/// attributes: [ +/// #[meta1] +/// #[meta2] +/// [sinf, cosf], +/// ], +/// // Any tokens that should be passed directly to all invocations of the callback. This can +/// // be used to pass local variables or other things the macro needs access to. +/// extra: [foo], +/// // Similar to `extra`, but allow providing a pattern for only specific functions. Uses +/// // a simplified match-like syntax. +/// fn_extra: match MACRO_FN_NAME { +/// hypot | hypotf => |x| x.hypot(), +/// // `ALL_*` magic matchers also work to extract specific types +/// ALL_F64 => |x| x, +/// // The default pattern gets applied to everything that did not match +/// _ => |x| x, +/// }, +/// } +/// ``` +#[proc_macro] +pub fn for_each_function(tokens: pm::TokenStream) -> pm::TokenStream { + let input = syn::parse_macro_input!(tokens as Invocation); + + let res = StructuredInput::from_fields(input) + .and_then(|mut s_in| validate(&mut s_in).map(|fn_list| (s_in, fn_list))) + .and_then(|(s_in, fn_list)| expand(s_in, &fn_list)); + + match res { + Ok(ts) => ts.into(), + Err(e) => e.into_compile_error().into(), + } +} + +/// Check for any input that is structurally correct but has other problems. +/// +/// Returns the list of function names that we should expand for. +fn validate(input: &mut StructuredInput) -> syn::Result<Vec<&'static MathOpInfo>> { + // Replace magic mappers with a list of relevant functions. + if let Some(map) = &mut input.fn_extra { + for (name, ty) in [ + ("ALL_F16", FloatTy::F16), + ("ALL_F32", FloatTy::F32), + ("ALL_F64", FloatTy::F64), + ("ALL_F128", FloatTy::F128), + ] { + let Some(k) = map.keys().find(|key| *key == name) else { + continue; + }; + + let key = k.clone(); + let val = map.remove(&key).unwrap(); + + for op in ALL_OPERATIONS.iter().filter(|op| op.float_ty == ty) { + map.insert(Ident::new(op.name, key.span()), val.clone()); + } + } + } + + // Collect lists of all functions that are provied as macro inputs in various fields (only, + // skip, attributes). + let attr_mentions = input + .attributes + .iter() + .flat_map(|map_list| map_list.iter()) + .flat_map(|attr_map| attr_map.names.iter()); + let only_mentions = input.only.iter().flat_map(|only_list| only_list.iter()); + let fn_extra_mentions = input + .fn_extra + .iter() + .flat_map(|v| v.keys()) + .filter(|name| *name != "_"); + let all_mentioned_fns = input + .skip + .iter() + .chain(only_mentions) + .chain(attr_mentions) + .chain(fn_extra_mentions); + + // Make sure that every function mentioned is a real function + for mentioned in all_mentioned_fns { + if !ALL_OPERATIONS.iter().any(|func| mentioned == func.name) { + let e = syn::Error::new( + mentioned.span(), + format!("unrecognized function name `{mentioned}`"), + ); + return Err(e); + } + } + + if !input.skip.is_empty() && input.only.is_some() { + let e = syn::Error::new( + input.only_span.unwrap(), + "only one of `skip` or `only` may be specified", + ); + return Err(e); + } + + // Construct a list of what we intend to expand + let mut fn_list = Vec::new(); + for func in ALL_OPERATIONS.iter() { + let fn_name = func.name; + // If we have an `only` list and it does _not_ contain this function name, skip it + if input + .only + .as_ref() + .is_some_and(|only| !only.iter().any(|o| o == fn_name)) + { + continue; + } + + // If there is a `skip` list that contains this function name, skip it + if input.skip.iter().any(|s| s == fn_name) { + continue; + } + + // Omit f16 and f128 functions if requested + if input.skip_f16_f128 && (func.float_ty == FloatTy::F16 || func.float_ty == FloatTy::F128) + { + continue; + } + + // Run everything else + fn_list.push(func); + } + + // Types that the user would like us to provide in the macro + let mut add_all_types = false; + for ty in &input.emit_types { + let ty_name = ty.to_string(); + if ty_name == "all" { + add_all_types = true; + continue; + } + + // Check that all requested types are valid + if !KNOWN_TYPES.contains(&ty_name.as_str()) { + let e = syn::Error::new( + ty_name.span(), + format!("unrecognized type identifier `{ty_name}`"), + ); + return Err(e); + } + } + + if add_all_types { + // Ensure that if `all` was specified that nothing else was + if input.emit_types.len() > 1 { + let e = syn::Error::new( + input.emit_types_span.unwrap(), + "if `all` is specified, no other type identifiers may be given", + ); + return Err(e); + } + + // ...and then add all types + input.emit_types.clear(); + for ty in KNOWN_TYPES { + let ident = Ident::new(ty, Span::call_site()); + input.emit_types.push(ident); + } + } + + if let Some(map) = &input.fn_extra + && !map.keys().any(|key| key == "_") + { + // No default provided; make sure every expected function is covered + let mut fns_not_covered = Vec::new(); + for func in &fn_list { + if !map.keys().any(|key| key == func.name) { + // `name` was not mentioned in the `match` statement + fns_not_covered.push(func); + } + } + + if !fns_not_covered.is_empty() { + let e = syn::Error::new( + input.fn_extra_span.unwrap(), + format!( + "`fn_extra`: no default `_` pattern specified and the following \ + patterns are not covered: {fns_not_covered:#?}" + ), + ); + return Err(e); + } + }; + + Ok(fn_list) +} + +/// Expand our structured macro input into invocations of the callback macro. +fn expand(input: StructuredInput, fn_list: &[&MathOpInfo]) -> syn::Result<pm2::TokenStream> { + let mut out = pm2::TokenStream::new(); + let default_ident = Ident::new("_", Span::call_site()); + let callback = input.callback; + + for func in fn_list { + let fn_name = Ident::new(func.name, Span::call_site()); + + // Prepare attributes in an `attrs: ...` field + let mut meta_fields = Vec::new(); + if let Some(attrs) = &input.attributes { + let meta_iter = attrs + .iter() + .filter(|map| map.names.contains(&fn_name)) + .flat_map(|map| &map.meta) + .map(|v| v.into_token_stream()); + + meta_fields.extend(meta_iter); + } + + // Always emit f16 and f128 meta so this doesn't need to be repeated everywhere + if func.rust_sig.args.contains(&Ty::F16) || func.rust_sig.returns.contains(&Ty::F16) { + let ts = quote! { cfg(f16_enabled) }; + meta_fields.push(ts); + } + if func.rust_sig.args.contains(&Ty::F128) || func.rust_sig.returns.contains(&Ty::F128) { + let ts = quote! { cfg(f128_enabled) }; + meta_fields.push(ts); + } + + let meta_field = quote! { attrs: [ #( #meta_fields ),* ], }; + + // Prepare extra in an `extra: ...` field, running the replacer + let extra_field = match input.extra.clone() { + Some(mut extra) => { + let mut v = MacroReplace::new(func.name); + v.visit_expr_mut(&mut extra); + v.finish()?; + + quote! { extra: #extra, } + } + None => pm2::TokenStream::new(), + }; + + // Prepare function-specific extra in a `fn_extra: ...` field, running the replacer + let fn_extra_field = match input.fn_extra { + Some(ref map) => { + let mut fn_extra = map + .get(&fn_name) + .or_else(|| map.get(&default_ident)) + .unwrap() + .clone(); + + let mut v = MacroReplace::new(func.name); + v.visit_expr_mut(&mut fn_extra); + v.finish()?; + + quote! { fn_extra: #fn_extra, } + } + None => pm2::TokenStream::new(), + }; + + let base_fty = func.float_ty; + let c_args = &func.c_sig.args; + let c_ret = &func.c_sig.returns; + let rust_args = &func.rust_sig.args; + let rust_ret = &func.rust_sig.returns; + let public = func.public; + + let mut ty_fields = Vec::new(); + for ty in &input.emit_types { + let field = match ty.to_string().as_str() { + "FTy" => quote! { FTy: #base_fty, }, + "CFn" => quote! { CFn: fn( #(#c_args),* ,) -> ( #(#c_ret),* ), }, + "CArgs" => quote! { CArgs: ( #(#c_args),* ,), }, + "CRet" => quote! { CRet: ( #(#c_ret),* ), }, + "RustFn" => quote! { RustFn: fn( #(#rust_args),* ,) -> ( #(#rust_ret),* ), }, + "RustArgs" => quote! { RustArgs: ( #(#rust_args),* ,), }, + "RustRet" => quote! { RustRet: ( #(#rust_ret),* ), }, + "public" => quote! { public: #public, }, + _ => unreachable!("checked in validation"), + }; + ty_fields.push(field); + } + + let new = quote! { + #callback! { + fn_name: #fn_name, + #( #ty_fields )* + #meta_field + #extra_field + #fn_extra_field + } + }; + + out.extend(new); + } + + Ok(out) +} + +/// Visitor to replace "magic" identifiers that we allow: `MACRO_FN_NAME` and +/// `MACRO_FN_NAME_NORMALIZED`. +struct MacroReplace { + fn_name: &'static str, + /// Remove the trailing `f` or `f128` to make + norm_name: String, + error: Option<syn::Error>, +} + +impl MacroReplace { + fn new(name: &'static str) -> Self { + let norm_name = base_name(name); + Self { + fn_name: name, + norm_name: norm_name.to_owned(), + error: None, + } + } + + fn finish(self) -> syn::Result<()> { + match self.error { + Some(e) => Err(e), + None => Ok(()), + } + } + + fn visit_ident_inner(&mut self, i: &mut Ident) { + let s = i.to_string(); + if !s.starts_with("MACRO") || self.error.is_some() { + return; + } + + match s.as_str() { + "MACRO_FN_NAME" => *i = Ident::new(self.fn_name, i.span()), + "MACRO_FN_NAME_NORMALIZED" => *i = Ident::new(&self.norm_name, i.span()), + _ => { + self.error = Some(syn::Error::new( + i.span(), + format!("unrecognized meta expression `{s}`"), + )); + } + } + } +} + +impl VisitMut for MacroReplace { + fn visit_ident_mut(&mut self, i: &mut Ident) { + self.visit_ident_inner(i); + syn::visit_mut::visit_ident_mut(self, i); + } +} + +/// Return the unsuffixed version of a function name; e.g. `abs` and `absf` both return `abs`, +/// `lgamma_r` and `lgammaf_r` both return `lgamma_r`. +fn base_name(name: &str) -> &str { + let known_mappings = &[ + ("erff", "erf"), + ("erf", "erf"), + ("lgammaf_r", "lgamma_r"), + ("modff", "modf"), + ("modf", "modf"), + ]; + + match known_mappings.iter().find(|known| known.0 == name) { + Some(found) => found.1, + None => name + .strip_suffix("f") + .or_else(|| name.strip_suffix("f16")) + .or_else(|| name.strip_suffix("f128")) + .unwrap_or(name), + } +} + +impl ToTokens for Ty { + fn to_tokens(&self, tokens: &mut pm2::TokenStream) { + let ts = match self { + Ty::F16 => quote! { f16 }, + Ty::F32 => quote! { f32 }, + Ty::F64 => quote! { f64 }, + Ty::F128 => quote! { f128 }, + Ty::I32 => quote! { i32 }, + Ty::CInt => quote! { ::core::ffi::c_int }, + Ty::MutF16 => quote! { &'a mut f16 }, + Ty::MutF32 => quote! { &'a mut f32 }, + Ty::MutF64 => quote! { &'a mut f64 }, + Ty::MutF128 => quote! { &'a mut f128 }, + Ty::MutI32 => quote! { &'a mut i32 }, + Ty::MutCInt => quote! { &'a mut core::ffi::c_int }, + }; + + tokens.extend(ts); + } +} +impl ToTokens for FloatTy { + fn to_tokens(&self, tokens: &mut pm2::TokenStream) { + let ts = match self { + FloatTy::F16 => quote! { f16 }, + FloatTy::F32 => quote! { f32 }, + FloatTy::F64 => quote! { f64 }, + FloatTy::F128 => quote! { f128 }, + }; + + tokens.extend(ts); + } +} |
