use itertools::Itertools; use serde::{Deserialize, Deserializer, Serialize, de}; use crate::{ context::{self, GlobalContext}, intrinsic::Intrinsic, predicate_forms::{PredicateForm, PredicationMask, PredicationMethods}, typekinds::TypeKind, wildstring::WildString, }; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(untagged)] pub enum InputType { /// PredicateForm variant argument #[serde(skip)] // Predicate forms have their own dedicated deserialization field. Skip. PredicateForm(PredicateForm), /// Operand from which to generate an N variant #[serde(skip)] NVariantOp(Option), /// TypeKind variant argument Type(TypeKind), } impl InputType { /// Optionally unwraps as a PredicateForm. pub fn predicate_form(&self) -> Option<&PredicateForm> { match self { InputType::PredicateForm(pf) => Some(pf), _ => None, } } /// Optionally unwraps as a mutable PredicateForm pub fn predicate_form_mut(&mut self) -> Option<&mut PredicateForm> { match self { InputType::PredicateForm(pf) => Some(pf), _ => None, } } /// Optionally unwraps as a TypeKind. pub fn typekind(&self) -> Option<&TypeKind> { match self { InputType::Type(ty) => Some(ty), _ => None, } } /// Optionally unwraps as a NVariantOp pub fn n_variant_op(&self) -> Option<&WildString> { match self { InputType::NVariantOp(Some(op)) => Some(op), _ => None, } } } impl PartialOrd for InputType { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for InputType { fn cmp(&self, other: &Self) -> std::cmp::Ordering { use std::cmp::Ordering::*; match (self, other) { (InputType::PredicateForm(pf1), InputType::PredicateForm(pf2)) => pf1.cmp(pf2), (InputType::Type(ty1), InputType::Type(ty2)) => ty1.cmp(ty2), (InputType::NVariantOp(None), InputType::NVariantOp(Some(..))) => Less, (InputType::NVariantOp(Some(..)), InputType::NVariantOp(None)) => Greater, (InputType::NVariantOp(_), InputType::NVariantOp(_)) => Equal, (InputType::Type(..), InputType::PredicateForm(..)) => Less, (InputType::PredicateForm(..), InputType::Type(..)) => Greater, (InputType::Type(..), InputType::NVariantOp(..)) => Less, (InputType::NVariantOp(..), InputType::Type(..)) => Greater, (InputType::PredicateForm(..), InputType::NVariantOp(..)) => Less, (InputType::NVariantOp(..), InputType::PredicateForm(..)) => Greater, } } } mod many_or_one { use serde::{Deserialize, Serialize, de::Deserializer, ser::Serializer}; pub fn serialize(vec: &Vec, serializer: S) -> Result where T: Serialize, S: Serializer, { if vec.len() == 1 { vec.first().unwrap().serialize(serializer) } else { vec.serialize(serializer) } } pub fn deserialize<'de, T, D>(deserializer: D) -> Result, D::Error> where T: Deserialize<'de>, D: Deserializer<'de>, { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] enum ManyOrOne { Many(Vec), One(T), } match ManyOrOne::deserialize(deserializer)? { ManyOrOne::Many(vec) => Ok(vec), ManyOrOne::One(val) => Ok(vec![val]), } } } #[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct InputSet(#[serde(with = "many_or_one")] Vec); impl InputSet { pub fn get(&self, idx: usize) -> Option<&InputType> { self.0.get(idx) } pub fn is_empty(&self) -> bool { self.0.is_empty() } pub fn iter(&self) -> impl Iterator + '_ { self.0.iter() } pub fn iter_mut(&mut self) -> impl Iterator + '_ { self.0.iter_mut() } pub fn into_iter(self) -> impl Iterator + Clone { self.0.into_iter() } pub fn types_len(&self) -> usize { self.iter().filter_map(|arg| arg.typekind()).count() } pub fn typekind(&self, idx: Option) -> Option { let types_len = self.types_len(); self.get(idx.unwrap_or(0)).and_then(move |arg: &InputType| { if (idx.is_none() && types_len != 1) || (idx.is_some() && types_len == 1) { None } else { arg.typekind().cloned() } }) } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct InputSetEntry(#[serde(with = "many_or_one")] Vec); impl InputSetEntry { pub fn new(input: Vec) -> Self { Self(input) } pub fn get(&self, idx: usize) -> Option<&InputSet> { self.0.get(idx) } } fn validate_types<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { let v: Vec = Vec::deserialize(deserializer)?; let mut it = v.iter(); if let Some(first) = it.next() { it.try_fold(first, |last, cur| { if last.0.len() == cur.0.len() { Ok(cur) } else { Err("the length of the InputSets and the product lists must match".to_string()) } }) .map_err(de::Error::custom)?; } Ok(v) } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct IntrinsicInput { #[serde(default)] #[serde(deserialize_with = "validate_types")] pub types: Vec, #[serde(flatten)] pub predication_methods: PredicationMethods, /// Generates a _n variant where the specified operand is a primitive type /// that requires conversion to an SVE one. The `{_n}` wildcard is required /// in the intrinsic's name, otherwise an error will be thrown. #[serde(default)] pub n_variant_op: WildString, } impl IntrinsicInput { /// Extracts all the possible variants as an iterator. pub fn variants( &self, intrinsic: &Intrinsic, ) -> context::Result + '_> { let mut top_product = vec![]; if !self.types.is_empty() { top_product.push( self.types .iter() .flat_map(|ty_in| { ty_in .0 .iter() .map(|v| v.clone().into_iter()) .multi_cartesian_product() }) .collect_vec(), ) } if let Ok(mask) = PredicationMask::try_from(&intrinsic.signature.name) { top_product.push( PredicateForm::compile_list(&mask, &self.predication_methods)? .into_iter() .map(|pf| vec![InputType::PredicateForm(pf)]) .collect_vec(), ) } if !self.n_variant_op.is_empty() { top_product.push(vec![ vec![InputType::NVariantOp(None)], vec![InputType::NVariantOp(Some(self.n_variant_op.to_owned()))], ]) } let it = top_product .into_iter() .map(|v| v.into_iter()) .multi_cartesian_product() .filter(|set| !set.is_empty()) .map(|set| InputSet(set.into_iter().flatten().collect_vec())); Ok(it) } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GeneratorInput { #[serde(flatten)] pub ctx: GlobalContext, pub intrinsics: Vec, } #[cfg(test)] mod tests { use crate::{ input::*, predicate_forms::{DontCareMethod, ZeroingMethod}, }; #[test] fn test_empty() { let str = r#"types: []"#; let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter(); assert_eq!(variants.next(), None); } #[test] fn test_product() { let str = r#"types: - [f64, f32] - [i64, [f64, f32]] "#; let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); let mut intrinsic = Intrinsic::default(); intrinsic.signature.name = "test_intrinsic{_mx}".parse().unwrap(); let mut variants = input.variants(&intrinsic).unwrap().into_iter(); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("f64".parse().unwrap()), InputType::Type("f32".parse().unwrap()), InputType::PredicateForm(PredicateForm::Merging), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("f64".parse().unwrap()), InputType::Type("f32".parse().unwrap()), InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("i64".parse().unwrap()), InputType::Type("f64".parse().unwrap()), InputType::PredicateForm(PredicateForm::Merging), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("i64".parse().unwrap()), InputType::Type("f64".parse().unwrap()), InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("i64".parse().unwrap()), InputType::Type("f32".parse().unwrap()), InputType::PredicateForm(PredicateForm::Merging), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("i64".parse().unwrap()), InputType::Type("f32".parse().unwrap()), InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)), ])), ); assert_eq!(variants.next(), None); } #[test] fn test_n_variant() { let str = r#"types: - [f64, f32] n_variant_op: op2 "#; let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter(); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("f64".parse().unwrap()), InputType::Type("f32".parse().unwrap()), InputType::NVariantOp(None), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("f64".parse().unwrap()), InputType::Type("f32".parse().unwrap()), InputType::NVariantOp(Some("op2".parse().unwrap())), ])) ); assert_eq!(variants.next(), None) } #[test] fn test_invalid_length() { let str = r#"types: [i32, [[u64], [u32]]]"#; serde_yaml::from_str::(str).expect_err("failure expected"); } #[test] fn test_invalid_predication() { let str = "types: []"; let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); let mut intrinsic = Intrinsic::default(); intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap(); input .variants(&intrinsic) .map(|v| v.collect_vec()) .expect_err("failure expected"); } #[test] fn test_invalid_predication_mask() { "test_intrinsic{_mxy}" .parse::() .expect_err("failure expected"); "test_intrinsic{_}" .parse::() .expect_err("failure expected"); } #[test] fn test_zeroing_predication() { let str = r#"types: [i64] zeroing_method: { drop: inactive }"#; let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); let mut intrinsic = Intrinsic::default(); intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap(); let mut variants = input.variants(&intrinsic).unwrap(); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("i64".parse().unwrap()), InputType::PredicateForm(PredicateForm::Merging), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("i64".parse().unwrap()), InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsZeroing)), ])) ); assert_eq!( variants.next(), Some(InputSet(vec![ InputType::Type("i64".parse().unwrap()), InputType::PredicateForm(PredicateForm::Zeroing(ZeroingMethod::Drop { drop: "inactive".parse().unwrap() })), ])) ); assert_eq!(variants.next(), None) } }