diff --git a/Cargo.lock b/Cargo.lock index 2351369f5..d37948010 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anstream" version = "0.6.19" @@ -73,6 +79,15 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "ar_archive_writer" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c269894b6fe5e9d7ada0cf69b5bf847ff35bc25fc271f08e1d080fce80339a" +dependencies = [ + "object", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -121,10 +136,11 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.27" +version = "1.2.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d487aa071b5f64da6f19a3e848e3578944b726ee5a4854b82172f02aa876bfdc" +checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" dependencies = [ + "find-msvc-tools", "shlex", ] @@ -134,6 +150,20 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +[[package]] +name = "chumsky" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14377e276b2c8300513dff55ba4cc4142b44e5d6de6d00eb5b2307d650bb4ec1" +dependencies = [ + "hashbrown", + "regex-automata 0.3.9", + "serde", + "stacker", + "unicode-ident", + "unicode-segmentation", +] + [[package]] name = "clap" version = "4.5.40" @@ -189,6 +219,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "convert_case" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baaaa0ecca5b51987b9423ccdc971514dd8b0bb7b4060b983d3664dad3f1f89f" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -320,6 +359,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "find-msvc-tools" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" + [[package]] name = "flate2" version = "1.1.2" @@ -344,6 +389,32 @@ name = "fnv" version = "1.0.7" source = "git+https://github.com/servo/rust-fnv?branch=main#4e55052a343a4372c191141f29a17ab829cf1dbc" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "fzn-rs" +version = "0.1.0" +dependencies = [ + "chumsky", + "fzn-rs-derive", + "thiserror", +] + +[[package]] +name = "fzn-rs-derive" +version = "0.1.0" +dependencies = [ + "convert_case 0.8.0", + "fzn-rs", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -362,6 +433,11 @@ name = "hashbrown" version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heck" @@ -580,6 +656,15 @@ dependencies = [ "libm", ] +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -622,6 +707,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11f2fedc3b7dafdc2851bc52f277377c5473d378859be234bc7ebb593144d01" +dependencies = [ + "ar_archive_writer", + "cc", +] + [[package]] name = "pumpkin-core" version = "0.2.2" @@ -629,7 +724,7 @@ dependencies = [ "bitfield", "bitfield-struct", "clap", - "convert_case", + "convert_case 0.6.0", "downcast-rs", "drcp-format", "enum-map", @@ -795,8 +890,19 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.7.5", ] [[package]] @@ -807,9 +913,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -905,6 +1017,19 @@ dependencies = [ "libc", ] +[[package]] +name = "stacker" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1f8b29fb42aafcea4edeeb6b2f2d7ecd0d969c48b4cf0d2e64aafc471dd6e59" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys", +] + [[package]] name = "stringcase" version = "0.3.0" @@ -919,9 +1044,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "syn" -version = "2.0.103" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index e61605b3b..6e279b2c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,5 @@ [workspace] -members = ["./pumpkin-solver", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./drcp-debugger", "./pumpkin-crates/*"] -default-members = ["./pumpkin-solver", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./pumpkin-crates/*"] +members = ["./pumpkin-solver", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./drcp-debugger", "./pumpkin-crates/*", "./fzn-rs", "./fzn-rs-derive"] resolver = "2" [workspace.package] diff --git a/fzn-rs-derive/Cargo.toml b/fzn-rs-derive/Cargo.toml new file mode 100644 index 000000000..a94eeae7a --- /dev/null +++ b/fzn-rs-derive/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "fzn-rs-derive" +version = "0.1.0" +repository.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true + +[lib] +proc-macro = true + +[dependencies] +convert_case = "0.8.0" +proc-macro2 = "1.0.95" +quote = "1.0.40" +syn = { version = "2.0.104", features = ["extra-traits"] } + +[dev-dependencies] +fzn-rs = { path = "../fzn-rs/" } + +[lints] +workspace = true diff --git a/fzn-rs-derive/src/annotation.rs b/fzn-rs-derive/src/annotation.rs new file mode 100644 index 000000000..d16e9e644 --- /dev/null +++ b/fzn-rs-derive/src/annotation.rs @@ -0,0 +1,114 @@ +use quote::quote; + +/// Construct a token stream that initialises an annotation with name `value_type` and the arguments +/// described in `fields`. +pub(crate) fn initialise_value( + value_type: &syn::Ident, + fields: &syn::Fields, +) -> proc_macro2::TokenStream { + // For every field, initialise the value for that field. + let field_values = fields.iter().enumerate().map(|(idx, field)| { + let ty = &field.ty; + + // If the field has a name, then prepend the field value with the name: `name: value`. + let value_prefix = if let Some(ident) = &field.ident { + quote! { #ident: } + } else { + quote! {} + }; + + // If there is an `#[annotation]` attribute on the field, then the value is the result of + // parsing a nested annotation. Otherwise, we look at the type of the field + // and parse the value corresponding to that type. + if field.attrs.iter().any(|attr| { + attr.path() + .get_ident() + .is_some_and(|ident| ident == "annotation") + }) { + quote! { + #value_prefix <#ty as ::fzn_rs::FromNestedAnnotation>::from_argument( + &arguments[#idx], + )? + } + } else { + quote! { + #value_prefix <#ty as ::fzn_rs::FromAnnotationArgument>::from_argument( + &arguments[#idx], + )? + } + } + }); + + // Complete the value initialiser by prepending the type name to the field values. + let value_initialiser = match fields { + syn::Fields::Named(_) => quote! { #value_type { #(#field_values),* } }, + syn::Fields::Unnamed(_) => quote! { #value_type ( #(#field_values),* ) }, + syn::Fields::Unit => quote! { #value_type }, + }; + + let num_arguments = fields.len(); + + // Output the final initialisation, with checking of number of arguments. + quote! { + if arguments.len() != #num_arguments { + return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #num_arguments, + actual: arguments.len(), + span: annotation.span, + }); + } + + Ok(Some(#value_initialiser)) + } +} + +/// Create the parsing code for one annotation corresponding to the given variant. +pub(crate) fn variant_to_annotation(variant: &syn::Variant) -> proc_macro2::TokenStream { + // Determine the flatzinc annotation name. + let name = match crate::common::get_explicit_name(variant) { + Ok(name) => name, + Err(_) => { + return quote! { + compile_error!("Invalid usage of #[name(...)]"); + }; + } + }; + + let variant_name = &variant.ident; + + // If variant argument is a struct, then delegate parsing of the annotation arguments to that + // struct. + if let Some(constraint_type) = crate::common::get_args_type(variant) { + return quote! { + ::fzn_rs::ast::Annotation::Call(::fzn_rs::ast::AnnotationCall { + name, + arguments, + }) if name.as_ref() == #name => { + let args = <#constraint_type as ::fzn_rs::FlatZincAnnotation>::from_ast_required(annotation)?; + let value = #variant_name(args); + Ok(Some(value)) + } + }; + } + + // If the variant has no arguments, parse an atom annotaton. Otherwise, initialise the values + // of the variant arguments. + if matches!(variant.fields, syn::Fields::Unit) { + quote! { + ::fzn_rs::ast::Annotation::Atom(ident) if ident.as_ref() == #name => { + Ok(Some(#variant_name)) + } + } + } else { + let value = initialise_value(&variant.ident, &variant.fields); + + quote! { + ::fzn_rs::ast::Annotation::Call(::fzn_rs::ast::AnnotationCall { + name, + arguments, + }) if name.as_ref() == #name => { + #value + } + } + } +} diff --git a/fzn-rs-derive/src/common.rs b/fzn-rs-derive/src/common.rs new file mode 100644 index 000000000..6d2a6d83f --- /dev/null +++ b/fzn-rs-derive/src/common.rs @@ -0,0 +1,55 @@ +use convert_case::Case; +use convert_case::Casing; + +/// Get the name of the constraint or annotation from the variant. This either is converting the +/// variant name to snake case, or retrieving the value from the `#[name(...)]` attribute. +pub(crate) fn get_explicit_name(variant: &syn::Variant) -> syn::Result { + variant + .attrs + .iter() + // Find the attribute with a `name` as the path. + .find(|attr| attr.path().get_ident().is_some_and(|ident| ident == "name")) + // Parse the arguments of the attribute to a string literal. + .map(|attr| { + attr.parse_args::() + .map(|string_lit| string_lit.value()) + }) + // If no `name` attribute exists, return the snake-case version of the variant name. + .unwrap_or_else(|| Ok(variant.ident.to_string().to_case(Case::Snake))) +} + +/// Returns the type of the arguments for the variant if the variant has exactly the following +/// shape: +/// +/// ```ignore +/// #[args] +/// Variant(Type) +/// ``` +pub(crate) fn get_args_type(variant: &syn::Variant) -> Option<&syn::Type> { + let has_args_attr = variant + .attrs + .iter() + .any(|attr| attr.path().get_ident().is_some_and(|ident| ident == "args")); + + if !has_args_attr { + return None; + } + + if variant.fields.len() != 1 { + // If there is not exactly one argument for this variant, then it cannot be a struct + // constraint. + return None; + } + + let field = variant + .fields + .iter() + .next() + .expect("there is exactly one field"); + + if field.ident.is_none() { + Some(&field.ty) + } else { + None + } +} diff --git a/fzn-rs-derive/src/constraint.rs b/fzn-rs-derive/src/constraint.rs new file mode 100644 index 000000000..bfb3ef49b --- /dev/null +++ b/fzn-rs-derive/src/constraint.rs @@ -0,0 +1,107 @@ +use quote::quote; + +/// Construct a token stream that initialises a constraint with value name `value_type` and the +/// arguments described in `fields`. +pub(crate) fn initialise_value( + identifier: &syn::Ident, + fields: &syn::Fields, +) -> proc_macro2::TokenStream { + match fields { + // In case of named fields, the order of the fields is the order of the flatzinc arguments. + syn::Fields::Named(fields) => { + let arguments = fields.named.iter().enumerate().map(|(idx, field)| { + let field_name = field + .ident + .as_ref() + .expect("we are in a syn::Fields::Named"); + let ty = &field.ty; + + quote! { + #field_name: <#ty as ::fzn_rs::FromArgument>::from_argument( + &constraint.node.arguments[#idx], + )?, + } + }); + + quote! { #identifier { #(#arguments)* } } + } + + syn::Fields::Unnamed(fields) => { + let arguments = fields.unnamed.iter().enumerate().map(|(idx, field)| { + let ty = &field.ty; + + quote! { + <#ty as ::fzn_rs::FromArgument>::from_argument( + &constraint.node.arguments[#idx], + )?, + } + }); + + quote! { #identifier ( #(#arguments)* ) } + } + + syn::Fields::Unit => quote! { + compile_error!("A FlatZinc constraint must have at least one field") + }, + } +} + +/// Generate an implementation of `FlatZincConstraint` for enums. +pub(crate) fn flatzinc_constraint_for_enum( + constraint_enum_name: &syn::Ident, + data_enum: &syn::DataEnum, +) -> proc_macro2::TokenStream { + // For every variant in the enum, create a match arm that matches the constraint name and + // parses the constraint with the appropriate arguments. + let constraints = data_enum.variants.iter().map(|variant| { + // Determine the flatzinc name of the constraint. + let name = match crate::common::get_explicit_name(variant) { + Ok(name) => name, + Err(_) => { + return quote! { + compile_error!("Invalid usage of #[name(...)]"); + } + } + }; + + let variant_name = &variant.ident; + let match_expression = match crate::common::get_args_type(variant) { + Some(constraint_type) => quote! { + Ok(#variant_name (<#constraint_type as ::fzn_rs::FlatZincConstraint>::from_ast(constraint)?)) + }, + None => { + let initialised_value = initialise_value(variant_name, &variant.fields); + let expected_num_arguments = variant.fields.len(); + + quote! { + if constraint.node.arguments.len() != #expected_num_arguments { + return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #expected_num_arguments, + actual: constraint.node.arguments.len(), + span: constraint.span, + }); + } + + Ok(#initialised_value) + } + } + }; + + quote! { + #name => { + #match_expression + } + } + }); + + quote! { + use #constraint_enum_name::*; + + match constraint.node.name.node.as_ref() { + #(#constraints)* + unknown => Err(::fzn_rs::InstanceError::UnsupportedConstraint( + String::from(unknown) + )), + } + } +} diff --git a/fzn-rs-derive/src/lib.rs b/fzn-rs-derive/src/lib.rs new file mode 100644 index 000000000..0fd29dedf --- /dev/null +++ b/fzn-rs-derive/src/lib.rs @@ -0,0 +1,115 @@ +mod annotation; +mod common; +mod constraint; + +use proc_macro::TokenStream; +use quote::quote; +use syn::DeriveInput; +use syn::parse_macro_input; + +#[proc_macro_derive(FlatZincConstraint, attributes(name, args))] +pub fn derive_flatzinc_constraint(item: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(item as DeriveInput); + + let type_name = derive_input.ident; + let implementation = match &derive_input.data { + syn::Data::Struct(data_struct) => { + let expected_num_arguments = data_struct.fields.len(); + + let struct_initialiser = constraint::initialise_value(&type_name, &data_struct.fields); + quote! { + if constraint.node.arguments.len() != #expected_num_arguments { + return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #expected_num_arguments, + actual: constraint.node.arguments.len(), + span: constraint.span, + }); + } + + Ok(#struct_initialiser) + } + } + syn::Data::Enum(data_enum) => { + constraint::flatzinc_constraint_for_enum(&type_name, data_enum) + } + syn::Data::Union(_) => quote! { + compile_error!("Cannot implement FlatZincConstraint on unions.") + }, + }; + + let token_stream = quote! { + #[automatically_derived] + impl ::fzn_rs::FlatZincConstraint for #type_name { + fn from_ast( + constraint: &::fzn_rs::ast::Node<::fzn_rs::ast::Constraint>, + ) -> Result { + #implementation + } + } + }; + + token_stream.into() +} + +#[proc_macro_derive(FlatZincAnnotation, attributes(name, annotation, args))] +pub fn derive_flatzinc_annotation(item: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(item as DeriveInput); + let annotatation_enum_name = derive_input.ident; + + let implementation = match derive_input.data { + syn::Data::Struct(data_struct) => { + let initialised_values = + annotation::initialise_value(&annotatation_enum_name, &data_struct.fields); + + let expected_num_arguments = data_struct.fields.len(); + + quote! { + match &annotation.node { + ::fzn_rs::ast::Annotation::Call(::fzn_rs::ast::AnnotationCall { + name, + arguments, + }) => { + #initialised_values + } + + _ => return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #expected_num_arguments, + actual: 0, + span: annotation.span, + }), + } + } + } + syn::Data::Enum(data_enum) => { + let annotations = data_enum + .variants + .iter() + .map(annotation::variant_to_annotation); + + quote! { + use #annotatation_enum_name::*; + + match &annotation.node { + #(#annotations),* + _ => Ok(None), + } + } + } + syn::Data::Union(_) => quote! { + compile_error!("Cannot implement FlatZincAnnotation on unions.") + }, + }; + + let token_stream = quote! { + #[automatically_derived] + impl ::fzn_rs::FlatZincAnnotation for #annotatation_enum_name { + fn from_ast( + annotation: &::fzn_rs::ast::Node<::fzn_rs::ast::Annotation>, + ) -> Result, ::fzn_rs::InstanceError> { + #implementation + } + } + }; + + token_stream.into() +} diff --git a/fzn-rs-derive/tests/derive_flatzinc_annotation.rs b/fzn-rs-derive/tests/derive_flatzinc_annotation.rs new file mode 100644 index 000000000..f6e5a1ef8 --- /dev/null +++ b/fzn-rs-derive/tests/derive_flatzinc_annotation.rs @@ -0,0 +1,397 @@ +#![cfg(test)] // workaround for https://github.com/rust-lang/rust-clippy/issues/11024 + +mod utils; + +use std::collections::BTreeMap; +use std::rc::Rc; + +use fzn_rs::ArrayExpr; +use fzn_rs::TypedInstance; +use fzn_rs::VariableExpr; +use fzn_rs::ast::Annotation; +use fzn_rs::ast::AnnotationArgument; +use fzn_rs::ast::AnnotationCall; +use fzn_rs::ast::AnnotationLiteral; +use fzn_rs::ast::Argument; +use fzn_rs::ast::Ast; +use fzn_rs::ast::Domain; +use fzn_rs::ast::Literal; +use fzn_rs::ast::RangeList; +use fzn_rs::ast::Variable; +use fzn_rs_derive::FlatZincAnnotation; +use fzn_rs_derive::FlatZincConstraint; +use utils::*; + +macro_rules! btreemap { + ($($key:expr => $value:expr,)+) => (btreemap!($($key => $value),+)); + + ( $($key:expr => $value:expr),* ) => { + { + let mut _map = ::std::collections::BTreeMap::new(); + $( + let _ = _map.insert($key, $value); + )* + _map + } + }; +} + +#[test] +fn annotation_without_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + OutputVar, + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Atom("output_var".into()))], + })], + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::OutputVar, + ); +} + +#[test] +fn annotation_with_positional_literal_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + DefinesVar(Rc), + OutputArray(RangeList), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![ + test_node(Annotation::Call(AnnotationCall { + name: "defines_var".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("some_var".into())), + )))], + })), + test_node(Annotation::Call(AnnotationCall { + name: "output_array".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::IntSet(RangeList::from(1..=5))), + )))], + })), + ], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::DefinesVar("some_var".into()), + ); + + assert_eq!( + instance.constraints[0].annotations[1].node, + TypedAnnotation::OutputArray(RangeList::from(1..=5)), + ); +} + +#[test] +fn annotation_with_named_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + DefinesVar { variable_id: Rc }, + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "defines_var".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("some_var".into())), + )))], + }))], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::DefinesVar { + variable_id: "some_var".into() + }, + ); +} + +#[test] +fn nested_annotation_as_argument() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + SomeAnnotation(#[annotation] SomeAnnotationArgs), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum SomeAnnotationArgs { + ArgOne, + ArgTwo(Rc), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![ + test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("arg_one".into())), + )))], + })), + test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::Annotation(AnnotationCall { + name: "arg_two".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("ident".into())), + )))], + }), + )))], + })), + ], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::SomeAnnotation(SomeAnnotationArgs::ArgOne), + ); + + assert_eq!( + instance.constraints[0].annotations[1].node, + TypedAnnotation::SomeAnnotation(SomeAnnotationArgs::ArgTwo("ident".into())), + ); +} + +#[test] +fn arrays_as_annotation_arguments_with_literal_elements() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + SomeAnnotation(ArrayExpr), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Array(vec![ + test_node(AnnotationLiteral::BaseLiteral(Literal::Int(1))), + test_node(AnnotationLiteral::BaseLiteral(Literal::Int(2))), + ]))], + }))], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + let TypedAnnotation::SomeAnnotation(args) = instance.constraints[0].annotations[0].node.clone(); + + let resolved_args = instance + .resolve_array(&args) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(resolved_args, vec![1, 2]); +} + +#[test] +fn arrays_as_annotation_arguments_with_annotation_elements() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + SomeAnnotation(#[annotation] Vec), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum ArrayElements { + ElementOne, + ElementTwo(i64), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Array(vec![ + test_node(AnnotationLiteral::BaseLiteral(Literal::Identifier( + "element_one".into(), + ))), + test_node(AnnotationLiteral::Annotation(AnnotationCall { + name: "element_two".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Int(4)), + )))], + })), + ]))], + }))], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::SomeAnnotation(vec![ + ArrayElements::ElementOne, + ArrayElements::ElementTwo(4) + ]), + ); +} + +#[test] +fn annotations_can_be_structs_for_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + #[args] + SomeAnnotation(AnnotationArgs), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + struct AnnotationArgs { + ident: Rc, + #[annotation] + ann: OtherAnnotation, + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum OtherAnnotation { + ElementOne, + ElementTwo(i64), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: btreemap! { + "x1".into() => test_node(Variable { + domain: test_node(Domain::UnboundedInt), + value: None, + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![ + test_node(AnnotationArgument::Literal(test_node(AnnotationLiteral::BaseLiteral(Literal::Identifier("some_ident".into()))))), + test_node(AnnotationArgument::Literal(test_node(AnnotationLiteral::BaseLiteral(Literal::Identifier("element_one".into()))))), + ], + }))], + }), + }, + arrays: BTreeMap::new(), + constraints: vec![], + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.variables["x1"].annotations[0].node, + TypedAnnotation::SomeAnnotation(AnnotationArgs { + ident: "some_ident".into(), + ann: OtherAnnotation::ElementOne, + }), + ); +} diff --git a/fzn-rs-derive/tests/derive_flatzinc_constraint.rs b/fzn-rs-derive/tests/derive_flatzinc_constraint.rs new file mode 100644 index 000000000..f0662bad2 --- /dev/null +++ b/fzn-rs-derive/tests/derive_flatzinc_constraint.rs @@ -0,0 +1,489 @@ +#![cfg(test)] // workaround for https://github.com/rust-lang/rust-clippy/issues/11024 + +mod utils; + +use std::collections::BTreeMap; + +use fzn_rs::ArrayExpr; +use fzn_rs::InstanceError; +use fzn_rs::TypedInstance; +use fzn_rs::VariableExpr; +use fzn_rs::ast::Argument; +use fzn_rs::ast::Array; +use fzn_rs::ast::Ast; +use fzn_rs::ast::Literal; +use fzn_rs::ast::Span; +use fzn_rs_derive::FlatZincConstraint; +use utils::*; + +#[test] +fn variant_with_unnamed_fields() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + IntLinLe(ArrayExpr, ArrayExpr>, i64), + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Array(vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ])), + test_node(Argument::Array(vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ])), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + let TypedConstraint::IntLinLe(weights, variables, bound) = + instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn variant_with_named_fields() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + IntLinLe { + weights: ArrayExpr, + variables: ArrayExpr>, + bound: i64, + }, + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Array(vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ])), + test_node(Argument::Array(vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ])), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + let TypedConstraint::IntLinLe { + weights, + variables, + bound, + } = instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn variant_with_name_attribute() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + #[name("int_lin_le")] + LinearInequality { + weights: ArrayExpr, + variables: ArrayExpr>, + bound: i64, + }, + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Array(vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ])), + test_node(Argument::Array(vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ])), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + let TypedConstraint::LinearInequality { + weights, + variables, + bound, + } = instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn constraint_referencing_arrays() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + IntLinLe(ArrayExpr, ArrayExpr>, i64), + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: [ + ( + "array1".into(), + test_node(Array { + domain: test_node(fzn_rs::ast::Domain::UnboundedInt), + contents: vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ], + annotations: vec![], + }), + ), + ( + "array2".into(), + test_node(Array { + domain: test_node(fzn_rs::ast::Domain::UnboundedInt), + contents: vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ], + annotations: vec![], + }), + ), + ] + .into_iter() + .collect(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Identifier( + "array1".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Identifier( + "array2".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + + let TypedConstraint::IntLinLe(weights, variables, bound) = + instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn constraint_as_struct_args() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + #[args] + IntLinLe(LinearLeq), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + struct LinearLeq { + weights: ArrayExpr, + variables: ArrayExpr>, + bound: i64, + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: [ + ( + "array1".into(), + test_node(Array { + domain: test_node(fzn_rs::ast::Domain::UnboundedInt), + contents: vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ], + annotations: vec![], + }), + ), + ( + "array2".into(), + test_node(Array { + domain: test_node(fzn_rs::ast::Domain::UnboundedInt), + contents: vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ], + annotations: vec![], + }), + ), + ] + .into_iter() + .collect(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Identifier( + "array1".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Identifier( + "array2".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + + let TypedConstraint::IntLinLe(linear) = instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&linear.weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&linear.variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(linear.bound, 3); +} + +#[test] +fn argument_count_on_tuple_variants() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(i64), + } + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![node( + fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Int(3)))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + }, + 0, + 10, + )], + solve: satisfy_solve(), + }; + + let error = TypedInstance::::from_ast(ast).expect_err("invalid instance"); + + assert_eq!( + error, + InstanceError::IncorrectNumberOfArguments { + expected: 1, + actual: 2, + span: Span { start: 0, end: 10 }, + } + ); +} + +#[test] +fn argument_count_on_named_fields_variant() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint { constant: i64 }, + } + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![node( + fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Int(3)))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + }, + 0, + 10, + )], + solve: satisfy_solve(), + }; + + let error = TypedInstance::::from_ast(ast).expect_err("invalid instance"); + + assert_eq!( + error, + InstanceError::IncorrectNumberOfArguments { + expected: 1, + actual: 2, + span: Span { start: 0, end: 10 }, + } + ); +} + +#[test] +fn argument_count_on_args_struct() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + #[args] + SomeConstraint(Args), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + struct Args { + argument: i64, + } + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![node( + fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Int(3)))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + }, + 0, + 10, + )], + solve: satisfy_solve(), + }; + + let error = TypedInstance::::from_ast(ast).expect_err("invalid instance"); + + assert_eq!( + error, + InstanceError::IncorrectNumberOfArguments { + expected: 1, + actual: 2, + span: Span { start: 0, end: 10 }, + } + ); +} diff --git a/fzn-rs-derive/tests/utils.rs b/fzn-rs-derive/tests/utils.rs new file mode 100644 index 000000000..8de6e5edb --- /dev/null +++ b/fzn-rs-derive/tests/utils.rs @@ -0,0 +1,49 @@ +#![allow( + dead_code, + reason = "it is used in other test files, but somehow compiler can't see it" +)] +#![cfg(test)] + +use std::collections::BTreeMap; +use std::rc::Rc; + +use fzn_rs::ast::{self}; + +pub(crate) fn satisfy_solve() -> ast::SolveItem { + ast::SolveItem { + method: test_node(ast::Method::Satisfy), + annotations: vec![], + } +} + +pub(crate) fn test_node(data: T) -> ast::Node { + node(data, usize::MAX, usize::MAX) +} + +pub(crate) fn node(data: T, span_start: usize, span_end: usize) -> ast::Node { + ast::Node { + node: data, + span: ast::Span { + start: span_start, + end: span_end, + }, + } +} + +pub(crate) fn unbounded_variables<'a>( + names: impl IntoIterator, +) -> BTreeMap, ast::Node>> { + names + .into_iter() + .map(|name| { + ( + Rc::from(name), + test_node(ast::Variable { + domain: test_node(ast::Domain::UnboundedInt), + value: None, + annotations: vec![], + }), + ) + }) + .collect() +} diff --git a/fzn-rs/Cargo.toml b/fzn-rs/Cargo.toml new file mode 100644 index 000000000..2656fb78d --- /dev/null +++ b/fzn-rs/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "fzn-rs" +version = "0.1.0" +repository.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true + +[dependencies] +chumsky = { version = "0.10.1" } +thiserror = "2.0.12" +fzn-rs-derive = { path = "../fzn-rs-derive/" } + +[lints] +workspace = true diff --git a/fzn-rs/src/ast.rs b/fzn-rs/src/ast.rs new file mode 100644 index 000000000..9b23dd148 --- /dev/null +++ b/fzn-rs/src/ast.rs @@ -0,0 +1,380 @@ +//! The AST representing a FlatZinc instance, compatible with both the JSON format and +//! the original FZN format. +//! +//! It is a modified version of the `FlatZinc` type from [`flatzinc-serde`](https://docs.rs/flatzinc-serde). +use std::collections::BTreeMap; +use std::fmt::Display; +use std::iter::FusedIterator; +use std::ops::RangeInclusive; +use std::rc::Rc; + +/// Represents a FlatZinc instance. +/// +/// In the `.fzn` format, identifiers can point to both constants and variables (either single or +/// arrays). In this AST, the constants are immediately resolved and are not kept in their original +/// form. Therefore, any [`Literal::Identifier`] points to a variable or an array. +/// +/// All identifiers are [`Rc`]s to allow parsers to re-use the allocation of the variable name. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Ast { + /// A mapping from identifiers to variables. + pub variables: BTreeMap, Node>>, + /// The arrays in this instance. + pub arrays: BTreeMap, Node>>, + /// A list of constraints. + pub constraints: Vec>, + /// The goal of the model. + pub solve: SolveItem, +} + +/// A decision variable. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Variable { + /// The domain of the variable. + pub domain: Node, + /// Optionally, the value that the variable is equal to. + pub value: Option>, + /// The annotations on this variable. + pub annotations: Vec>, +} + +/// A named array of literals. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Array { + /// The domain of the elements of the array. + pub domain: Node, + /// The elements of the array. + pub contents: Vec>, + /// The annotations associated with this array. + pub annotations: Vec>, +} + +/// The domain of a [`Variable`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Domain { + /// The set of all integers. + UnboundedInt, + /// A finite set of integer values. + Int(RangeList), + /// A boolean domain. + Bool, +} + +/// Holds a non-empty set of values. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RangeList { + /// A sorted list of intervals. + /// + /// Invariant: Consecutive intervals are merged. + intervals: Vec<(E, E)>, +} + +impl RangeList { + /// The smallest element in the set. + pub fn lower_bound(&self) -> &E { + &self.intervals[0].0 + } + + /// The largest element in the set. + pub fn upper_bound(&self) -> &E { + let last_idx = self.intervals.len() - 1; + + &self.intervals[last_idx].1 + } + + /// Returns `true` if the set is a continious range from [`Self::lower_bound`] to + /// [`Self::upper_bound`]. + pub fn is_continuous(&self) -> bool { + self.intervals.len() == 1 + } +} + +macro_rules! impl_iter_fn { + ($int_type:ty) => { + impl<'a> IntoIterator for &'a RangeList<$int_type> { + type Item = $int_type; + + type IntoIter = RangeListIter<'a, $int_type>; + + fn into_iter(self) -> Self::IntoIter { + RangeListIter { + current_interval: self.intervals.first().copied().unwrap_or((1, 0)), + tail: &self.intervals[1..], + } + } + } + }; +} + +impl_iter_fn!(i32); +impl_iter_fn!(i64); + +impl From> for RangeList { + fn from(value: RangeInclusive) -> Self { + RangeList { + intervals: vec![(*value.start(), *value.end())], + } + } +} + +macro_rules! range_list_from_iter { + ($int_type:ty) => { + impl FromIterator<$int_type> for RangeList<$int_type> { + fn from_iter>(iter: T) -> Self { + let mut intervals: Vec<_> = iter.into_iter().map(|e| (e, e)).collect(); + intervals.sort(); + intervals.dedup(); + + let mut idx = 0; + + while idx < intervals.len() - 1 { + let current = intervals[idx]; + let next = intervals[idx + 1]; + + if current.1 >= next.0 - 1 { + intervals[idx] = (current.0, next.1); + let _ = intervals.remove(idx + 1); + } else { + idx += 1; + } + } + + RangeList { intervals } + } + } + }; +} + +range_list_from_iter!(i32); +range_list_from_iter!(i64); + +/// An [`Iterator`] over a [`RangeList`]. +#[derive(Debug)] +pub struct RangeListIter<'a, E> { + current_interval: (E, E), + tail: &'a [(E, E)], +} + +macro_rules! impl_range_list_iter { + ($int_type:ty) => { + impl Iterator for RangeListIter<'_, $int_type> { + type Item = $int_type; + + fn next(&mut self) -> Option { + let (current_lb, current_ub) = self.current_interval; + + if current_lb > current_ub { + let (next_interval, new_tail) = self.tail.split_first()?; + self.current_interval = *next_interval; + self.tail = new_tail; + } + + let current_lb = self.current_interval.0; + self.current_interval.0 += 1; + + Some(current_lb) + } + } + + impl FusedIterator for RangeListIter<'_, $int_type> {} + }; +} + +impl_range_list_iter!(i32); +impl_range_list_iter!(i64); + +/// The foundational element from which expressions are built. Literals are the values/identifiers +/// in the model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Literal { + Int(i64), + Identifier(Rc), + Bool(bool), + IntSet(RangeList), +} + +impl From for Literal { + fn from(value: i64) -> Self { + Literal::Int(value) + } +} + +/// The solve item. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SolveItem { + pub method: Node, + pub annotations: Vec>, +} + +/// Whether to satisfy or optimise the model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Method { + Satisfy, + Optimize { + direction: OptimizationDirection, + objective: Literal, + }, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum OptimizationDirection { + Minimize, + Maximize, +} + +/// A constraint definition. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Constraint { + /// The name of the constraint. + pub name: Node>, + /// The list of arguments. + pub arguments: Vec>, + /// Any annotations on the constraint. + pub annotations: Vec>, +} + +/// An argument for a [`Constraint`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Argument { + Array(Vec>), + Literal(Node), +} + +/// An annotation on any item in the model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Annotation { + /// An annotation without arguments. + Atom(Rc), + /// An annotation with arguments. + Call(AnnotationCall), +} + +impl Annotation { + /// Get the name of the annotation. + pub fn name(&self) -> &str { + match self { + Annotation::Atom(name) => name, + Annotation::Call(call) => &call.name, + } + } +} + +/// An annotation with arguments. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct AnnotationCall { + /// The name of the annotation. + pub name: Rc, + /// Any arguments for the annotation. + pub arguments: Vec>, +} + +/// An individual argument for an [`Annotation`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum AnnotationArgument { + Array(Vec>), + Literal(Node), +} + +/// An annotation literal is either a regular [`Literal`] or it is another annotation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum AnnotationLiteral { + BaseLiteral(Literal), + /// In the FZN grammar, this is an `Annotation` instead of an `AnnotationCall`. We diverge from + /// the grammar to avoid the case where the same input can parse to either a + /// `Annotation::Atom(ident)` or an `Literal::Identifier`. + Annotation(AnnotationCall), +} + +/// Describes a range `[start, end)` in the model file that contains a [`Node`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Span { + /// The index in the source that starts the span. + pub start: usize, + /// The index in the source that ends the span. + /// + /// Note the end is exclusive. + pub end: usize, +} + +impl Display for Span { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}, {})", self.start, self.end) + } +} + +impl chumsky::span::Span for Span { + type Context = (); + + type Offset = usize; + + fn new(_: Self::Context, range: std::ops::Range) -> Self { + Self { + start: range.start, + end: range.end, + } + } + + fn context(&self) -> Self::Context {} + + fn start(&self) -> Self::Offset { + self.start + } + + fn end(&self) -> Self::Offset { + self.end + } +} + +impl From for Span { + fn from(value: chumsky::span::SimpleSpan) -> Self { + Span { + start: value.start, + end: value.end, + } + } +} + +impl From for chumsky::span::SimpleSpan { + fn from(value: Span) -> Self { + chumsky::span::SimpleSpan::from(value.start..value.end) + } +} + +/// A node in the [`Ast`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Node { + /// The span in the source of this node. + pub span: Span, + /// The parsed node. + pub node: T, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rangelist_from_iter_identifies_continuous_ranges() { + let set = RangeList::from_iter([1, 2, 3, 4]); + + assert!(set.is_continuous()); + } + + #[test] + fn rangelist_from_iter_identifiers_non_continuous_ranges() { + let set = RangeList::from_iter([1, 3, 4, 6]); + + assert!(!set.is_continuous()); + } + + #[test] + fn rangelist_iter_produces_elements_in_set() { + let set: RangeList = RangeList::from_iter([1, 3, 5]); + + let mut iter = set.into_iter(); + assert_eq!(Some(1), iter.next()); + assert_eq!(Some(3), iter.next()); + assert_eq!(Some(5), iter.next()); + assert_eq!(None, iter.next()); + } +} diff --git a/fzn-rs/src/error.rs b/fzn-rs/src/error.rs new file mode 100644 index 000000000..337e7c681 --- /dev/null +++ b/fzn-rs/src/error.rs @@ -0,0 +1,92 @@ +use std::fmt::Display; + +use crate::ast; + +#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +pub enum InstanceError { + #[error("constraint '{0}' is not supported")] + UnsupportedConstraint(String), + + #[error("annotation '{0}' is not supported")] + UnsupportedAnnotation(String), + + #[error("expected {expected}, got {actual} at {span}")] + UnexpectedToken { + expected: Token, + actual: Token, + span: ast::Span, + }, + + #[error("array {0} is undefined")] + UndefinedArray(String), + + #[error("expected {expected} arguments, got {actual}")] + IncorrectNumberOfArguments { + expected: usize, + actual: usize, + span: ast::Span, + }, + + #[error("value {0} does not fit in the required integer type")] + IntegerOverflow(i64), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Token { + Identifier, + IntLiteral, + BoolLiteral, + IntSetLiteral, + + Array, + Variable(Box), + AnnotationCall, + Annotation, +} + +impl From<&'_ ast::Literal> for Token { + fn from(value: &'_ ast::Literal) -> Self { + match value { + ast::Literal::Int(_) => Token::IntLiteral, + ast::Literal::Identifier(_) => Token::Identifier, + ast::Literal::Bool(_) => Token::BoolLiteral, + ast::Literal::IntSet(_) => Token::IntSetLiteral, + } + } +} + +impl From<&'_ ast::AnnotationArgument> for Token { + fn from(value: &'_ ast::AnnotationArgument) -> Self { + match value { + ast::AnnotationArgument::Array(_) => Token::Array, + ast::AnnotationArgument::Literal(literal) => (&literal.node).into(), + } + } +} + +impl From<&'_ ast::AnnotationLiteral> for Token { + fn from(value: &'_ ast::AnnotationLiteral) -> Self { + match value { + ast::AnnotationLiteral::BaseLiteral(literal) => literal.into(), + ast::AnnotationLiteral::Annotation(_) => Token::AnnotationCall, + } + } +} + +impl Display for Token { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::Identifier => write!(f, "identifier"), + + Token::IntLiteral => write!(f, "int"), + Token::BoolLiteral => write!(f, "bool"), + Token::IntSetLiteral => write!(f, "int set"), + + Token::AnnotationCall => write!(f, "annotation"), + Token::Annotation => write!(f, "annotation"), + + Token::Array => write!(f, "array"), + Token::Variable(token) => write!(f, "{token} variable"), + } + } +} diff --git a/fzn-rs/src/fzn/mod.rs b/fzn-rs/src/fzn/mod.rs new file mode 100644 index 000000000..235c7f26c --- /dev/null +++ b/fzn-rs/src/fzn/mod.rs @@ -0,0 +1,1115 @@ +use std::collections::BTreeMap; +use std::collections::BTreeSet; +use std::rc::Rc; + +use chumsky::IterParser; +use chumsky::Parser; +use chumsky::error::Rich; +use chumsky::extra; +use chumsky::input::Input; +use chumsky::input::MapExtra; +use chumsky::input::ValueInput; +use chumsky::prelude::any; +use chumsky::prelude::choice; +use chumsky::prelude::just; +use chumsky::prelude::recursive; +use chumsky::select; +use chumsky::span::SimpleSpan; + +use crate::ast; + +mod tokens; + +pub use tokens::Token; +use tokens::Token::*; + +#[derive(Clone, Debug, Default)] +struct ParseState { + /// The identifiers encountered so far. + strings: BTreeSet>, + /// Parameters + parameters: BTreeMap, ParameterValue>, +} + +impl ParseState { + fn get_interned(&mut self, string: &str) -> Rc { + if !self.strings.contains(string) { + let _ = self.strings.insert(Rc::from(string)); + } + + Rc::clone(self.strings.get(string).unwrap()) + } + + fn resolve_literal(&self, literal: ast::Literal) -> ast::Literal { + match literal { + ast::Literal::Identifier(ident) => self + .parameters + .get(&ident) + .map(|value| match value { + ParameterValue::Bool(boolean) => ast::Literal::Bool(*boolean), + ParameterValue::Int(int) => ast::Literal::Int(*int), + ParameterValue::IntSet(set) => ast::Literal::IntSet(set.clone()), + }) + .unwrap_or(ast::Literal::Identifier(ident)), + + lit @ (ast::Literal::Int(_) | ast::Literal::Bool(_) | ast::Literal::IntSet(_)) => lit, + } + } +} + +#[derive(Clone, Debug)] +enum ParameterValue { + Bool(bool), + Int(i64), + IntSet(ast::RangeList), +} + +#[derive(Debug, thiserror::Error)] +pub enum FznError<'src> { + #[error("failed to lex fzn")] + LexError { + reasons: Vec>, + }, + + #[error("failed to parse fzn")] + ParseError { + reasons: Vec, ast::Span>>, + }, +} + +pub fn parse(source: &str) -> Result> { + let mut state = extra::SimpleState(ParseState::default()); + + let tokens = tokens::lex() + .parse(source) + .into_result() + .map_err(|reasons| FznError::LexError { reasons })?; + + let parser_input = tokens.map( + ast::Span { + start: source.len(), + end: source.len(), + }, + |node| (&node.node, &node.span), + ); + + let ast = predicates() + .ignore_then(parameters()) + .ignore_then(arrays()) + .then(variables()) + .then(arrays()) + .then(constraints()) + .then(solve_item()) + .map( + |((((parameter_arrays, variables), variable_arrays), constraints), solve)| { + let mut arrays = parameter_arrays; + arrays.extend(variable_arrays); + + ast::Ast { + variables, + arrays, + constraints, + solve, + } + }, + ) + .parse_with_state(parser_input, &mut state) + .into_result() + .map_err( + |reasons: Vec, _>>| FznError::ParseError { + reasons: reasons + .into_iter() + .map(|error| error.into_owned()) + .collect(), + }, + )?; + + Ok(ast) +} + +/// The extra data attached to the chumsky parsers. +/// +/// We specify a rich error type, as well as an instance of [`ParseState`] for string interning and +/// parameter resolution. +type FznExtra<'tokens, 'src> = + extra::Full, ast::Span>, extra::SimpleState, ()>; + +fn predicates<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + predicate().repeated().collect::>().ignored() +} + +fn predicate<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("predicate")) + .ignore_then(any().and_is(just(SemiColon).not()).repeated()) + .then(just(SemiColon)) + .ignored() +} + +fn parameters<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + parameter().repeated().collect::>().ignored() +} + +fn parameter_type<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + just(Ident("int")), + just(Ident("bool")), + just(Ident("set")) + .then_ignore(just(Ident("of"))) + .then_ignore(just(Ident("int"))), + )) + .ignored() +} + +fn parameter<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + parameter_type() + .ignore_then(just(Colon)) + .ignore_then(identifier()) + .then_ignore(just(Equal)) + .then(literal()) + .then_ignore(just(SemiColon)) + .try_map_with(|(name, value), extra| { + let state = extra.state(); + + let value = match value.node { + ast::Literal::Int(int) => ParameterValue::Int(int), + ast::Literal::Bool(boolean) => ParameterValue::Bool(boolean), + ast::Literal::IntSet(set) => ParameterValue::IntSet(set), + ast::Literal::Identifier(identifier) => { + return Err(Rich::custom( + value.span, + format!("parameter '{identifier}' is undefined"), + )); + } + }; + + let _ = state.parameters.insert(name, value); + + Ok(()) + }) +} + +fn arrays<'tokens, 'src: 'tokens, I>() -> impl Parser< + 'tokens, + I, + BTreeMap, ast::Node>>, + FznExtra<'tokens, 'src>, +> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + array() + .repeated() + .collect::>() + .map(|arrays| arrays.into_iter().collect()) +} + +fn array<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, (Rc, ast::Node>), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("array")) + .ignore_then(interval_set(integer()).delimited_by(just(OpenBracket), just(CloseBracket))) + .ignore_then(just(Ident("of"))) + .ignore_then(just(Ident("var")).or_not()) + .ignore_then(domain()) + .then_ignore(just(Colon)) + .then(identifier()) + .then(annotations()) + .then_ignore(just(Equal)) + .then( + literal() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBracket), just(CloseBracket)), + ) + .then_ignore(just(SemiColon)) + .map_with(|(((domain, name), annotations), contents), extra| { + ( + name, + ast::Node { + node: ast::Array { + domain, + contents, + annotations, + }, + span: extra.span(), + }, + ) + }) +} + +fn variables<'tokens, 'src: 'tokens, I>() -> impl Parser< + 'tokens, + I, + BTreeMap, ast::Node>>, + FznExtra<'tokens, 'src>, +> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + variable() + .repeated() + .collect::>() + .map(|variables| variables.into_iter().collect()) +} + +fn variable<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, (Rc, ast::Node>), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("var")) + .ignore_then(domain()) + .then_ignore(just(Colon)) + .then(identifier()) + .then(annotations()) + .then(just(Equal).ignore_then(literal()).or_not()) + .then_ignore(just(SemiColon)) + .map_with(to_node) + .map(|node| { + let ast::Node { + node: (((domain, name), annotations), value), + span, + } = node; + + let variable = ast::Variable { + domain, + value, + annotations, + }; + + ( + name, + ast::Node { + node: variable, + span, + }, + ) + }) +} + +fn domain<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + just(Ident("int")).to(ast::Domain::UnboundedInt), + just(Ident("bool")).to(ast::Domain::Bool), + set_of(integer()).map(ast::Domain::Int), + )) + .map_with(to_node) +} + +fn constraints<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, Vec>, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + constraint().repeated().collect::>() +} + +fn constraint<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("constraint")) + .ignore_then(identifier().map_with(to_node)) + .then( + argument() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenParen), just(CloseParen)), + ) + .then(annotations()) + .then_ignore(just(SemiColon)) + .map(|((name, arguments), annotations)| ast::Constraint { + name, + arguments, + annotations, + }) + .map_with(to_node) +} + +fn argument<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + literal().map(ast::Argument::Literal), + literal() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBracket), just(CloseBracket)) + .map(ast::Argument::Array), + )) + .map_with(to_node) +} + +fn solve_item<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::SolveItem, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("solve")) + .ignore_then(annotations()) + .then(solve_method()) + .then_ignore(just(SemiColon)) + .map(|(annotations, method)| ast::SolveItem { + method, + annotations, + }) +} + +fn solve_method<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + just(Ident("satisfy")).to(ast::Method::Satisfy), + just(Ident("minimize")) + .ignore_then(identifier()) + .map(|ident| ast::Method::Optimize { + direction: ast::OptimizationDirection::Minimize, + objective: ast::Literal::Identifier(ident), + }), + just(Ident("maximize")) + .ignore_then(identifier()) + .map(|ident| ast::Method::Optimize { + direction: ast::OptimizationDirection::Maximize, + objective: ast::Literal::Identifier(ident), + }), + )) + .map_with(to_node) +} + +fn annotations<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, Vec>, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + annotation().repeated().collect() +} + +fn annotation<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(DoubleColon) + .ignore_then(choice(( + annotation_call().map(ast::Annotation::Call), + identifier().map(ast::Annotation::Atom), + ))) + .map_with(to_node) +} + +fn annotation_call<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::AnnotationCall, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + recursive(|call| { + identifier() + .then( + annotation_argument(call) + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenParen), just(CloseParen)), + ) + .map(|(name, arguments)| ast::AnnotationCall { name, arguments }) + }) +} + +fn annotation_argument<'tokens, 'src: 'tokens, I>( + call_parser: impl Parser<'tokens, I, ast::AnnotationCall, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + annotation_literal(call_parser.clone()).map(ast::AnnotationArgument::Literal), + annotation_literal(call_parser) + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBracket), just(CloseBracket)) + .map(ast::AnnotationArgument::Array), + )) + .map_with(to_node) +} + +fn annotation_literal<'tokens, 'src: 'tokens, I>( + call_parser: impl Parser<'tokens, I, ast::AnnotationCall, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + call_parser + .map(ast::AnnotationLiteral::Annotation) + .map_with(to_node), + literal().map(|node| ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(node.node), + span: node.span, + }), + )) +} + +fn to_node<'tokens, 'src: 'tokens, I, T>( + node: T, + extra: &mut MapExtra<'tokens, '_, I, FznExtra<'tokens, 'src>>, +) -> ast::Node +where + I: Input<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + ast::Node { + node, + span: extra.span(), + } +} + +fn literal<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + set_of(integer()).map(ast::Literal::IntSet), + integer().map(ast::Literal::Int), + boolean().map(ast::Literal::Bool), + identifier().map(ast::Literal::Identifier), + )) + .map_with(|literal, extra| { + let state = extra.state(); + state.resolve_literal(literal) + }) + .map_with(to_node) +} + +fn set_of<'tokens, 'src: 'tokens, I, T: Copy + Ord>( + value_parser: impl Parser<'tokens, I, T, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, ast::RangeList, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, + ast::RangeList: FromIterator, +{ + let sparse_set = value_parser + .clone() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBrace), just(CloseBrace)) + .map(ast::RangeList::from_iter); + + choice(( + sparse_set, + interval_set(value_parser).map(|(lb, ub)| ast::RangeList::from(lb..=ub)), + )) +} + +fn interval_set<'tokens, 'src: 'tokens, I, T: Copy + Ord>( + value_parser: impl Parser<'tokens, I, T, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, (T, T), FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + value_parser + .clone() + .then_ignore(just(DoublePeriod)) + .then(value_parser) +} + +fn integer<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, i64, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + select! { + Integer(int) => int, + } +} + +fn boolean<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, bool, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + select! { + Boolean(boolean) => boolean, + } +} + +fn identifier<'tokens, 'src: 'tokens, I>() +-> impl Parser<'tokens, I, Rc, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + select! { + Ident(ident) => ident, + } + .map_with(|ident, extra| { + let state: &mut extra::SimpleState = extra.state(); + state.get_interned(ident) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! btreemap { + ($($key:expr => $value:expr,)+) => (btreemap!($($key => $value),+)); + + ( $($key:expr => $value:expr),* ) => { + { + let mut _map = ::std::collections::BTreeMap::new(); + $( + let _ = _map.insert($key, $value); + )* + _map + } + }; + } + + #[test] + fn empty_satisfaction_model() { + let source = r#" + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(15, 22, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn comments_are_ignored() { + let source = r#" + % Generated by MiniZinc 2.9.2 + % Solver: bla + solve satisfy; % This is ignored + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(75, 82, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn predicate_statements_are_ignored() { + let source = r#" + predicate some_predicate(int: xs, var int: ys); + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(71, 78, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn empty_minimization_model() { + let source = r#" + solve minimize objective; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node( + 15, + 33, + ast::Method::Optimize { + direction: ast::OptimizationDirection::Minimize, + objective: ast::Literal::Identifier("objective".into()), + } + ), + annotations: vec![], + } + } + ); + } + + #[test] + fn empty_maximization_model() { + let source = r#" + solve maximize objective; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node( + 15, + 33, + ast::Method::Optimize { + direction: ast::OptimizationDirection::Maximize, + objective: ast::Literal::Identifier("objective".into()), + } + ), + annotations: vec![], + } + } + ); + } + + #[test] + fn variables() { + let source = r#" + var 1..5: x_interval; + var bool: x_bool; + var {1, 3, 5}: x_sparse; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x_interval".into() => node(9, 30, ast::Variable { + domain: node(13, 17, ast::Domain::Int(ast::RangeList::from(1..=5))), + value: None, + annotations: vec![] + }), + "x_bool".into() => node(39, 56, ast::Variable { + domain: node(43, 47, ast::Domain::Bool), + value: None, + annotations: vec![] + }), + "x_sparse".into() => node(65, 89, ast::Variable { + domain: node(69, 78, ast::Domain::Int(ast::RangeList::from_iter([1, 3, 5]))), + value: None, + annotations: vec![] + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(104, 111, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn variable_with_assignment() { + let source = r#" + var 5..5: x1 = 5; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x1".into() => node(9, 26, ast::Variable { + domain: node(13, 17, ast::Domain::Int(ast::RangeList::from(5..=5))), + value: Some(node(24, 25, ast::Literal::Int(5))), + annotations: vec![] + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(41, 48, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn variable_with_assignment_to_named_constant() { + let source = r#" + int: y = 5; + var 5..5: x1 = y; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x1".into() => node(29, 46, ast::Variable { + domain: node(33, 37, ast::Domain::Int(ast::RangeList::from(5..=5))), + value: Some(node(44, 45, ast::Literal::Int(5))), + annotations: vec![] + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(61, 68, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn arrays_of_constants_and_variables() { + let source = r#" + int: p = 5; + array [1..3] of int: ys = [1, 3, p]; + + var int: some_var; + array [1..2] of var int: vars = [1, some_var]; + + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "some_var".into() => node(75, 93, ast::Variable { + domain: node(79, 82, ast::Domain::UnboundedInt), + value: None, + annotations: vec![] + }), + }, + arrays: btreemap! { + "ys".into() => node(29, 65, ast::Array { + domain: node(45, 48, ast::Domain::UnboundedInt), + contents: vec![ + node(56, 57, ast::Literal::Int(1)), + node(59, 60, ast::Literal::Int(3)), + node(62, 63, ast::Literal::Int(5)), + ], + annotations: vec![], + }), + "vars".into() => node(102, 148, ast::Array { + domain: node(122, 125, ast::Domain::UnboundedInt), + contents: vec![ + node(135, 136, ast::Literal::Int(1)), + node(138, 146, ast::Literal::Identifier("some_var".into())), + ], + annotations: vec![], + }), + }, + constraints: vec![], + solve: ast::SolveItem { + method: node(164, 171, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn constraint_item() { + let source = r#" + constraint int_lin_le(weights, [x1, x2, 3], 3); + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![node( + 9, + 56, + ast::Constraint { + name: node(20, 30, "int_lin_le".into()), + arguments: vec![ + node( + 31, + 38, + ast::Argument::Literal(node( + 31, + 38, + ast::Literal::Identifier("weights".into()) + )) + ), + node( + 40, + 51, + ast::Argument::Array(vec![ + node(41, 43, ast::Literal::Identifier("x1".into())), + node(45, 47, ast::Literal::Identifier("x2".into())), + node(49, 50, ast::Literal::Int(3)), + ]), + ), + node( + 53, + 54, + ast::Argument::Literal(node(53, 54, ast::Literal::Int(3))) + ), + ], + annotations: vec![], + } + )], + solve: ast::SolveItem { + method: node(71, 78, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_variables() { + let source = r#" + var 5..5: x1 :: output_var = 5; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x1".into() => node(9, 40, ast::Variable { + domain: node(13, 17, ast::Domain::Int(ast::RangeList::from(5..=5))), + value: Some(node(38, 39, ast::Literal::Int(5))), + annotations: vec![node(22, 35, ast::Annotation::Atom("output_var".into()))], + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(55, 62, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_arrays() { + let source = r#" + array [1..2] of var 1..10: xs :: output_array([1..2]) = []; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: btreemap! { + "xs".into() => node(9, 68, ast::Array { + domain: node(29, 34, ast::Domain::Int(ast::RangeList::from(1..=10))), + contents: vec![], + annotations: vec![ + node(39, 62, ast::Annotation::Call(ast::AnnotationCall { + name: "output_array".into(), + arguments: vec![node(55, 61, + ast::AnnotationArgument::Array(vec![node(56, 60, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::IntSet( + ast::RangeList::from(1..=2) + ) + ) + )]) + )], + })), + ], + }), + }, + constraints: vec![], + solve: ast::SolveItem { + method: node(83, 90, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_constraints() { + let source = r#" + constraint predicate() :: defines_var(x1); + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![node( + 9, + 51, + ast::Constraint { + name: node(20, 29, "predicate".into()), + arguments: vec![], + annotations: vec![node( + 32, + 50, + ast::Annotation::Call(ast::AnnotationCall { + name: "defines_var".into(), + arguments: vec![node( + 47, + 49, + ast::AnnotationArgument::Literal(node( + 47, + 49, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::Identifier("x1".into()) + ) + )) + )] + }) + )], + } + )], + solve: ast::SolveItem { + method: node(66, 73, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_solve_item() { + let source = r#" + solve :: int_search(first_fail(xs), indomain_min) satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(59, 66, ast::Method::Satisfy), + annotations: vec![node( + 15, + 58, + ast::Annotation::Call(ast::AnnotationCall { + name: "int_search".into(), + arguments: vec![ + node( + 29, + 43, + ast::AnnotationArgument::Literal(node( + 29, + 43, + ast::AnnotationLiteral::Annotation(ast::AnnotationCall { + name: "first_fail".into(), + arguments: vec![node( + 40, + 42, + ast::AnnotationArgument::Literal(node( + 40, + 42, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::Identifier("xs".into()) + ) + )) + )] + }) + )) + ), + node( + 45, + 57, + ast::AnnotationArgument::Literal(node( + 45, + 57, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::Identifier("indomain_min".into()) + ) + )) + ), + ] + }) + )], + } + } + ); + } + + fn node(start: usize, end: usize, data: T) -> ast::Node { + ast::Node { + node: data, + span: ast::Span { start, end }, + } + } +} diff --git a/fzn-rs/src/fzn/tokens.rs b/fzn-rs/src/fzn/tokens.rs new file mode 100644 index 000000000..3b3da7fc5 --- /dev/null +++ b/fzn-rs/src/fzn/tokens.rs @@ -0,0 +1,113 @@ +use std::fmt::Display; + +use chumsky::IterParser; +use chumsky::Parser; +use chumsky::error::Rich; +use chumsky::extra::{self}; +use chumsky::prelude::any; +use chumsky::prelude::choice; +use chumsky::prelude::just; +use chumsky::text::ascii::ident; +use chumsky::text::int; + +use crate::ast; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Token<'src> { + OpenParen, + CloseParen, + OpenBracket, + CloseBracket, + OpenBrace, + CloseBrace, + Comma, + Colon, + DoubleColon, + SemiColon, + DoublePeriod, + Equal, + Ident(&'src str), + Integer(i64), + Boolean(bool), +} + +impl Display for Token<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::OpenParen => write!(f, "("), + Token::CloseParen => write!(f, ")"), + Token::OpenBracket => write!(f, "["), + Token::CloseBracket => write!(f, "]"), + Token::OpenBrace => write!(f, "{{"), + Token::CloseBrace => write!(f, "}}"), + Token::Comma => write!(f, ","), + Token::Colon => write!(f, ":"), + Token::DoubleColon => write!(f, "::"), + Token::SemiColon => write!(f, ";"), + Token::DoublePeriod => write!(f, ".."), + Token::Equal => write!(f, "="), + Token::Ident(ident) => write!(f, "{ident}"), + Token::Integer(int) => write!(f, "{int}"), + Token::Boolean(boolean) => write!(f, "{boolean}"), + } + } +} + +type LexExtra<'src> = extra::Err>; + +pub(super) fn lex<'src>() +-> impl Parser<'src, &'src str, Vec>>, LexExtra<'src>> { + token() + .padded_by(comment().repeated()) + .padded() + .repeated() + .collect() +} + +fn comment<'src>() -> impl Parser<'src, &'src str, (), extra::Err>> { + just("%") + .then(any().and_is(just('\n').not()).repeated()) + .padded() + .ignored() +} + +fn token<'src>() +-> impl Parser<'src, &'src str, ast::Node>, extra::Err>> { + choice(( + // Punctuation + just(";").to(Token::SemiColon), + just("::").to(Token::DoubleColon), + just(":").to(Token::Colon), + just(",").to(Token::Comma), + just("..").to(Token::DoublePeriod), + just("[").to(Token::OpenBracket), + just("]").to(Token::CloseBracket), + just("{").to(Token::OpenBrace), + just("}").to(Token::CloseBrace), + just("(").to(Token::OpenParen), + just(")").to(Token::CloseParen), + just("=").to(Token::Equal), + // Values + just("true").to(Token::Boolean(true)), + just("false").to(Token::Boolean(false)), + int_literal().map(Token::Integer), + // Identifiers (including keywords) + ident().map(Token::Ident), + )) + .map_with(|token, extra| { + let span: chumsky::prelude::SimpleSpan = extra.span(); + + ast::Node { + node: token, + span: span.into(), + } + }) +} + +fn int_literal<'src>() -> impl Parser<'src, &'src str, i64, LexExtra<'src>> { + just("-") + .or_not() + .ignore_then(int(10)) + .to_slice() + .map(|slice: &str| slice.parse().unwrap()) +} diff --git a/fzn-rs/src/lib.rs b/fzn-rs/src/lib.rs new file mode 100644 index 000000000..305118728 --- /dev/null +++ b/fzn-rs/src/lib.rs @@ -0,0 +1,121 @@ +//! # fzn-rs +//! +//! `fzn-rs` is a crate that allows for easy parsing of FlatZinc instances in Rust. It facilitates +//! type-driven parsing of a FlatZinc file using derive macros. +//! +//! ## Features +//! - `fzn-parser`: Include the parser for fzn files in the traditional `.fzn` format. In the +//! future, a parser for the JSON format will be included as well, behind a separate feature. +//! - `derive`: Include the derive macro's to parse the AST into a strongly-typed model. +//! +//! ## Example +//! ``` +//! use fzn_rs::ArrayExpr; +//! use fzn_rs::FlatZincConstraint; +//! use fzn_rs::TypedInstance; +//! use fzn_rs::VariableExpr; +//! +//! /// The FlatZincConstraint derive macro enables the parsing of a strongly typed constraint +//! /// based on the FlatZinc Ast. +//! #[derive(FlatZincConstraint)] +//! pub enum MyConstraints { +//! /// The variant name is converted to snake_case to serve as the constraint identifier by +//! /// default. +//! IntLinLe(ArrayExpr, ArrayExpr>, i64), +//! +//! /// If the snake_case version of the variant name is different from the constraint +//! /// identifier, then the `#[name(...)], attribute allows you to set the constraint +//! /// identifier explicitly. +//! #[name("int_lin_eq")] +//! LinearEquality(ArrayExpr, ArrayExpr>, i64), +//! +//! /// Constraint arguments can also be named, but the order determines how they are parsed +//! /// from the AST. +//! Element { +//! index: VariableExpr, +//! array: ArrayExpr, +//! rhs: VariableExpr, +//! }, +//! +//! /// Arguments can also be separate structs, if the enum variant has exactly one argument. +//! #[args] +//! IntTimes(Multiplication), +//! } +//! +//! #[derive(FlatZincConstraint)] +//! pub struct Multiplication { +//! a: VariableExpr, +//! b: VariableExpr, +//! c: VariableExpr, +//! } +//! +//! /// The `TypedInstance` is parameterized by the constraint type, as well as any annotations you +//! /// may need to parse. It uses `i64` to represent integers. +//! type MyInstance = TypedInstance; +//! +//! fn parse_flatzinc(source: &str) -> MyInstance { +//! // First, the source string is parsed into a structured representation. +//! // +//! // Note: the `fzn_rs::fzn` module is only available with the `fzn-parser` feature enabled. +//! let ast = fzn_rs::fzn::parse(source).expect("source is valid flatzinc"); +//! +//! // Then, the strongly-typed instance is created from the AST +//! MyInstance::from_ast(ast).expect("type-checking passes") +//! } +//! ``` +//! +//! ## Derive Macros +//! When parsing a FlatZinc file, the result is an [`ast::Ast`]. That type describes any valid +//! FlatZinc file. However, when consuming FlatZinc, typically you need to process that AST +//! further. For example, to support the [`int_lin_le`][1] constraint, you have to validate that the +//! [`ast::Constraint`] has three arguments, and that each of the arguments has the correct type. +//! +//! Similar to typed constraints, the derive macro for [`FlatZincAnnotation`] allows for easy +//! parsing of annotations: +//! ``` +//! use std::rc::Rc; +//! +//! use fzn_rs::FlatZincAnnotation; +//! +//! #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] +//! enum TypedAnnotation { +//! /// Matches the snake-case atom "annotation". +//! Annotation, +//! +//! /// Supports nested annotations with the `#[annotation]` attribute. +//! SomeAnnotation(#[annotation] SomeAnnotationArgs), +//! } +//! +//! #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] +//! enum SomeAnnotationArgs { +//! /// Just as constraints, the name can be explicitly set. +//! #[name("arg_one")] +//! Arg1, +//! ArgTwo(Rc), +//! } +//! ``` +//! Different to parsing constraints, is that annotations can be ignored. If the AST contains an +//! annotation whose name does not match one of the variants in the enum, then the annotation is +//! simply ignored. +//! +//! ## Comparison to other FlatZinc crates +//! There are two well-known crates for parsing FlatZinc files: +//! - [flatzinc](https://docs.rs/flatzinc), for parsing the original `fzn` format, +//! - and [flatzinc-serde](https://docs.rs/flatzinc-serde), for parsing `fzn.json`. +//! +//! These crates produce what we call the [`ast::Ast`] in this crate, although the concrete types +//! can be different. `fzn-rs` builds the strong typing of constraints and annotations on-top of +//! a unified AST for both file formats. Finally, our aim is to improve the error messages that +//! are encountered when parsing invalid FlatZinc files. +//! +//! [1]: https://docs.minizinc.dev/en/stable/lib-flatzinc-int.html#int-lin-le + +mod error; +mod typed; + +pub mod ast; +pub mod fzn; + +pub use error::*; +pub use fzn_rs_derive::*; +pub use typed::*; diff --git a/fzn-rs/src/typed/arrays.rs b/fzn-rs/src/typed/arrays.rs new file mode 100644 index 000000000..1c076fc85 --- /dev/null +++ b/fzn-rs/src/typed/arrays.rs @@ -0,0 +1,194 @@ +use std::collections::BTreeMap; +use std::marker::PhantomData; +use std::rc::Rc; + +use super::FromAnnotationArgument; +use super::FromAnnotationLiteral; +use super::FromArgument; +use super::FromLiteral; +use super::VariableExpr; +use crate::InstanceError; +use crate::Token; +use crate::ast; + +/// Models an array in a constraint argument. +/// +/// ## Example +/// ``` +/// use fzn_rs::ArrayExpr; +/// use fzn_rs::FlatZincConstraint; +/// use fzn_rs::VariableExpr; +/// +/// #[derive(FlatZincConstraint)] +/// struct Linear { +/// /// An array of constants. +/// weights: ArrayExpr, +/// /// An array of variables. +/// variables: ArrayExpr>, +/// } +/// ``` +/// +/// Use [`crate::TypedInstance::resolve_array`] to access the elements in the array. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ArrayExpr { + expr: ArrayExprImpl, + ty: PhantomData, +} + +impl ArrayExpr +where + T: FromAnnotationLiteral, +{ + pub(crate) fn resolve<'a, Ann>( + &'a self, + arrays: &'a BTreeMap, ast::Array>, + ) -> Result> + 'a, Rc> { + match &self.expr { + ArrayExprImpl::Identifier(ident) => arrays + .get(ident) + .map(|array| { + GenericIterator(Box::new( + array.contents.iter().map(::from_literal), + )) + }) + .ok_or_else(|| Rc::clone(ident)), + ArrayExprImpl::Array(array) => Ok(GenericIterator(Box::new( + array.iter().map(::from_literal), + ))), + } + } +} + +impl FromArgument for ArrayExpr { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::Argument::Array(contents) => { + let contents = contents + .iter() + .cloned() + .map(|node| ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(node.node), + span: node.span, + }) + .collect(); + + Ok(ArrayExpr { + expr: ArrayExprImpl::Array(contents), + ty: PhantomData, + }) + } + ast::Argument::Literal(ast::Node { + node: ast::Literal::Identifier(ident), + .. + }) => Ok(ArrayExpr { + expr: ArrayExprImpl::Identifier(Rc::clone(ident)), + ty: PhantomData, + }), + ast::Argument::Literal(literal) => Err(InstanceError::UnexpectedToken { + expected: Token::Array, + actual: Token::from(&literal.node), + span: literal.span, + }), + } + } +} + +impl FromAnnotationArgument for ArrayExpr { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::AnnotationArgument::Array(contents) => Ok(ArrayExpr { + expr: ArrayExprImpl::Array(contents.clone()), + ty: PhantomData, + }), + ast::AnnotationArgument::Literal(ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(ast::Literal::Identifier(ident)), + .. + }) => Ok(ArrayExpr { + expr: ArrayExprImpl::Identifier(Rc::clone(ident)), + ty: PhantomData, + }), + ast::AnnotationArgument::Literal(literal) => Err(InstanceError::UnexpectedToken { + expected: Token::Array, + actual: Token::from(&literal.node), + span: literal.span, + }), + } + } +} + +impl From>> for ArrayExpr +where + ast::Literal: From, +{ + fn from(value: Vec>) -> Self { + ArrayExpr { + expr: ArrayExprImpl::Array( + value + .into_iter() + .map(|value| ast::Node { + node: match value { + VariableExpr::Identifier(ident) => { + ast::AnnotationLiteral::BaseLiteral(ast::Literal::Identifier(ident)) + } + VariableExpr::Constant(value) => { + ast::AnnotationLiteral::BaseLiteral(ast::Literal::from(value)) + } + }, + span: ast::Span { + start: usize::MAX, + end: usize::MAX, + }, + }) + .collect(), + ), + ty: PhantomData, + } + } +} + +impl From> for ArrayExpr +where + ast::Literal: From, +{ + fn from(value: Vec) -> Self { + ArrayExpr { + expr: ArrayExprImpl::Array( + value + .into_iter() + .map(|value| ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(ast::Literal::from(value)), + span: ast::Span { + start: usize::MAX, + end: usize::MAX, + }, + }) + .collect(), + ), + ty: PhantomData, + } + } +} + +/// The actual array expression, which is either an identifier or the array. +/// +/// This is a private type as all access to the array should go through [`ArrayExpr::resolve`]. +#[derive(Clone, Debug, PartialEq, Eq)] +enum ArrayExprImpl { + Identifier(Rc), + /// Regardless of the contents of the array, the elements will always be of type + /// [`ast::AnnotationLiteral`] to support the parsing of annotations from arrays. + Array(Vec>), +} + +/// A boxed dyn [`ExactSizeIterator`] which is returned from [`ArrayExpr::resolve`]. +struct GenericIterator<'a, T>(Box + 'a>); + +impl Iterator for GenericIterator<'_, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl ExactSizeIterator for GenericIterator<'_, T> {} diff --git a/fzn-rs/src/typed/constraint.rs b/fzn-rs/src/typed/constraint.rs new file mode 100644 index 000000000..f607c98ce --- /dev/null +++ b/fzn-rs/src/typed/constraint.rs @@ -0,0 +1,8 @@ +use crate::ast; + +/// A constraint that has annotations attached to it. +#[derive(Clone, Debug)] +pub struct AnnotatedConstraint { + pub constraint: ast::Node, + pub annotations: Vec>, +} diff --git a/fzn-rs/src/typed/flatzinc_annotation.rs b/fzn-rs/src/typed/flatzinc_annotation.rs new file mode 100644 index 000000000..b9b894a3b --- /dev/null +++ b/fzn-rs/src/typed/flatzinc_annotation.rs @@ -0,0 +1,150 @@ +use std::rc::Rc; + +use super::FromLiteral; +use crate::InstanceError; +use crate::Token; +use crate::ast; + +/// Parse an [`ast::Annotation`] into a specific annotation type. +/// +/// The difference with [`crate::FlatZincConstraint::from_ast`] is that annotations can be ignored. +/// [`FlatZincAnnotation::from_ast`] can successfully parse an annotation into nothing, signifying +/// the annotation is not of interest in the final [`crate::TypedInstance`]. +pub trait FlatZincAnnotation: Sized { + /// Parse a value of `Self` from the annotation node. Return `None` if the annotation node + /// clearly is not relevant for `Self`, e.g. when the name is for a completely different + /// annotation than `Self` models. + fn from_ast(annotation: &ast::Node) -> Result, InstanceError>; + + /// Parse an [`ast::Annotation`] into `Self` and produce an error if the annotation cannot be + /// converted to a value of `Self`. + fn from_ast_required(annotation: &ast::Node) -> Result { + let outcome = Self::from_ast(annotation)?; + + // By default, failing to parse an annotation node into an annotation type is not + // necessarily an error since the annotation node can be ignored. In this case, however, + // we require a value to be present. Hence, if `outcome` is `None`, that is an error. + outcome.ok_or_else(|| InstanceError::UnsupportedAnnotation(annotation.node.name().into())) + } +} + +/// A default implementation that ignores all annotations. +impl FlatZincAnnotation for () { + fn from_ast(_: &ast::Node) -> Result, InstanceError> { + Ok(None) + } +} + +/// Parse a value from an [`ast::AnnotationArgument`]. +/// +/// Any type that implements [`FromAnnotationLiteral`] also implements [`FromAnnotationArgument`]. +pub trait FromAnnotationArgument: Sized { + fn from_argument(argument: &ast::Node) -> Result; +} + +/// Parse a value from an [`ast::AnnotationLiteral`]. +pub trait FromAnnotationLiteral: FromLiteral + Sized { + fn expected() -> Token; + + fn from_literal(literal: &ast::Node) -> Result; +} + +impl FromAnnotationLiteral for T { + fn expected() -> Token { + T::expected() + } + + fn from_literal(literal: &ast::Node) -> Result { + match &literal.node { + ast::AnnotationLiteral::BaseLiteral(base_literal) => T::from_literal(&ast::Node { + node: base_literal.clone(), + span: literal.span, + }), + ast::AnnotationLiteral::Annotation(_) => Err(InstanceError::UnexpectedToken { + expected: T::expected(), + actual: Token::AnnotationCall, + span: literal.span, + }), + } + } +} + +impl FromAnnotationArgument for T { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::AnnotationArgument::Literal(literal) => { + ::from_literal(literal) + } + + node => Err(InstanceError::UnexpectedToken { + expected: ::expected(), + actual: node.into(), + span: argument.span, + }), + } + } +} + +/// Parse an [`ast::AnnotationArgument`] as an annotation. This needs to be a separate trait from +/// [`FromAnnotationArgument`] so it does not collide wiith implementations for literals. +pub trait FromNestedAnnotation: Sized { + fn from_argument(argument: &ast::Node) -> Result; +} + +/// Converts an [`ast::AnnotationLiteral`] to an [`ast::Annotation`], or produces an error if that +/// is not possible. +fn annotation_literal_to_annotation( + literal: &ast::Node, +) -> Result, InstanceError> { + match &literal.node { + ast::AnnotationLiteral::BaseLiteral(ast::Literal::Identifier(ident)) => Ok(ast::Node { + node: ast::Annotation::Atom(Rc::clone(ident)), + span: literal.span, + }), + ast::AnnotationLiteral::Annotation(annotation_call) => Ok(ast::Node { + node: ast::Annotation::Call(annotation_call.clone()), + span: literal.span, + }), + ast::AnnotationLiteral::BaseLiteral(lit) => Err(InstanceError::UnexpectedToken { + expected: Token::Annotation, + actual: lit.into(), + span: literal.span, + }), + } +} + +impl FromNestedAnnotation for Ann { + fn from_argument(argument: &ast::Node) -> Result { + let annotation = match &argument.node { + ast::AnnotationArgument::Literal(literal) => annotation_literal_to_annotation(literal)?, + ast::AnnotationArgument::Array(_) => { + return Err(InstanceError::UnexpectedToken { + expected: Token::Annotation, + actual: Token::Array, + span: argument.span, + }); + } + }; + + Ann::from_ast_required(&annotation) + } +} + +impl FromNestedAnnotation for Vec { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::AnnotationArgument::Array(elements) => elements + .iter() + .map(|literal| { + let annotation = annotation_literal_to_annotation(literal)?; + Ann::from_ast_required(&annotation) + }) + .collect::>(), + ast::AnnotationArgument::Literal(lit) => Err(InstanceError::UnexpectedToken { + expected: Token::Array, + actual: (&lit.node).into(), + span: argument.span, + }), + } + } +} diff --git a/fzn-rs/src/typed/flatzinc_constraint.rs b/fzn-rs/src/typed/flatzinc_constraint.rs new file mode 100644 index 000000000..e45890fbe --- /dev/null +++ b/fzn-rs/src/typed/flatzinc_constraint.rs @@ -0,0 +1,27 @@ +use super::FromLiteral; +use crate::InstanceError; +use crate::Token; +use crate::ast; + +/// Parse a constraint from the given [`ast::Constraint`]. +pub trait FlatZincConstraint: Sized { + fn from_ast(constraint: &ast::Node) -> Result; +} + +/// Extract an argument from the [`ast::Argument`] node. +pub trait FromArgument: Sized { + fn from_argument(argument: &ast::Node) -> Result; +} + +impl FromArgument for T { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::Argument::Literal(literal) => T::from_literal(literal), + ast::Argument::Array(_) => Err(InstanceError::UnexpectedToken { + expected: T::expected(), + actual: Token::Array, + span: argument.span, + }), + } + } +} diff --git a/fzn-rs/src/typed/from_literal.rs b/fzn-rs/src/typed/from_literal.rs new file mode 100644 index 000000000..069a6b54e --- /dev/null +++ b/fzn-rs/src/typed/from_literal.rs @@ -0,0 +1,132 @@ +use std::rc::Rc; + +use super::VariableExpr; +use crate::InstanceError; +use crate::Token; +use crate::ast; + +/// Extract a value from an [`ast::Literal`]. +pub trait FromLiteral: Sized { + /// The [`Token`] that is expected for this implementation. Used to create error messages. + fn expected() -> Token; + + /// Extract `Self` from a literal AST node. + fn from_literal(node: &ast::Node) -> Result; +} + +impl FromLiteral for i32 { + fn expected() -> Token { + Token::IntLiteral + } + + fn from_literal(node: &ast::Node) -> Result { + let integer = ::from_literal(node)?; + i32::try_from(integer).map_err(|_| InstanceError::IntegerOverflow(integer)) + } +} + +impl FromLiteral for i64 { + fn expected() -> Token { + Token::IntLiteral + } + + fn from_literal(node: &ast::Node) -> Result { + match &node.node { + ast::Literal::Int(value) => Ok(*value), + literal => Err(InstanceError::UnexpectedToken { + expected: Token::IntLiteral, + actual: literal.into(), + span: node.span, + }), + } + } +} + +impl FromLiteral for VariableExpr { + fn expected() -> Token { + Token::Variable(Box::new(T::expected())) + } + + fn from_literal(node: &ast::Node) -> Result { + match &node.node { + ast::Literal::Identifier(identifier) => { + Ok(VariableExpr::Identifier(Rc::clone(identifier))) + } + literal => T::from_literal(node) + .map(VariableExpr::Constant) + .map_err(|_| InstanceError::UnexpectedToken { + expected: ::expected(), + actual: literal.into(), + span: node.span, + }), + } + } +} + +impl FromLiteral for Rc { + fn expected() -> Token { + Token::Identifier + } + + fn from_literal(argument: &ast::Node) -> Result { + match &argument.node { + ast::Literal::Identifier(ident) => Ok(Rc::clone(ident)), + + node => Err(InstanceError::UnexpectedToken { + expected: Token::Identifier, + actual: node.into(), + span: argument.span, + }), + } + } +} + +impl FromLiteral for bool { + fn expected() -> Token { + Token::BoolLiteral + } + + fn from_literal(argument: &ast::Node) -> Result { + match &argument.node { + ast::Literal::Bool(boolean) => Ok(*boolean), + + node => Err(InstanceError::UnexpectedToken { + expected: Token::BoolLiteral, + actual: node.into(), + span: argument.span, + }), + } + } +} + +impl FromLiteral for ast::RangeList { + fn expected() -> Token { + Token::IntSetLiteral + } + + fn from_literal(argument: &ast::Node) -> Result { + let set = as FromLiteral>::from_literal(argument)?; + + set.into_iter() + .map(|elem| i32::try_from(elem).map_err(|_| InstanceError::IntegerOverflow(elem))) + .collect::>() + } +} + +impl FromLiteral for ast::RangeList { + fn expected() -> Token { + Token::IntSetLiteral + } + + fn from_literal(argument: &ast::Node) -> Result { + match &argument.node { + ast::Literal::IntSet(set) => Ok(set.clone()), + + node => Err(InstanceError::UnexpectedToken { + expected: Token::IntSetLiteral, + actual: node.into(), + span: argument.span, + }), + } + } +} diff --git a/fzn-rs/src/typed/instance.rs b/fzn-rs/src/typed/instance.rs new file mode 100644 index 000000000..6d43b93d2 --- /dev/null +++ b/fzn-rs/src/typed/instance.rs @@ -0,0 +1,193 @@ +use std::collections::BTreeMap; +use std::rc::Rc; + +use super::ArrayExpr; +use super::FlatZincAnnotation; +use super::FlatZincConstraint; +use super::FromAnnotationLiteral; +use super::FromLiteral; +use super::VariableExpr; +use crate::AnnotatedConstraint; +use crate::InstanceError; +use crate::ast; + +/// A fully typed representation of a FlatZinc instance. +/// +/// It is generic over the type of constraints, as well as the annotations for variables, arrays, +/// constraints, and solve. +#[derive(Clone, Debug)] +pub struct TypedInstance< + Int, + Constraint, + VariableAnnotations = (), + ArrayAnnotations = (), + ConstraintAnnotations = (), + SolveAnnotations = (), +> { + /// The variables that are in the instance. + /// + /// The key is the identifier of the variable, and the value is the domain of the variable. + pub variables: BTreeMap, ast::Variable>, + + /// The arrays in the instance. + /// + /// The key is the identifier of the array, and the value is the array itself. + pub arrays: BTreeMap, ast::Array>, + + /// The constraints in the instance. + pub constraints: Vec>, + + /// The solve item indicating how to solve the model. + pub solve: SolveItem, +} + +/// Specifies how to solve a [`TypedInstance`]. +/// +/// This is generic over the integer type. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SolveItem { + pub method: ast::Node>, + pub annotations: Vec>, +} + +/// Indicate whether the model is an optimisation or satisfaction model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Method { + Satisfy, + Optimize { + direction: ast::OptimizationDirection, + objective: VariableExpr, + }, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[error("array '{0}' is undefined")] +pub struct UndefinedArrayError(pub String); + +impl + TypedInstance +{ + /// Get the elements in an [`ArrayExpr`]. + pub fn resolve_array<'a, T>( + &'a self, + array_expr: &'a ArrayExpr, + ) -> Result> + 'a, UndefinedArrayError> + where + T: FromAnnotationLiteral, + { + array_expr + .resolve(&self.arrays) + .map_err(|identifier| UndefinedArrayError(identifier.as_ref().into())) + } +} + +impl + TypedInstance +where + TConstraint: FlatZincConstraint, + VAnnotations: FlatZincAnnotation, + AAnotations: FlatZincAnnotation, + CAnnotations: FlatZincAnnotation, + SAnnotations: FlatZincAnnotation, + VariableExpr: FromLiteral, +{ + /// Create a [`TypedInstance`] from an [`ast::Ast`]. + /// + /// This parses the constraints and annotations, and can fail e.g. if the number or type of + /// arguments do not match what is expected in the parser. + pub fn from_ast(ast: ast::Ast) -> Result { + let variables = ast + .variables + .into_iter() + .map(|(id, variable)| { + let variable = ast::Variable { + domain: variable.node.domain, + value: variable.node.value, + annotations: map_annotations(&variable.node.annotations)?, + }; + + Ok((id, variable)) + }) + .collect::>()?; + + let arrays = ast + .arrays + .into_iter() + .map(|(id, array)| { + let array = ast::Array { + domain: array.node.domain, + contents: array.node.contents, + annotations: map_annotations(&array.node.annotations)?, + }; + + Ok((id, array)) + }) + .collect::>()?; + + let constraints = ast + .constraints + .iter() + .map(|constraint| { + let annotations = map_annotations(&constraint.node.annotations)?; + + let instance_constraint = TConstraint::from_ast(constraint)?; + + Ok(AnnotatedConstraint { + constraint: ast::Node { + node: instance_constraint, + span: constraint.span, + }, + annotations, + }) + }) + .collect::>()?; + + let solve = SolveItem { + method: match ast.solve.method.node { + ast::Method::Satisfy => ast::Node { + node: Method::Satisfy, + span: ast.solve.method.span, + }, + ast::Method::Optimize { + direction, + objective, + } => ast::Node { + node: Method::Optimize { + direction, + objective: as FromLiteral>::from_literal(&ast::Node { + node: objective, + span: ast.solve.method.span, + })?, + }, + span: ast.solve.method.span, + }, + }, + annotations: map_annotations(&ast.solve.annotations)?, + }; + + Ok(TypedInstance { + variables, + arrays, + constraints, + solve, + }) + } +} + +fn map_annotations( + annotations: &[ast::Node], +) -> Result>, InstanceError> { + annotations + .iter() + .filter_map(|annotation| { + Ann::from_ast(annotation) + .map(|maybe_node| { + maybe_node.map(|node| ast::Node { + node, + span: annotation.span, + }) + }) + .transpose() + }) + .collect() +} diff --git a/fzn-rs/src/typed/mod.rs b/fzn-rs/src/typed/mod.rs new file mode 100644 index 000000000..7a3217c51 --- /dev/null +++ b/fzn-rs/src/typed/mod.rs @@ -0,0 +1,23 @@ +mod arrays; +mod constraint; +mod flatzinc_annotation; +mod flatzinc_constraint; +mod from_literal; +mod instance; + +use std::rc::Rc; + +pub use arrays::*; +pub use constraint::*; +pub use flatzinc_annotation::*; +pub use flatzinc_constraint::*; +pub use from_literal::*; +pub use instance::*; + +/// Models a variable in the FlatZinc AST. Since `var T` is a subtype of `T`, a variable can also +/// be a constant. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum VariableExpr { + Identifier(Rc), + Constant(T), +}