about summary refs log tree commit diff
path: root/library/stdarch/crates/stdarch-gen-arm/src/expression.rs
diff options
context:
space:
mode:
authorJames Barford-Evans <james.barford-evans@arm.com>2024-12-18 15:42:53 +0000
committerAmanieu d'Antras <amanieu@gmail.com>2025-01-16 14:29:19 +0000
commitf283e449b11ebe8127570aab09b8871442d1e74b (patch)
treec9f7981d8f3fbf52ae2e36c07e60d14d0c0c5959 /library/stdarch/crates/stdarch-gen-arm/src/expression.rs
parent28cb01cd6a6d0ee09b6933352f07a12a4c3c01ca (diff)
downloadrust-f283e449b11ebe8127570aab09b8871442d1e74b.tar.gz
rust-f283e449b11ebe8127570aab09b8871442d1e74b.zip
PR feedback & pipeline
Diffstat (limited to 'library/stdarch/crates/stdarch-gen-arm/src/expression.rs')
-rw-r--r--library/stdarch/crates/stdarch-gen-arm/src/expression.rs576
1 files changed, 576 insertions, 0 deletions
diff --git a/library/stdarch/crates/stdarch-gen-arm/src/expression.rs b/library/stdarch/crates/stdarch-gen-arm/src/expression.rs
new file mode 100644
index 00000000000..83984679588
--- /dev/null
+++ b/library/stdarch/crates/stdarch-gen-arm/src/expression.rs
@@ -0,0 +1,576 @@
+use itertools::Itertools;
+use lazy_static::lazy_static;
+use proc_macro2::{Literal, Punct, Spacing, TokenStream};
+use quote::{format_ident, quote, ToTokens, TokenStreamExt};
+use regex::Regex;
+use serde::de::{self, MapAccess, Visitor};
+use serde::{Deserialize, Deserializer, Serialize};
+use std::fmt;
+use std::str::FromStr;
+
+use crate::intrinsic::Intrinsic;
+use crate::{
+    context::{self, Context, VariableType},
+    intrinsic::{Argument, LLVMLink, StaticDefinition},
+    matching::{MatchKindValues, MatchSizeValues},
+    typekinds::{BaseType, BaseTypeKind, TypeKind},
+    wildcards::Wildcard,
+    wildstring::WildString,
+};
+
+#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
+pub enum IdentifierType {
+    Variable,
+    Symbol,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum LetVariant {
+    Basic(WildString, Box<Expression>),
+    WithType(WildString, TypeKind, Box<Expression>),
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct FnCall(
+    /// Function pointer
+    pub Box<Expression>,
+    /// Function arguments
+    pub Vec<Expression>,
+    /// Function turbofish arguments
+    #[serde(default)]
+    pub Vec<Expression>,
+);
+
+impl FnCall {
+    pub fn new_expression(fn_ptr: Expression, arguments: Vec<Expression>) -> Expression {
+        FnCall(Box::new(fn_ptr), arguments, Vec::new()).into()
+    }
+
+    pub fn is_llvm_link_call(&self, llvm_link_name: &String) -> bool {
+        self.is_expected_call(llvm_link_name)
+    }
+
+    pub fn is_target_feature_call(&self) -> bool {
+        self.is_expected_call("target_feature")
+    }
+
+    pub fn is_expected_call(&self, fn_call_name: &str) -> bool {
+        if let Expression::Identifier(fn_name, IdentifierType::Symbol) = self.0.as_ref() {
+            &fn_name.to_string() == fn_call_name
+        } else {
+            false
+        }
+    }
+
+    pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result {
+        self.0.pre_build(ctx)?;
+        self.1
+            .iter_mut()
+            .chain(self.2.iter_mut())
+            .try_for_each(|ex| ex.pre_build(ctx))
+    }
+
+    pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result {
+        self.0.build(intrinsic, ctx)?;
+        self.1
+            .iter_mut()
+            .chain(self.2.iter_mut())
+            .try_for_each(|ex| ex.build(intrinsic, ctx))
+    }
+}
+
+impl ToTokens for FnCall {
+    fn to_tokens(&self, tokens: &mut TokenStream) {
+        let FnCall(fn_ptr, arguments, turbofish) = self;
+
+        fn_ptr.to_tokens(tokens);
+
+        if !turbofish.is_empty() {
+            tokens.append_all(quote! {::<#(#turbofish),*>});
+        }
+
+        tokens.append_all(quote! { (#(#arguments),*) })
+    }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(remote = "Self", deny_unknown_fields)]
+pub enum Expression {
+    /// (Re)Defines a variable
+    Let(LetVariant),
+    /// Performs a variable assignment operation
+    Assign(String, Box<Expression>),
+    /// Performs a macro call
+    MacroCall(String, String),
+    /// Performs a function call
+    FnCall(FnCall),
+    /// Performs a method call. The following:
+    /// `MethodCall: ["$object", "to_string", []]`
+    /// is tokenized as:
+    /// `object.to_string()`.
+    MethodCall(Box<Expression>, String, Vec<Expression>),
+    /// Symbol identifier name, prepend with a `$` to treat it as a scope variable
+    /// which engages variable tracking and enables inference.
+    /// E.g. `my_function_name` for a generic symbol or `$my_variable` for
+    /// a variable.
+    Identifier(WildString, IdentifierType),
+    /// Constant signed integer number expression
+    IntConstant(i32),
+    /// Constant floating point number expression
+    FloatConstant(f32),
+    /// Constant boolean expression, either `true` or `false`
+    BoolConstant(bool),
+    /// Array expression
+    Array(Vec<Expression>),
+
+    // complex expressions
+    /// Makes an LLVM link.
+    ///
+    /// It stores the link's function name in the wildcard `{llvm_link}`, for use in
+    /// subsequent expressions.
+    LLVMLink(LLVMLink),
+    /// Casts the given expression to the specified (unchecked) type
+    CastAs(Box<Expression>, String),
+    /// Returns the LLVM `undef` symbol
+    SvUndef,
+    /// Multiplication
+    Multiply(Box<Expression>, Box<Expression>),
+    /// Xor
+    Xor(Box<Expression>, Box<Expression>),
+    /// Converts the specified constant to the specified type's kind
+    ConvertConst(TypeKind, i32),
+    /// Yields the given type in the Rust representation
+    Type(TypeKind),
+
+    MatchSize(TypeKind, MatchSizeValues<Box<Expression>>),
+    MatchKind(TypeKind, MatchKindValues<Box<Expression>>),
+}
+
+impl Expression {
+    pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result {
+        match self {
+            Self::FnCall(fn_call) => fn_call.pre_build(ctx),
+            Self::MethodCall(cl_ptr_ex, _, arg_exs) => {
+                cl_ptr_ex.pre_build(ctx)?;
+                arg_exs.iter_mut().try_for_each(|ex| ex.pre_build(ctx))
+            }
+            Self::Let(LetVariant::Basic(_, ex) | LetVariant::WithType(_, _, ex)) => {
+                ex.pre_build(ctx)
+            }
+            Self::CastAs(ex, _) => ex.pre_build(ctx),
+            Self::Multiply(lhs, rhs) | Self::Xor(lhs, rhs) => {
+                lhs.pre_build(ctx)?;
+                rhs.pre_build(ctx)
+            }
+            Self::MatchSize(match_ty, values) => {
+                *self = *values.get(match_ty, ctx.local)?.to_owned();
+                self.pre_build(ctx)
+            }
+            Self::MatchKind(match_ty, values) => {
+                *self = *values.get(match_ty, ctx.local)?.to_owned();
+                self.pre_build(ctx)
+            }
+            _ => Ok(()),
+        }
+    }
+
+    pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result {
+        match self {
+            Self::LLVMLink(link) => link.build_and_save(ctx),
+            Self::Identifier(identifier, id_type) => {
+                identifier.build_acle(ctx.local)?;
+
+                if let IdentifierType::Variable = id_type {
+                    ctx.local
+                        .variables
+                        .get(&identifier.to_string())
+                        .map(|_| ())
+                        .ok_or_else(|| format!("invalid variable {identifier} being referenced"))
+                } else {
+                    Ok(())
+                }
+            }
+            Self::FnCall(fn_call) => {
+                fn_call.build(intrinsic, ctx)?;
+
+                if let Some(llvm_link_name) = ctx.local.substitutions.get(&Wildcard::LLVMLink) {
+                    if fn_call.is_llvm_link_call(llvm_link_name) {
+                        *self = intrinsic
+                            .llvm_link()
+                            .expect("got LLVMLink wildcard without a LLVM link in `compose`")
+                            .apply_conversions_to_call(fn_call.clone(), ctx.local)?
+                    }
+                }
+
+                Ok(())
+            }
+            Self::MethodCall(cl_ptr_ex, _, arg_exs) => {
+                cl_ptr_ex.build(intrinsic, ctx)?;
+                arg_exs
+                    .iter_mut()
+                    .try_for_each(|ex| ex.build(intrinsic, ctx))
+            }
+            Self::Let(variant) => {
+                let (var_name, ex, ty) = match variant {
+                    LetVariant::Basic(var_name, ex) => (var_name, ex, None),
+                    LetVariant::WithType(var_name, ty, ex) => {
+                        if let Some(w) = ty.wildcard() {
+                            ty.populate_wildcard(ctx.local.provide_type_wildcard(w)?)?;
+                        }
+                        (var_name, ex, Some(ty.to_owned()))
+                    }
+                };
+
+                var_name.build_acle(ctx.local)?;
+                ctx.local.variables.insert(
+                    var_name.to_string(),
+                    (
+                        ty.unwrap_or_else(|| TypeKind::Custom("unknown".to_string())),
+                        VariableType::Internal,
+                    ),
+                );
+                ex.build(intrinsic, ctx)
+            }
+            Self::CastAs(ex, _) => ex.build(intrinsic, ctx),
+            Self::Multiply(lhs, rhs) | Self::Xor(lhs, rhs) => {
+                lhs.build(intrinsic, ctx)?;
+                rhs.build(intrinsic, ctx)
+            }
+            Self::ConvertConst(ty, num) => {
+                if let Some(w) = ty.wildcard() {
+                    *ty = ctx.local.provide_type_wildcard(w)?
+                }
+
+                if let Some(BaseType::Sized(BaseTypeKind::Float, _)) = ty.base() {
+                    *self = Expression::FloatConstant(*num as f32)
+                } else {
+                    *self = Expression::IntConstant(*num)
+                }
+                Ok(())
+            }
+            Self::Type(ty) => {
+                if let Some(w) = ty.wildcard() {
+                    *ty = ctx.local.provide_type_wildcard(w)?
+                }
+
+                Ok(())
+            }
+            _ => Ok(()),
+        }
+    }
+
+    /// True if the expression requires an `unsafe` context in a safe function.
+    ///
+    /// The classification is somewhat fuzzy, based on actual usage (e.g. empirical function names)
+    /// rather than a full parse. This is a reasonable approach because mistakes here will usually
+    /// be caught at build time:
+    ///
+    ///  - Missing an `unsafe` is a build error.
+    ///  - An unnecessary `unsafe` is a warning, made into an error by the CI's `-D warnings`.
+    ///
+    /// This **panics** if it encounters an expression that shouldn't appear in a safe function at
+    /// all (such as `SvUndef`).
+    pub fn requires_unsafe_wrapper(&self, ctx_fn: &str) -> bool {
+        match self {
+            // The call will need to be unsafe, but the declaration does not.
+            Self::LLVMLink(..) => false,
+            // Identifiers, literals and type names are never unsafe.
+            Self::Identifier(..) => false,
+            Self::IntConstant(..) => false,
+            Self::FloatConstant(..) => false,
+            Self::BoolConstant(..) => false,
+            Self::Type(..) => false,
+            Self::ConvertConst(..) => false,
+            // Nested structures that aren't inherently unsafe, but could contain other expressions
+            // that might be.
+            Self::Assign(_var, exp) => exp.requires_unsafe_wrapper(ctx_fn),
+            Self::Let(LetVariant::Basic(_, exp) | LetVariant::WithType(_, _, exp)) => {
+                exp.requires_unsafe_wrapper(ctx_fn)
+            }
+            Self::Array(exps) => exps.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)),
+            Self::Multiply(lhs, rhs) | Self::Xor(lhs, rhs) => {
+                lhs.requires_unsafe_wrapper(ctx_fn) || rhs.requires_unsafe_wrapper(ctx_fn)
+            }
+            Self::CastAs(exp, _ty) => exp.requires_unsafe_wrapper(ctx_fn),
+            // Functions and macros can be unsafe, but can also contain other expressions.
+            Self::FnCall(FnCall(fn_exp, args, turbo_args)) => {
+                let fn_name = fn_exp.to_string();
+                fn_exp.requires_unsafe_wrapper(ctx_fn)
+                    || fn_name.starts_with("_sv")
+                    || fn_name.starts_with("simd_")
+                    || fn_name.ends_with("transmute")
+                    || args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
+                    || turbo_args
+                        .iter()
+                        .any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
+            }
+            Self::MethodCall(exp, fn_name, args) => match fn_name.as_str() {
+                // `as_signed` and `as_unsigned` are unsafe because they're trait methods with
+                // target features to allow use on feature-dependent types (such as SVE vectors).
+                // We can safely wrap them here.
+                "as_signed" => true,
+                "as_unsigned" => true,
+                _ => {
+                    exp.requires_unsafe_wrapper(ctx_fn)
+                        || args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
+                }
+            },
+            // We only use macros to check const generics (using static assertions).
+            Self::MacroCall(_name, _args) => false,
+            // Materialising uninitialised values is always unsafe, and we avoid it in safe
+            // functions.
+            Self::SvUndef => panic!("Refusing to wrap unsafe SvUndef in safe function '{ctx_fn}'."),
+            // Variants that aren't tokenised. We shouldn't encounter these here.
+            Self::MatchKind(..) => {
+                unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.")
+            }
+            Self::MatchSize(..) => {
+                unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.")
+            }
+        }
+    }
+}
+
+impl FromStr for Expression {
+    type Err = String;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        lazy_static! {
+            static ref MACRO_RE: Regex =
+                Regex::new(r"^(?P<name>[\w\d_]+)!\((?P<ex>.*?)\);?$").unwrap();
+        }
+
+        if s == "SvUndef" {
+            Ok(Expression::SvUndef)
+        } else if MACRO_RE.is_match(s) {
+            let c = MACRO_RE.captures(s).unwrap();
+            let ex = c["ex"].to_string();
+            let _: TokenStream = ex
+                .parse()
+                .map_err(|e| format!("could not parse macro call expression: {e:#?}"))?;
+            Ok(Expression::MacroCall(c["name"].to_string(), ex))
+        } else {
+            let (s, id_type) = if let Some(varname) = s.strip_prefix('$') {
+                (varname, IdentifierType::Variable)
+            } else {
+                (s, IdentifierType::Symbol)
+            };
+            let identifier = s.trim().parse()?;
+            Ok(Expression::Identifier(identifier, id_type))
+        }
+    }
+}
+
+impl From<FnCall> for Expression {
+    fn from(fn_call: FnCall) -> Self {
+        Expression::FnCall(fn_call)
+    }
+}
+
+impl From<WildString> for Expression {
+    fn from(ws: WildString) -> Self {
+        Expression::Identifier(ws, IdentifierType::Symbol)
+    }
+}
+
+impl From<&Argument> for Expression {
+    fn from(a: &Argument) -> Self {
+        Expression::Identifier(a.name.to_owned(), IdentifierType::Variable)
+    }
+}
+
+impl TryFrom<&StaticDefinition> for Expression {
+    type Error = String;
+
+    fn try_from(sd: &StaticDefinition) -> Result<Self, Self::Error> {
+        match sd {
+            StaticDefinition::Constant(imm) => Ok(imm.into()),
+            StaticDefinition::Generic(t) => t.parse(),
+        }
+    }
+}
+
+impl fmt::Display for Expression {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Self::Identifier(identifier, kind) => {
+                write!(
+                    f,
+                    "{}{identifier}",
+                    matches!(kind, IdentifierType::Variable)
+                        .then_some("$")
+                        .unwrap_or_default()
+                )
+            }
+            Self::MacroCall(name, expression) => {
+                write!(f, "{name}!({expression})")
+            }
+            _ => Err(fmt::Error),
+        }
+    }
+}
+
+impl ToTokens for Expression {
+    fn to_tokens(&self, tokens: &mut TokenStream) {
+        match self {
+            Self::Let(LetVariant::Basic(var_name, exp)) => {
+                let var_ident = format_ident!("{}", var_name.to_string());
+                tokens.append_all(quote! { let #var_ident = #exp })
+            }
+            Self::Let(LetVariant::WithType(var_name, ty, exp)) => {
+                let var_ident = format_ident!("{}", var_name.to_string());
+                tokens.append_all(quote! { let #var_ident: #ty = #exp })
+            }
+            Self::Assign(var_name, exp) => {
+                /* If we are dereferencing a variable to assign a value \
+                 * the 'format_ident!' macro does not like the asterix */
+                let var_name_str: &str;
+
+                if let Some(ch) = var_name.chars().nth(0) {
+                    /* Manually append the asterix and split out the rest of
+                     * the variable name */
+                    if ch == '*' {
+                        tokens.append(Punct::new('*', Spacing::Alone));
+                        var_name_str = &var_name[1..var_name.len()];
+                    } else {
+                        var_name_str = var_name.as_str();
+                    }
+                } else {
+                    /* Should not be reached as you cannot have a variable
+                     * without a name */
+                    panic!("Invalid variable name, must be at least one character")
+                }
+
+                let var_ident = format_ident!("{}", var_name_str);
+                tokens.append_all(quote! { #var_ident = #exp })
+            }
+            Self::MacroCall(name, ex) => {
+                let name = format_ident!("{name}");
+                let ex: TokenStream = ex.parse().unwrap();
+                tokens.append_all(quote! { #name!(#ex) })
+            }
+            Self::FnCall(fn_call) => fn_call.to_tokens(tokens),
+            Self::MethodCall(exp, fn_name, args) => {
+                let fn_ident = format_ident!("{}", fn_name);
+                tokens.append_all(quote! { #exp.#fn_ident(#(#args),*) })
+            }
+            Self::Identifier(identifier, _) => {
+                assert!(
+                    !identifier.has_wildcards(),
+                    "expression {self:#?} was not built before calling to_tokens"
+                );
+                identifier
+                    .to_string()
+                    .parse::<TokenStream>()
+                    .expect(format!("invalid syntax: {:?}", self).as_str())
+                    .to_tokens(tokens);
+            }
+            Self::IntConstant(n) => tokens.append(Literal::i32_unsuffixed(*n)),
+            Self::FloatConstant(n) => tokens.append(Literal::f32_unsuffixed(*n)),
+            Self::BoolConstant(true) => tokens.append(format_ident!("true")),
+            Self::BoolConstant(false) => tokens.append(format_ident!("false")),
+            Self::Array(vec) => tokens.append_all(quote! { [ #(#vec),* ] }),
+            Self::LLVMLink(link) => link.to_tokens(tokens),
+            Self::CastAs(ex, ty) => {
+                let ty: TokenStream = ty.parse().expect("invalid syntax");
+                tokens.append_all(quote! { #ex as #ty })
+            }
+            Self::SvUndef => tokens.append_all(quote! { simd_reinterpret(()) }),
+            Self::Multiply(lhs, rhs) => tokens.append_all(quote! { #lhs * #rhs }),
+            Self::Xor(lhs, rhs) => tokens.append_all(quote! { #lhs ^ #rhs }),
+            Self::Type(ty) => ty.to_tokens(tokens),
+            _ => unreachable!("{self:?} cannot be converted to tokens."),
+        }
+    }
+}
+
+impl Serialize for Expression {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: serde::Serializer,
+    {
+        match self {
+            Self::IntConstant(v) => serializer.serialize_i32(*v),
+            Self::FloatConstant(v) => serializer.serialize_f32(*v),
+            Self::BoolConstant(v) => serializer.serialize_bool(*v),
+            Self::Identifier(..) => serializer.serialize_str(&self.to_string()),
+            Self::MacroCall(..) => serializer.serialize_str(&self.to_string()),
+            _ => Expression::serialize(self, serializer),
+        }
+    }
+}
+
+impl<'de> Deserialize<'de> for Expression {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        struct CustomExpressionVisitor;
+
+        impl<'de> Visitor<'de> for CustomExpressionVisitor {
+            type Value = Expression;
+
+            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+                formatter.write_str("integer, float, boolean, string or map")
+            }
+
+            fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
+            where
+                E: de::Error,
+            {
+                Ok(Expression::BoolConstant(v))
+            }
+
+            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
+            where
+                E: de::Error,
+            {
+                Ok(Expression::IntConstant(v as i32))
+            }
+
+            fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
+            where
+                E: de::Error,
+            {
+                Ok(Expression::IntConstant(v as i32))
+            }
+
+            fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
+            where
+                E: de::Error,
+            {
+                Ok(Expression::FloatConstant(v as f32))
+            }
+
+            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
+            where
+                E: de::Error,
+            {
+                FromStr::from_str(value).map_err(de::Error::custom)
+            }
+
+            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
+            where
+                A: de::SeqAccess<'de>,
+            {
+                let arr = std::iter::from_fn(|| seq.next_element::<Self::Value>().transpose())
+                    .try_collect()?;
+                Ok(Expression::Array(arr))
+            }
+
+            fn visit_map<M>(self, map: M) -> Result<Expression, M::Error>
+            where
+                M: MapAccess<'de>,
+            {
+                // `MapAccessDeserializer` is a wrapper that turns a `MapAccess`
+                // into a `Deserializer`, allowing it to be used as the input to T's
+                // `Deserialize` implementation. T then deserializes itself using
+                // the entries from the map visitor.
+                Expression::deserialize(de::value::MapAccessDeserializer::new(map))
+            }
+        }
+
+        deserializer.deserialize_any(CustomExpressionVisitor)
+    }
+}