use itertools::Itertools; use proc_macro2::{Literal, Punct, Spacing, TokenStream}; use quote::{ToTokens, TokenStreamExt, format_ident, quote}; use regex::Regex; use serde::de::{self, MapAccess, Visitor}; use serde::{Deserialize, Deserializer, Serialize}; use std::fmt; use std::str::FromStr; use std::sync::LazyLock; use crate::intrinsic::Intrinsic; use crate::wildstring::WildStringPart; use crate::{ context::{self, Context, VariableType}, intrinsic::{Argument, LLVMLink, StaticDefinition}, matching::{MatchKindValues, MatchSizeValues}, typekinds::{BaseType, BaseTypeKind, TypeKind}, wildcards::Wildcard, wildstring::WildString, }; #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum IdentifierType { Variable, Symbol, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum LetVariant { Basic(WildString, Box), WithType(WildString, TypeKind, Box), MutWithType(WildString, TypeKind, Box), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FnCall( /// Function pointer pub Box, /// Function arguments pub Vec, /// Function turbofish arguments #[serde(default)] pub Vec, /// Function requires unsafe wrapper #[serde(default)] pub bool, ); impl FnCall { pub fn new_expression(fn_ptr: Expression, arguments: Vec) -> Expression { FnCall(Box::new(fn_ptr), arguments, Vec::new(), false).into() } pub fn new_unsafe_expression(fn_ptr: Expression, arguments: Vec) -> Expression { FnCall(Box::new(fn_ptr), arguments, Vec::new(), true).into() } pub fn is_llvm_link_call(&self, llvm_link_name: &str) -> bool { self.is_expected_call(llvm_link_name) } pub fn is_target_feature_call(&self) -> bool { self.is_expected_call("target_feature") } pub fn is_expected_call(&self, fn_call_name: &str) -> bool { if let Expression::Identifier(fn_name, IdentifierType::Symbol) = self.0.as_ref() { fn_name.to_string() == fn_call_name } else { false } } pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result { self.0.pre_build(ctx)?; self.1 .iter_mut() .chain(self.2.iter_mut()) .try_for_each(|ex| ex.pre_build(ctx)) } pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result { self.0.build(intrinsic, ctx)?; self.1 .iter_mut() .chain(self.2.iter_mut()) .try_for_each(|ex| ex.build(intrinsic, ctx)) } } impl ToTokens for FnCall { fn to_tokens(&self, tokens: &mut TokenStream) { let FnCall(fn_ptr, arguments, turbofish, _requires_unsafe_wrapper) = self; fn_ptr.to_tokens(tokens); if !turbofish.is_empty() { tokens.append_all(quote! {::<#(#turbofish),*>}); } tokens.append_all(quote! { (#(#arguments),*) }) } } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(remote = "Self", deny_unknown_fields)] pub enum Expression { /// (Re)Defines a variable Let(LetVariant), /// Performs a variable assignment operation Assign(String, Box), /// Performs a macro call MacroCall(String, String), /// Performs a function call FnCall(FnCall), /// Performs a method call. The following: /// `MethodCall: ["$object", "to_string", []]` /// is tokenized as: /// `object.to_string()`. MethodCall(Box, String, Vec), /// Symbol identifier name, prepend with a `$` to treat it as a scope variable /// which engages variable tracking and enables inference. /// E.g. `my_function_name` for a generic symbol or `$my_variable` for /// a variable. Identifier(WildString, IdentifierType), /// Constant signed integer number expression IntConstant(i32), /// Constant floating point number expression FloatConstant(f32), /// Constant boolean expression, either `true` or `false` BoolConstant(bool), /// Array expression Array(Vec), // complex expressions /// Makes an LLVM link. /// /// It stores the link's function name in the wildcard `{llvm_link}`, for use in /// subsequent expressions. LLVMLink(LLVMLink), /// Casts the given expression to the specified (unchecked) type CastAs(Box, String), /// Returns the LLVM `undef` symbol SvUndef, /// Multiplication Multiply(Box, Box), /// Xor Xor(Box, Box), /// Converts the specified constant to the specified type's kind ConvertConst(TypeKind, i32), /// Yields the given type in the Rust representation Type(TypeKind), MatchSize(TypeKind, MatchSizeValues>), MatchKind(TypeKind, MatchKindValues>), } impl Expression { pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result { match self { Self::FnCall(fn_call) => fn_call.pre_build(ctx), Self::MethodCall(cl_ptr_ex, _, arg_exs) => { cl_ptr_ex.pre_build(ctx)?; arg_exs.iter_mut().try_for_each(|ex| ex.pre_build(ctx)) } Self::Let( LetVariant::Basic(_, ex) | LetVariant::WithType(_, _, ex) | LetVariant::MutWithType(_, _, ex), ) => ex.pre_build(ctx), Self::CastAs(ex, _) => ex.pre_build(ctx), Self::Multiply(lhs, rhs) | Self::Xor(lhs, rhs) => { lhs.pre_build(ctx)?; rhs.pre_build(ctx) } Self::MatchSize(match_ty, values) => { *self = *values.get(match_ty, ctx.local)?.to_owned(); self.pre_build(ctx) } Self::MatchKind(match_ty, values) => { *self = *values.get(match_ty, ctx.local)?.to_owned(); self.pre_build(ctx) } _ => Ok(()), } } pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result { match self { Self::LLVMLink(link) => link.build_and_save(ctx), Self::Identifier(identifier, id_type) => { identifier.build_acle(ctx.local)?; if let IdentifierType::Variable = id_type { ctx.local .variables .get(&identifier.to_string()) .map(|_| ()) .ok_or_else(|| format!("invalid variable {identifier} being referenced")) } else { Ok(()) } } Self::FnCall(fn_call) => { fn_call.build(intrinsic, ctx)?; #[allow(clippy::collapsible_if)] if let Some(llvm_link_name) = ctx.local.substitutions.get(&Wildcard::LLVMLink) { if fn_call.is_llvm_link_call(llvm_link_name) { *self = intrinsic .llvm_link() .expect("got LLVMLink wildcard without a LLVM link in `compose`") .apply_conversions_to_call(fn_call.clone(), ctx)? } } Ok(()) } Self::MethodCall(cl_ptr_ex, _, arg_exs) => { cl_ptr_ex.build(intrinsic, ctx)?; arg_exs .iter_mut() .try_for_each(|ex| ex.build(intrinsic, ctx)) } Self::Let(variant) => { let (var_name, ex, ty) = match variant { LetVariant::Basic(var_name, ex) => (var_name, ex, None), LetVariant::WithType(var_name, ty, ex) | LetVariant::MutWithType(var_name, ty, ex) => { if let Some(w) = ty.wildcard() { ty.populate_wildcard(ctx.local.provide_type_wildcard(w)?)?; } (var_name, ex, Some(ty.to_owned())) } }; var_name.build_acle(ctx.local)?; ctx.local.variables.insert( var_name.to_string(), ( ty.unwrap_or_else(|| TypeKind::Custom("unknown".to_string())), VariableType::Internal, ), ); ex.build(intrinsic, ctx) } Self::CastAs(ex, _) => ex.build(intrinsic, ctx), Self::Multiply(lhs, rhs) | Self::Xor(lhs, rhs) => { lhs.build(intrinsic, ctx)?; rhs.build(intrinsic, ctx) } Self::ConvertConst(ty, num) => { if let Some(w) = ty.wildcard() { *ty = ctx.local.provide_type_wildcard(w)? } if let Some(BaseType::Sized(BaseTypeKind::Float, _)) = ty.base() { *self = Expression::FloatConstant(*num as f32) } else { *self = Expression::IntConstant(*num) } Ok(()) } Self::Type(ty) => { if let Some(w) = ty.wildcard() { *ty = ctx.local.provide_type_wildcard(w)? } Ok(()) } _ => Ok(()), } } /// True if the expression requires an `unsafe` context in a safe function. /// /// The classification is somewhat fuzzy, based on actual usage (e.g. empirical function names) /// rather than a full parse. This is a reasonable approach because mistakes here will usually /// be caught at build time: /// /// - Missing an `unsafe` is a build error. /// - An unnecessary `unsafe` is a warning, made into an error by the CI's `-D warnings`. /// /// This **panics** if it encounters an expression that shouldn't appear in a safe function at /// all (such as `SvUndef`). pub fn requires_unsafe_wrapper(&self, ctx_fn: &str) -> bool { match self { // The call will need to be unsafe, but the declaration does not. Self::LLVMLink(..) => false, // Identifiers, literals and type names are never unsafe. Self::Identifier(..) => false, Self::IntConstant(..) => false, Self::FloatConstant(..) => false, Self::BoolConstant(..) => false, Self::Type(..) => false, Self::ConvertConst(..) => false, // Nested structures that aren't inherently unsafe, but could contain other expressions // that might be. Self::Assign(_var, exp) => exp.requires_unsafe_wrapper(ctx_fn), Self::Let( LetVariant::Basic(_, exp) | LetVariant::WithType(_, _, exp) | LetVariant::MutWithType(_, _, exp), ) => exp.requires_unsafe_wrapper(ctx_fn), Self::Array(exps) => exps.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)), Self::Multiply(lhs, rhs) | Self::Xor(lhs, rhs) => { lhs.requires_unsafe_wrapper(ctx_fn) || rhs.requires_unsafe_wrapper(ctx_fn) } Self::CastAs(exp, _ty) => exp.requires_unsafe_wrapper(ctx_fn), // Functions and macros can be unsafe, but can also contain other expressions. Self::FnCall(FnCall(fn_exp, args, turbo_args, requires_unsafe_wrapper)) => { let fn_name = fn_exp.to_string(); fn_exp.requires_unsafe_wrapper(ctx_fn) || fn_name.starts_with("_sv") || fn_name.starts_with("simd_") || fn_name.ends_with("transmute") || args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)) || turbo_args .iter() .any(|exp| exp.requires_unsafe_wrapper(ctx_fn)) || *requires_unsafe_wrapper } Self::MethodCall(exp, fn_name, args) => match fn_name.as_str() { // `as_signed` and `as_unsigned` are unsafe because they're trait methods with // target features to allow use on feature-dependent types (such as SVE vectors). // We can safely wrap them here. "as_signed" => true, "as_unsigned" => true, _ => { exp.requires_unsafe_wrapper(ctx_fn) || args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)) } }, // We only use macros to check const generics (using static assertions). Self::MacroCall(_name, _args) => false, // Materialising uninitialised values is always unsafe, and we avoid it in safe // functions. Self::SvUndef => panic!("Refusing to wrap unsafe SvUndef in safe function '{ctx_fn}'."), // Variants that aren't tokenised. We shouldn't encounter these here. Self::MatchKind(..) => { unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.") } Self::MatchSize(..) => { unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.") } } } /// Determine if an expression is a `static_assert<...>` function call. pub fn is_static_assert(&self) -> bool { match self { Expression::FnCall(fn_call) => match fn_call.0.as_ref() { Expression::Identifier(wild_string, _) => { if let WildStringPart::String(function_name) = &wild_string.0[0] { function_name.starts_with("static_assert") } else { false } } _ => panic!("Badly defined function call: {fn_call:?}"), }, _ => false, } } /// Determine if an espression is a LLVM binding pub fn is_llvm_link(&self) -> bool { matches!(self, Expression::LLVMLink(_)) } } impl FromStr for Expression { type Err = String; fn from_str(s: &str) -> Result { static MACRO_RE: LazyLock = LazyLock::new(|| Regex::new(r"^(?P[\w\d_]+)!\((?P.*?)\);?$").unwrap()); if s == "SvUndef" { Ok(Expression::SvUndef) } else if MACRO_RE.is_match(s) { let c = MACRO_RE.captures(s).unwrap(); let ex = c["ex"].to_string(); let _: TokenStream = ex .parse() .map_err(|e| format!("could not parse macro call expression: {e:#?}"))?; Ok(Expression::MacroCall(c["name"].to_string(), ex)) } else { let (s, id_type) = if let Some(varname) = s.strip_prefix('$') { (varname, IdentifierType::Variable) } else { (s, IdentifierType::Symbol) }; let identifier = s.trim().parse()?; Ok(Expression::Identifier(identifier, id_type)) } } } impl From for Expression { fn from(fn_call: FnCall) -> Self { Expression::FnCall(fn_call) } } impl From for Expression { fn from(ws: WildString) -> Self { Expression::Identifier(ws, IdentifierType::Symbol) } } impl From<&Argument> for Expression { fn from(a: &Argument) -> Self { Expression::Identifier(a.name.to_owned(), IdentifierType::Variable) } } impl TryFrom<&StaticDefinition> for Expression { type Error = String; fn try_from(sd: &StaticDefinition) -> Result { match sd { StaticDefinition::Constant(imm) => Ok(imm.into()), StaticDefinition::Generic(t) => t.parse(), } } } impl fmt::Display for Expression { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Identifier(identifier, kind) => { write!( f, "{}{identifier}", matches!(kind, IdentifierType::Variable) .then_some("$") .unwrap_or_default() ) } Self::MacroCall(name, expression) => { write!(f, "{name}!({expression})") } _ => Err(fmt::Error), } } } impl ToTokens for Expression { fn to_tokens(&self, tokens: &mut TokenStream) { match self { Self::Let(LetVariant::Basic(var_name, exp)) => { let var_ident = format_ident!("{}", var_name.to_string()); tokens.append_all(quote! { let #var_ident = #exp }) } Self::Let(LetVariant::WithType(var_name, ty, exp)) => { let var_ident = format_ident!("{}", var_name.to_string()); tokens.append_all(quote! { let #var_ident: #ty = #exp }) } Self::Let(LetVariant::MutWithType(var_name, ty, exp)) => { let var_ident = format_ident!("{}", var_name.to_string()); tokens.append_all(quote! { let mut #var_ident: #ty = #exp }) } Self::Assign(var_name, exp) => { /* If we are dereferencing a variable to assign a value \ * the 'format_ident!' macro does not like the asterix */ let var_name_str: &str; if let Some(ch) = var_name.chars().nth(0) { /* Manually append the asterix and split out the rest of * the variable name */ if ch == '*' { tokens.append(Punct::new('*', Spacing::Alone)); var_name_str = &var_name[1..var_name.len()]; } else { var_name_str = var_name.as_str(); } } else { /* Should not be reached as you cannot have a variable * without a name */ panic!("Invalid variable name, must be at least one character") } let var_ident = format_ident!("{}", var_name_str); tokens.append_all(quote! { #var_ident = #exp }) } Self::MacroCall(name, ex) => { let name = format_ident!("{name}"); let ex: TokenStream = ex.parse().unwrap(); tokens.append_all(quote! { #name!(#ex) }) } Self::FnCall(fn_call) => fn_call.to_tokens(tokens), Self::MethodCall(exp, fn_name, args) => { let fn_ident = format_ident!("{}", fn_name); tokens.append_all(quote! { #exp.#fn_ident(#(#args),*) }) } Self::Identifier(identifier, _) => { assert!( !identifier.has_wildcards(), "expression {self:#?} was not built before calling to_tokens" ); identifier .to_string() .parse::() .unwrap_or_else(|_| panic!("invalid syntax: {self:?}")) .to_tokens(tokens); } Self::IntConstant(n) => tokens.append(Literal::i32_unsuffixed(*n)), Self::FloatConstant(n) => tokens.append(Literal::f32_unsuffixed(*n)), Self::BoolConstant(true) => tokens.append(format_ident!("true")), Self::BoolConstant(false) => tokens.append(format_ident!("false")), Self::Array(vec) => tokens.append_all(quote! { [ #(#vec),* ] }), Self::LLVMLink(link) => link.to_tokens(tokens), Self::CastAs(ex, ty) => { let ty: TokenStream = ty.parse().expect("invalid syntax"); tokens.append_all(quote! { #ex as #ty }) } Self::SvUndef => tokens.append_all(quote! { simd_reinterpret(()) }), Self::Multiply(lhs, rhs) => tokens.append_all(quote! { #lhs * #rhs }), Self::Xor(lhs, rhs) => tokens.append_all(quote! { #lhs ^ #rhs }), Self::Type(ty) => ty.to_tokens(tokens), _ => unreachable!("{self:?} cannot be converted to tokens."), } } } impl Serialize for Expression { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { match self { Self::IntConstant(v) => serializer.serialize_i32(*v), Self::FloatConstant(v) => serializer.serialize_f32(*v), Self::BoolConstant(v) => serializer.serialize_bool(*v), Self::Identifier(..) => serializer.serialize_str(&self.to_string()), Self::MacroCall(..) => serializer.serialize_str(&self.to_string()), _ => Expression::serialize(self, serializer), } } } impl<'de> Deserialize<'de> for Expression { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { struct CustomExpressionVisitor; impl<'de> Visitor<'de> for CustomExpressionVisitor { type Value = Expression; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("integer, float, boolean, string or map") } fn visit_bool(self, v: bool) -> Result where E: de::Error, { Ok(Expression::BoolConstant(v)) } fn visit_i64(self, v: i64) -> Result where E: de::Error, { Ok(Expression::IntConstant(v as i32)) } fn visit_u64(self, v: u64) -> Result where E: de::Error, { Ok(Expression::IntConstant(v as i32)) } fn visit_f64(self, v: f64) -> Result where E: de::Error, { Ok(Expression::FloatConstant(v as f32)) } fn visit_str(self, value: &str) -> Result where E: de::Error, { FromStr::from_str(value).map_err(de::Error::custom) } fn visit_seq(self, mut seq: A) -> Result where A: de::SeqAccess<'de>, { let arr = std::iter::from_fn(|| seq.next_element::().transpose()) .try_collect()?; Ok(Expression::Array(arr)) } fn visit_map(self, map: M) -> Result where M: MapAccess<'de>, { // `MapAccessDeserializer` is a wrapper that turns a `MapAccess` // into a `Deserializer`, allowing it to be used as the input to T's // `Deserialize` implementation. T then deserializes itself using // the entries from the map visitor. Expression::deserialize(de::value::MapAccessDeserializer::new(map)) } } deserializer.deserialize_any(CustomExpressionVisitor) } }