Skip to content

Commit

Permalink
Refactor operation bulids in dialect macro (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Dec 5, 2023
1 parent f85bca8 commit 0b64220
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 64 deletions.
74 changes: 36 additions & 38 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<'a> Operation<'a> {

let arguments = Self::dag_constraints(definition, "arguments")?;
let regions = Self::collect_regions(definition)?;
let (results, variable_length_results_count) = Self::collect_results(
let (results, unfixed_results_count) = Self::collect_results(
definition,
has_trait("::mlir::OpTrait::SameVariadicResultSize"),
has_trait("::mlir::OpTrait::AttrSizedResultSegments"),
Expand Down Expand Up @@ -86,7 +86,7 @@ impl<'a> Operation<'a> {
can_infer_type: traits.iter().any(|r#trait| {
(r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType")
|| r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType"))
&& variable_length_results_count == 0
&& unfixed_results_count == 0
|| r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty()
}),
summary: {
Expand Down Expand Up @@ -162,22 +162,22 @@ impl<'a> Operation<'a> {
}

fn collect_traits(definition: Record<'a>) -> Result<Vec<Trait>, Error> {
let mut work_list = vec![definition.list_value("traits")?];
let mut traits = Vec::new();
let mut trait_lists = vec![definition.list_value("traits")?];
let mut traits = vec![];

while let Some(trait_definition) = work_list.pop() {
for value in trait_definition.iter() {
let trait_def: Record = value
while let Some(trait_list) = trait_lists.pop() {
for value in trait_list.iter() {
let definition: Record = value
.try_into()
.map_err(|error: tblgen::Error| error.set_location(definition))?;

if trait_def.subclass_of("TraitList") {
work_list.push(trait_def.list_value("traits")?);
if definition.subclass_of("TraitList") {
trait_lists.push(definition.list_value("traits")?);
} else {
if trait_def.subclass_of("Interface") {
work_list.push(trait_def.list_value("baseInterfaces")?);
if definition.subclass_of("Interface") {
trait_lists.push(definition.list_value("baseInterfaces")?);
}
traits.push(Trait::new(trait_def)?)
traits.push(Trait::new(definition)?)
}
}
}
Expand All @@ -193,23 +193,23 @@ impl<'a> Operation<'a> {
.dag_value(dag_field_name)?
.args()
.map(|(name, argument)| {
let mut argument_definition: Record = argument
let mut definition: Record = argument
.try_into()
.map_err(|error: tblgen::Error| error.set_location(definition))?;

if argument_definition.subclass_of("OpVariable") {
argument_definition = argument_definition.def_value("constraint")?;
if definition.subclass_of("OpVariable") {
definition = definition.def_value("constraint")?;
}

Ok((name, argument_definition))
Ok((name, definition))
})
.collect()
}

fn collect_results(
def: Record<'a>,
same_size: bool,
attr_sized: bool,
attribute_sized: bool,
) -> Result<(Vec<OperationField>, usize), Error> {
Self::collect_elements(
&Self::dag_constraints(def, "results")?
Expand All @@ -218,24 +218,24 @@ impl<'a> Operation<'a> {
.collect::<Vec<_>>(),
ElementKind::Result,
same_size,
attr_sized,
attribute_sized,
)
}

fn collect_operands(
arguments: &[(&'a str, Record<'a>)],
same_size: bool,
attr_sized: bool,
attribute_sized: bool,
) -> Result<Vec<OperationField<'a>>, Error> {
Ok(Self::collect_elements(
&arguments
.iter()
.filter(|(_, arg_def)| arg_def.subclass_of("TypeConstraint"))
.map(|(name, arg_def)| (*name, TypeConstraint::new(*arg_def)))
.filter(|(_, definition)| definition.subclass_of("TypeConstraint"))
.map(|(name, definition)| (*name, TypeConstraint::new(*definition)))
.collect::<Vec<_>>(),
ElementKind::Operand,
same_size,
attr_sized,
attribute_sized,
)?
.0)
}
Expand All @@ -244,13 +244,13 @@ impl<'a> Operation<'a> {
elements: &[(&'a str, TypeConstraint<'a>)],
element_kind: ElementKind,
same_size: bool,
attr_sized: bool,
attribute_sized: bool,
) -> Result<(Vec<OperationField<'a>>, usize), Error> {
let variable_length_count = elements
let unfixed_count = elements
.iter()
.filter(|(_, constraint)| constraint.has_variable_length())
.filter(|(_, constraint)| constraint.has_unfixed())
.count();
let mut variadic_kind = VariadicKind::new(variable_length_count, same_size, attr_sized);
let mut variadic_kind = VariadicKind::new(unfixed_count, same_size, attribute_sized);
let mut fields = vec![];

for (index, (name, constraint)) in elements.iter().enumerate() {
Expand All @@ -266,42 +266,40 @@ impl<'a> Operation<'a> {
)?);

match &mut variadic_kind {
VariadicKind::Simple {
variable_length_seen: seen_variable_length,
} => {
if constraint.has_variable_length() {
*seen_variable_length = true;
VariadicKind::Simple { unfixed_seen } => {
if constraint.has_unfixed() {
*unfixed_seen = true;
}
}
VariadicKind::SameSize {
preceding_simple_count,
preceding_variadic_count,
..
} => {
if constraint.has_variable_length() {
if constraint.has_unfixed() {
*preceding_variadic_count += 1;
} else {
*preceding_simple_count += 1;
}
}
VariadicKind::AttrSized {} => {}
VariadicKind::AttributeSized => {}
}
}

Ok((fields, variable_length_count))
Ok((fields, unfixed_count))
}

fn collect_attributes(
arguments: &[(&'a str, Record<'a>)],
) -> Result<Vec<OperationField<'a>>, Error> {
arguments
.iter()
.filter(|(_, arg_def)| arg_def.subclass_of("Attr"))
.map(|(name, arg_def)| {
.filter(|(_, definition)| definition.subclass_of("Attr"))
.map(|(name, definition)| {
// TODO: Replace assert! with Result
assert!(!arg_def.subclass_of("DerivedAttr"));
assert!(!definition.subclass_of("DerivedAttr"));

OperationField::new_attribute(name, AttributeConstraint::new(*arg_def))
OperationField::new_attribute(name, AttributeConstraint::new(*definition))
})
.collect()
}
Expand Down
22 changes: 10 additions & 12 deletions macro/src/dialect/operation/accessors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ impl<'a> OperationField<'a> {
let name = self.name;

Some(match variadic_kind {
VariadicKind::Simple {
variable_length_seen: seen_variable_length,
} => {
VariadicKind::Simple { unfixed_seen } => {
if constraint.is_optional() {
// Optional element, and some singular elements.
// Only present if the amount of groups is at least the number of
Expand All @@ -40,15 +38,15 @@ impl<'a> OperationField<'a> {
}
}
} else if constraint.is_variadic() {
// A variable length group
// A unfixed group
// Length computed by subtracting the amount of other
// singular elements from the number of elements.
quote! {
let group_length = self.operation.#count() - #len + 1;
self.operation.#plural().skip(#index).take(group_length)
}
} else if *seen_variable_length {
// Single element after variable length group
} else if *unfixed_seen {
// Single element after unfixed group
// Compute the length of that variable group and take the next element
quote! {
let group_length = self.operation.#count() - #len + 1;
Expand All @@ -62,16 +60,16 @@ impl<'a> OperationField<'a> {
}
}
VariadicKind::SameSize {
variable_length_count,
unfixed_count,
preceding_simple_count,
preceding_variadic_count,
} => {
let compute_start_length = quote! {
let total_var_len = self.operation.#count() - #variable_length_count + 1;
let group_len = total_var_len / #variable_length_count;
let total_var_len = self.operation.#count() - #unfixed_count + 1;
let group_len = total_var_len / #unfixed_count;
let start = #preceding_simple_count + #preceding_variadic_count * group_len;
};
let get_elements = if constraint.has_variable_length() {
let get_elements = if constraint.has_unfixed() {
quote! {
self.operation.#plural().skip(start).take(group_len)
}
Expand All @@ -83,7 +81,7 @@ impl<'a> OperationField<'a> {

quote! { #compute_start_length #get_elements }
}
VariadicKind::AttrSized {} => {
VariadicKind::AttributeSized => {
let attribute_name = format!("{}_segment_sizes", kind_str);
let compute_start_length = quote! {
let attribute =
Expand All @@ -98,7 +96,7 @@ impl<'a> OperationField<'a> {
.sum::<i32>() as usize;
let group_len = attribute.element(#index)? as usize;
};
let get_elements = if !constraint.has_variable_length() {
let get_elements = if !constraint.has_unfixed() {
quote! {
self.operation.#kind_ident(start)
}
Expand Down
2 changes: 1 addition & 1 deletion macro/src/dialect/operation/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl<'o> OperationBuilder<'o> {
// arguments
let add_arguments = match &field.kind {
FieldKind::Element { constraint, .. } => {
if constraint.has_variable_length() && !constraint.is_optional() {
if constraint.has_unfixed() && !constraint.is_optional() {
quote! { #name }
} else {
quote! { &[#name] }
Expand Down
2 changes: 1 addition & 1 deletion macro/src/dialect/operation/field_kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<'a> FieldKind<'a> {
};
if !constraint.is_variadic() {
Self::create_result_type(base_type)
} else if let VariadicKind::AttrSized {} = variadic_kind {
} else if let VariadicKind::AttributeSized = variadic_kind {
Self::create_result_type(Self::create_iterator_type(base_type))
} else {
Self::create_iterator_type(base_type)
Expand Down
22 changes: 11 additions & 11 deletions macro/src/dialect/operation/variadic_kind.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
#[derive(Clone, Debug)]
pub enum VariadicKind {
Simple {
variable_length_seen: bool,
unfixed_seen: bool,
},
SameSize {
variable_length_count: usize,
unfixed_count: usize,
preceding_simple_count: usize,
preceding_variadic_count: usize,
},
AttrSized {},
AttributeSized,
}

impl VariadicKind {
pub fn new(variable_length_count: usize, same_size: bool, attr_sized: bool) -> Self {
if variable_length_count <= 1 {
VariadicKind::Simple {
variable_length_seen: false,
pub fn new(unfixed_count: usize, same_size: bool, attribute_sized: bool) -> Self {
if unfixed_count <= 1 {
Self::Simple {
unfixed_seen: false,
}
} else if same_size {
VariadicKind::SameSize {
variable_length_count,
Self::SameSize {
unfixed_count,
preceding_simple_count: 0,
preceding_variadic_count: 0,
}
} else if attr_sized {
VariadicKind::AttrSized {}
} else if attribute_sized {
Self::AttributeSized
} else {
unimplemented!()
}
Expand Down
2 changes: 1 addition & 1 deletion macro/src/dialect/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<'a> TypeConstraint<'a> {
self.0.subclass_of("VariadicOfVariadic")
}

pub fn has_variable_length(&self) -> bool {
pub fn has_unfixed(&self) -> bool {
self.is_variadic() || self.is_optional()
}
}
Expand Down

0 comments on commit 0b64220

Please sign in to comment.