about summary refs log tree commit diff
path: root/library/stdarch/crates/stdarch-gen-arm/src/input.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/stdarch/crates/stdarch-gen-arm/src/input.rs')
-rw-r--r--library/stdarch/crates/stdarch-gen-arm/src/input.rs433
1 files changed, 433 insertions, 0 deletions
diff --git a/library/stdarch/crates/stdarch-gen-arm/src/input.rs b/library/stdarch/crates/stdarch-gen-arm/src/input.rs
new file mode 100644
index 00000000000..adefbf3215b
--- /dev/null
+++ b/library/stdarch/crates/stdarch-gen-arm/src/input.rs
@@ -0,0 +1,433 @@
+use itertools::Itertools;
+use serde::{Deserialize, Deserializer, Serialize, de};
+
+use crate::{
+    context::{self, GlobalContext},
+    intrinsic::Intrinsic,
+    predicate_forms::{PredicateForm, PredicationMask, PredicationMethods},
+    typekinds::TypeKind,
+    wildstring::WildString,
+};
+
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum InputType {
+    /// PredicateForm variant argument
+    #[serde(skip)] // Predicate forms have their own dedicated deserialization field. Skip.
+    PredicateForm(PredicateForm),
+    /// Operand from which to generate an N variant
+    #[serde(skip)]
+    NVariantOp(Option<WildString>),
+    /// TypeKind variant argument
+    Type(TypeKind),
+}
+
+impl InputType {
+    /// Optionally unwraps as a PredicateForm.
+    pub fn predicate_form(&self) -> Option<&PredicateForm> {
+        match self {
+            InputType::PredicateForm(pf) => Some(pf),
+            _ => None,
+        }
+    }
+
+    /// Optionally unwraps as a mutable PredicateForm
+    pub fn predicate_form_mut(&mut self) -> Option<&mut PredicateForm> {
+        match self {
+            InputType::PredicateForm(pf) => Some(pf),
+            _ => None,
+        }
+    }
+
+    /// Optionally unwraps as a TypeKind.
+    pub fn typekind(&self) -> Option<&TypeKind> {
+        match self {
+            InputType::Type(ty) => Some(ty),
+            _ => None,
+        }
+    }
+
+    /// Optionally unwraps as a NVariantOp
+    pub fn n_variant_op(&self) -> Option<&WildString> {
+        match self {
+            InputType::NVariantOp(Some(op)) => Some(op),
+            _ => None,
+        }
+    }
+}
+
+impl PartialOrd for InputType {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl Ord for InputType {
+    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+        use std::cmp::Ordering::*;
+
+        match (self, other) {
+            (InputType::PredicateForm(pf1), InputType::PredicateForm(pf2)) => pf1.cmp(pf2),
+            (InputType::Type(ty1), InputType::Type(ty2)) => ty1.cmp(ty2),
+
+            (InputType::NVariantOp(None), InputType::NVariantOp(Some(..))) => Less,
+            (InputType::NVariantOp(Some(..)), InputType::NVariantOp(None)) => Greater,
+            (InputType::NVariantOp(_), InputType::NVariantOp(_)) => Equal,
+
+            (InputType::Type(..), InputType::PredicateForm(..)) => Less,
+            (InputType::PredicateForm(..), InputType::Type(..)) => Greater,
+
+            (InputType::Type(..), InputType::NVariantOp(..)) => Less,
+            (InputType::NVariantOp(..), InputType::Type(..)) => Greater,
+
+            (InputType::PredicateForm(..), InputType::NVariantOp(..)) => Less,
+            (InputType::NVariantOp(..), InputType::PredicateForm(..)) => Greater,
+        }
+    }
+}
+
+mod many_or_one {
+    use serde::{Deserialize, Serialize, de::Deserializer, ser::Serializer};
+
+    pub fn serialize<T, S>(vec: &Vec<T>, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        T: Serialize,
+        S: Serializer,
+    {
+        if vec.len() == 1 {
+            vec.first().unwrap().serialize(serializer)
+        } else {
+            vec.serialize(serializer)
+        }
+    }
+
+    pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
+    where
+        T: Deserialize<'de>,
+        D: Deserializer<'de>,
+    {
+        #[derive(Debug, Clone, Serialize, Deserialize)]
+        #[serde(untagged)]
+        enum ManyOrOne<T> {
+            Many(Vec<T>),
+            One(T),
+        }
+
+        match ManyOrOne::deserialize(deserializer)? {
+            ManyOrOne::Many(vec) => Ok(vec),
+            ManyOrOne::One(val) => Ok(vec![val]),
+        }
+    }
+}
+
+#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
+pub struct InputSet(#[serde(with = "many_or_one")] Vec<InputType>);
+
+impl InputSet {
+    pub fn get(&self, idx: usize) -> Option<&InputType> {
+        self.0.get(idx)
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.0.is_empty()
+    }
+
+    pub fn iter(&self) -> impl Iterator<Item = &InputType> + '_ {
+        self.0.iter()
+    }
+
+    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut InputType> + '_ {
+        self.0.iter_mut()
+    }
+
+    pub fn into_iter(self) -> impl Iterator<Item = InputType> + Clone {
+        self.0.into_iter()
+    }
+
+    pub fn types_len(&self) -> usize {
+        self.iter().filter_map(|arg| arg.typekind()).count()
+    }
+
+    pub fn typekind(&self, idx: Option<usize>) -> Option<TypeKind> {
+        let types_len = self.types_len();
+        self.get(idx.unwrap_or(0)).and_then(move |arg: &InputType| {
+            if (idx.is_none() && types_len != 1) || (idx.is_some() && types_len == 1) {
+                None
+            } else {
+                arg.typekind().cloned()
+            }
+        })
+    }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub struct InputSetEntry(#[serde(with = "many_or_one")] Vec<InputSet>);
+
+impl InputSetEntry {
+    pub fn new(input: Vec<InputSet>) -> Self {
+        Self(input)
+    }
+
+    pub fn get(&self, idx: usize) -> Option<&InputSet> {
+        self.0.get(idx)
+    }
+}
+
+fn validate_types<'de, D>(deserializer: D) -> Result<Vec<InputSetEntry>, D::Error>
+where
+    D: Deserializer<'de>,
+{
+    let v: Vec<InputSetEntry> = Vec::deserialize(deserializer)?;
+
+    let mut it = v.iter();
+    if let Some(first) = it.next() {
+        it.try_fold(first, |last, cur| {
+            if last.0.len() == cur.0.len() {
+                Ok(cur)
+            } else {
+                Err("the length of the InputSets and the product lists must match".to_string())
+            }
+        })
+        .map_err(de::Error::custom)?;
+    }
+
+    Ok(v)
+}
+
+#[derive(Debug, Clone, Default, Serialize, Deserialize)]
+pub struct IntrinsicInput {
+    #[serde(default)]
+    #[serde(deserialize_with = "validate_types")]
+    pub types: Vec<InputSetEntry>,
+
+    #[serde(flatten)]
+    pub predication_methods: PredicationMethods,
+
+    /// Generates a _n variant where the specified operand is a primitive type
+    /// that requires conversion to an SVE one. The `{_n}` wildcard is required
+    /// in the intrinsic's name, otherwise an error will be thrown.
+    #[serde(default)]
+    pub n_variant_op: WildString,
+}
+
+impl IntrinsicInput {
+    /// Extracts all the possible variants as an iterator.
+    pub fn variants(
+        &self,
+        intrinsic: &Intrinsic,
+    ) -> context::Result<impl Iterator<Item = InputSet> + '_> {
+        let mut top_product = vec![];
+
+        if !self.types.is_empty() {
+            top_product.push(
+                self.types
+                    .iter()
+                    .flat_map(|ty_in| {
+                        ty_in
+                            .0
+                            .iter()
+                            .map(|v| v.clone().into_iter())
+                            .multi_cartesian_product()
+                    })
+                    .collect_vec(),
+            )
+        }
+
+        if let Ok(mask) = PredicationMask::try_from(&intrinsic.signature.name) {
+            top_product.push(
+                PredicateForm::compile_list(&mask, &self.predication_methods)?
+                    .into_iter()
+                    .map(|pf| vec![InputType::PredicateForm(pf)])
+                    .collect_vec(),
+            )
+        }
+
+        if !self.n_variant_op.is_empty() {
+            top_product.push(vec![
+                vec![InputType::NVariantOp(None)],
+                vec![InputType::NVariantOp(Some(self.n_variant_op.to_owned()))],
+            ])
+        }
+
+        let it = top_product
+            .into_iter()
+            .map(|v| v.into_iter())
+            .multi_cartesian_product()
+            .filter(|set| !set.is_empty())
+            .map(|set| InputSet(set.into_iter().flatten().collect_vec()));
+        Ok(it)
+    }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct GeneratorInput {
+    #[serde(flatten)]
+    pub ctx: GlobalContext,
+    pub intrinsics: Vec<Intrinsic>,
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::{
+        input::*,
+        predicate_forms::{DontCareMethod, ZeroingMethod},
+    };
+
+    #[test]
+    fn test_empty() {
+        let str = r#"types: []"#;
+        let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
+        let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
+        assert_eq!(variants.next(), None);
+    }
+
+    #[test]
+    fn test_product() {
+        let str = r#"types:
+- [f64, f32]
+- [i64, [f64, f32]]
+"#;
+        let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
+        let mut intrinsic = Intrinsic::default();
+        intrinsic.signature.name = "test_intrinsic{_mx}".parse().unwrap();
+        let mut variants = input.variants(&intrinsic).unwrap().into_iter();
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("f64".parse().unwrap()),
+                InputType::Type("f32".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::Merging),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("f64".parse().unwrap()),
+                InputType::Type("f32".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("i64".parse().unwrap()),
+                InputType::Type("f64".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::Merging),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("i64".parse().unwrap()),
+                InputType::Type("f64".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("i64".parse().unwrap()),
+                InputType::Type("f32".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::Merging),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("i64".parse().unwrap()),
+                InputType::Type("f32".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
+            ])),
+        );
+        assert_eq!(variants.next(), None);
+    }
+
+    #[test]
+    fn test_n_variant() {
+        let str = r#"types:
+- [f64, f32]
+n_variant_op: op2
+"#;
+        let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
+        let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("f64".parse().unwrap()),
+                InputType::Type("f32".parse().unwrap()),
+                InputType::NVariantOp(None),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("f64".parse().unwrap()),
+                InputType::Type("f32".parse().unwrap()),
+                InputType::NVariantOp(Some("op2".parse().unwrap())),
+            ]))
+        );
+        assert_eq!(variants.next(), None)
+    }
+
+    #[test]
+    fn test_invalid_length() {
+        let str = r#"types: [i32, [[u64], [u32]]]"#;
+        serde_yaml::from_str::<IntrinsicInput>(str).expect_err("failure expected");
+    }
+
+    #[test]
+    fn test_invalid_predication() {
+        let str = "types: []";
+        let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
+        let mut intrinsic = Intrinsic::default();
+        intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
+        input
+            .variants(&intrinsic)
+            .map(|v| v.collect_vec())
+            .expect_err("failure expected");
+    }
+
+    #[test]
+    fn test_invalid_predication_mask() {
+        "test_intrinsic{_mxy}"
+            .parse::<WildString>()
+            .expect_err("failure expected");
+        "test_intrinsic{_}"
+            .parse::<WildString>()
+            .expect_err("failure expected");
+    }
+
+    #[test]
+    fn test_zeroing_predication() {
+        let str = r#"types: [i64]
+zeroing_method: { drop: inactive }"#;
+        let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
+        let mut intrinsic = Intrinsic::default();
+        intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
+        let mut variants = input.variants(&intrinsic).unwrap();
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("i64".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::Merging),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("i64".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsZeroing)),
+            ]))
+        );
+        assert_eq!(
+            variants.next(),
+            Some(InputSet(vec![
+                InputType::Type("i64".parse().unwrap()),
+                InputType::PredicateForm(PredicateForm::Zeroing(ZeroingMethod::Drop {
+                    drop: "inactive".parse().unwrap()
+                })),
+            ]))
+        );
+        assert_eq!(variants.next(), None)
+    }
+}