diff --git a/Cargo.lock b/Cargo.lock index 4bbfa0c5f..d84ca6723 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -79,6 +79,16 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "ariadne" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f5e3dca4e09a6f340a61a0e9c7b61e030c69fc27bf29d73218f7e5e3b7638f" +dependencies = [ + "unicode-width", + "yansi", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -186,6 +196,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "convert_case" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baaaa0ecca5b51987b9423ccdc971514dd8b0bb7b4060b983d3664dad3f1f89f" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -328,15 +347,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "flatzinc" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cb354e9148694c9d928bdeca0436943cb024e666bda31cf4d1bbbffc7bebf14" -dependencies = [ - "winnow", -] - [[package]] name = "fnv" version = "1.0.7" @@ -349,6 +359,26 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "fzn-rs" +version = "0.1.0" +dependencies = [ + "chumsky", + "fzn-rs-derive", + "thiserror", +] + +[[package]] +name = "fzn-rs-derive" +version = "0.1.0" +dependencies = [ + "convert_case 0.8.0", + "fzn-rs", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -588,7 +618,7 @@ dependencies = [ "bitfield", "bitfield-struct", "clap", - "convert_case", + "convert_case 0.6.0", "downcast-rs", "drcp-format", "enum-map", @@ -618,11 +648,12 @@ dependencies = [ name = "pumpkin-solver" version = "0.2.1" dependencies = [ + "ariadne", "cc", "clap", "env_logger", - "flatzinc", "fnv", + "fzn-rs", "log", "pumpkin-core", "pumpkin-macros", @@ -872,9 +903,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "syn" -version = "2.0.103" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -928,6 +959,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "unindent" version = "0.2.4" @@ -1038,13 +1075,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] -name = "winnow" -version = "0.6.26" +name = "yansi" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e90edd2ac1aa278a5c4599b1d89cf03074b610800f866d4026dc199d7929a28" -dependencies = [ - "memchr", -] +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "zerocopy" diff --git a/Cargo.toml b/Cargo.toml index 573b614d0..0d6e864f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["./pumpkin-solver", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./drcp-debugger", "./pumpkin-crates/*"] +members = ["./pumpkin-solver", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./drcp-debugger", "./pumpkin-crates/*", "./fzn-rs", "fzn-rs-derive"] default-members = ["./pumpkin-solver", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros"] resolver = "2" diff --git a/clippy.toml b/clippy.toml index 937f96fe2..607dc0e84 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1 +1 @@ -allowed-duplicate-crates = ["regex-automata", "regex-syntax"] +allowed-duplicate-crates = ["regex-automata", "regex-syntax", "convert_case"] diff --git a/fzn-rs-derive/Cargo.toml b/fzn-rs-derive/Cargo.toml new file mode 100644 index 000000000..a94eeae7a --- /dev/null +++ b/fzn-rs-derive/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "fzn-rs-derive" +version = "0.1.0" +repository.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true + +[lib] +proc-macro = true + +[dependencies] +convert_case = "0.8.0" +proc-macro2 = "1.0.95" +quote = "1.0.40" +syn = { version = "2.0.104", features = ["extra-traits"] } + +[dev-dependencies] +fzn-rs = { path = "../fzn-rs/" } + +[lints] +workspace = true diff --git a/fzn-rs-derive/src/annotation.rs b/fzn-rs-derive/src/annotation.rs new file mode 100644 index 000000000..4af9dd467 --- /dev/null +++ b/fzn-rs-derive/src/annotation.rs @@ -0,0 +1,114 @@ +use quote::quote; + +/// Construct a token stream that initialises a value with name `value_type` and the arguments +/// described in `fields`. +pub(crate) fn initialise_value( + value_type: &syn::Ident, + fields: &syn::Fields, +) -> proc_macro2::TokenStream { + // For every field, initialise the value for that field. + let field_values = fields.iter().enumerate().map(|(idx, field)| { + let ty = &field.ty; + + // If the field has a name, then prepend the field value with the name: `name: value`. + let value_prefix = if let Some(ident) = &field.ident { + quote! { #ident: } + } else { + quote! {} + }; + + // If there is an `#[annotation]` attribute on the field, then the value is the result of + // parsing a nested annotation. Otherwise, we look at the type of the field + // and parse the value corresponding to that type. + if field.attrs.iter().any(|attr| { + attr.path() + .get_ident() + .is_some_and(|ident| ident == "annotation") + }) { + quote! { + #value_prefix <#ty as ::fzn_rs::FromNestedAnnotation>::from_argument( + &arguments[#idx], + )? + } + } else { + quote! { + #value_prefix <#ty as ::fzn_rs::FromAnnotationArgument>::from_argument( + &arguments[#idx], + )? + } + } + }); + + // Complete the value initialiser by prepending the type name to the field values. + let value_initialiser = match fields { + syn::Fields::Named(_) => quote! { #value_type { #(#field_values),* } }, + syn::Fields::Unnamed(_) => quote! { #value_type ( #(#field_values),* ) }, + syn::Fields::Unit => quote! { #value_type }, + }; + + let num_arguments = fields.len(); + + // Output the final initialisation, with checking of number of arguments. + quote! { + if arguments.len() != #num_arguments { + return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #num_arguments, + actual: arguments.len(), + span: annotation.span, + }); + } + + Ok(Some(#value_initialiser)) + } +} + +/// Create the parsing code for one annotation corresponding to the given variant. +pub(crate) fn variant_to_annotation(variant: &syn::Variant) -> proc_macro2::TokenStream { + // Determine the flatzinc annotation name. + let name = match crate::common::get_explicit_name(variant) { + Ok(name) => name, + Err(_) => { + return quote! { + compile_error!("Invalid usage of #[name(...)]"); + } + } + }; + + let variant_name = &variant.ident; + + // If variant argument is a struct, then delegate parsing of the annotation arguments to that + // struct. + if let Some(constraint_type) = crate::common::get_args_type(variant) { + return quote! { + ::fzn_rs::ast::Annotation::Call(::fzn_rs::ast::AnnotationCall { + name, + arguments, + }) if name.as_ref() == #name => { + let args = <#constraint_type as ::fzn_rs::FlatZincAnnotation>::from_ast_required(annotation)?; + let value = #variant_name(args); + Ok(Some(value)) + } + }; + } + + // If the variant has no arguments, parse an atom annotaton. Otherwise, initialise the values + // of the variant arguments. + if matches!(variant.fields, syn::Fields::Unit) { + quote! { + ::fzn_rs::ast::Annotation::Atom(ident) if ident.as_ref() == #name => { + Ok(Some(#variant_name)) + } + } + } else { + let value = initialise_value(&variant.ident, &variant.fields); + + quote! { + ::fzn_rs::ast::Annotation::Call(::fzn_rs::ast::AnnotationCall { + name, + arguments, + }) if name.as_ref() == #name => { + #value + } + } + } +} diff --git a/fzn-rs-derive/src/common.rs b/fzn-rs-derive/src/common.rs new file mode 100644 index 000000000..6d2a6d83f --- /dev/null +++ b/fzn-rs-derive/src/common.rs @@ -0,0 +1,55 @@ +use convert_case::Case; +use convert_case::Casing; + +/// Get the name of the constraint or annotation from the variant. This either is converting the +/// variant name to snake case, or retrieving the value from the `#[name(...)]` attribute. +pub(crate) fn get_explicit_name(variant: &syn::Variant) -> syn::Result { + variant + .attrs + .iter() + // Find the attribute with a `name` as the path. + .find(|attr| attr.path().get_ident().is_some_and(|ident| ident == "name")) + // Parse the arguments of the attribute to a string literal. + .map(|attr| { + attr.parse_args::() + .map(|string_lit| string_lit.value()) + }) + // If no `name` attribute exists, return the snake-case version of the variant name. + .unwrap_or_else(|| Ok(variant.ident.to_string().to_case(Case::Snake))) +} + +/// Returns the type of the arguments for the variant if the variant has exactly the following +/// shape: +/// +/// ```ignore +/// #[args] +/// Variant(Type) +/// ``` +pub(crate) fn get_args_type(variant: &syn::Variant) -> Option<&syn::Type> { + let has_args_attr = variant + .attrs + .iter() + .any(|attr| attr.path().get_ident().is_some_and(|ident| ident == "args")); + + if !has_args_attr { + return None; + } + + if variant.fields.len() != 1 { + // If there is not exactly one argument for this variant, then it cannot be a struct + // constraint. + return None; + } + + let field = variant + .fields + .iter() + .next() + .expect("there is exactly one field"); + + if field.ident.is_none() { + Some(&field.ty) + } else { + None + } +} diff --git a/fzn-rs-derive/src/constraint.rs b/fzn-rs-derive/src/constraint.rs new file mode 100644 index 000000000..ee819ec58 --- /dev/null +++ b/fzn-rs-derive/src/constraint.rs @@ -0,0 +1,107 @@ +use quote::quote; + +/// Construct a token stream that initialises a value with name `value_type` and the arguments +/// described in `fields`. +pub(crate) fn initialise_value( + identifier: &syn::Ident, + fields: &syn::Fields, +) -> proc_macro2::TokenStream { + match fields { + // In case of named fields, the order of the fields is the order of the flatzinc arguments. + syn::Fields::Named(fields) => { + let arguments = fields.named.iter().enumerate().map(|(idx, field)| { + let field_name = field + .ident + .as_ref() + .expect("we are in a syn::Fields::Named"); + let ty = &field.ty; + + quote! { + #field_name: <#ty as ::fzn_rs::FromArgument>::from_argument( + &constraint.node.arguments[#idx], + )?, + } + }); + + quote! { #identifier { #(#arguments)* } } + } + + syn::Fields::Unnamed(fields) => { + let arguments = fields.unnamed.iter().enumerate().map(|(idx, field)| { + let ty = &field.ty; + + quote! { + <#ty as ::fzn_rs::FromArgument>::from_argument( + &constraint.node.arguments[#idx], + )?, + } + }); + + quote! { #identifier ( #(#arguments)* ) } + } + + syn::Fields::Unit => quote! { + compile_error!("A FlatZinc constraint must have at least one field") + }, + } +} + +/// Generate an implementation of `FlatZincConstraint` for enums. +pub(crate) fn flatzinc_constraint_for_enum( + constraint_enum_name: &syn::Ident, + data_enum: &syn::DataEnum, +) -> proc_macro2::TokenStream { + // For every variant in the enum, create a match arm that matches the constraint name and + // parses the constraint with the appropriate arguments. + let constraints = data_enum.variants.iter().map(|variant| { + // Determine the flatzinc name of the constraint. + let name = match crate::common::get_explicit_name(variant) { + Ok(name) => name, + Err(_) => { + return quote! { + compile_error!("Invalid usage of #[name(...)]"); + } + } + }; + + let variant_name = &variant.ident; + let match_expression = match crate::common::get_args_type(variant) { + Some(constraint_type) => quote! { + Ok(#variant_name (<#constraint_type as ::fzn_rs::FlatZincConstraint>::from_ast(constraint)?)) + }, + None => { + let initialised_value = initialise_value(variant_name, &variant.fields); + let expected_num_arguments = variant.fields.len(); + + quote! { + if constraint.node.arguments.len() != #expected_num_arguments { + return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #expected_num_arguments, + actual: constraint.node.arguments.len(), + span: constraint.span, + }); + } + + Ok(#initialised_value) + } + } + }; + + quote! { + #name => { + #match_expression + } + } + }); + + quote! { + use #constraint_enum_name::*; + + match constraint.node.name.node.as_ref() { + #(#constraints)* + unknown => Err(::fzn_rs::InstanceError::UnsupportedConstraint( + String::from(unknown) + )), + } + } +} diff --git a/fzn-rs-derive/src/lib.rs b/fzn-rs-derive/src/lib.rs new file mode 100644 index 000000000..1f97c5adb --- /dev/null +++ b/fzn-rs-derive/src/lib.rs @@ -0,0 +1,115 @@ +mod annotation; +mod common; +mod constraint; + +use proc_macro::TokenStream; +use quote::quote; +use syn::parse_macro_input; +use syn::DeriveInput; + +#[proc_macro_derive(FlatZincConstraint, attributes(name, args))] +pub fn derive_flatzinc_constraint(item: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(item as DeriveInput); + + let type_name = derive_input.ident; + let implementation = match &derive_input.data { + syn::Data::Struct(data_struct) => { + let expected_num_arguments = data_struct.fields.len(); + + let struct_initialiser = constraint::initialise_value(&type_name, &data_struct.fields); + quote! { + if constraint.node.arguments.len() != #expected_num_arguments { + return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #expected_num_arguments, + actual: constraint.node.arguments.len(), + span: constraint.span, + }); + } + + Ok(#struct_initialiser) + } + } + syn::Data::Enum(data_enum) => { + constraint::flatzinc_constraint_for_enum(&type_name, data_enum) + } + syn::Data::Union(_) => quote! { + compile_error!("Cannot implement FlatZincConstraint on unions.") + }, + }; + + let token_stream = quote! { + #[automatically_derived] + impl ::fzn_rs::FlatZincConstraint for #type_name { + fn from_ast( + constraint: &::fzn_rs::ast::Node<::fzn_rs::ast::Constraint>, + ) -> Result { + #implementation + } + } + }; + + token_stream.into() +} + +#[proc_macro_derive(FlatZincAnnotation, attributes(name, annotation, args))] +pub fn derive_flatzinc_annotation(item: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(item as DeriveInput); + let annotatation_enum_name = derive_input.ident; + + let implementation = match derive_input.data { + syn::Data::Struct(data_struct) => { + let initialised_values = + annotation::initialise_value(&annotatation_enum_name, &data_struct.fields); + + let expected_num_arguments = data_struct.fields.len(); + + quote! { + match &annotation.node { + ::fzn_rs::ast::Annotation::Call(::fzn_rs::ast::AnnotationCall { + name, + arguments, + }) => { + #initialised_values + } + + _ => return Err(::fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected: #expected_num_arguments, + actual: 0, + span: annotation.span, + }), + } + } + } + syn::Data::Enum(data_enum) => { + let annotations = data_enum + .variants + .iter() + .map(annotation::variant_to_annotation); + + quote! { + use #annotatation_enum_name::*; + + match &annotation.node { + #(#annotations),* + _ => Ok(None), + } + } + } + syn::Data::Union(_) => quote! { + compile_error!("Cannot implement FlatZincAnnotation on unions.") + }, + }; + + let token_stream = quote! { + #[automatically_derived] + impl ::fzn_rs::FlatZincAnnotation for #annotatation_enum_name { + fn from_ast( + annotation: &::fzn_rs::ast::Node<::fzn_rs::ast::Annotation>, + ) -> Result, ::fzn_rs::InstanceError> { + #implementation + } + } + }; + + token_stream.into() +} diff --git a/fzn-rs-derive/tests/derive_flatzinc_annotation.rs b/fzn-rs-derive/tests/derive_flatzinc_annotation.rs new file mode 100644 index 000000000..aeb62ddc7 --- /dev/null +++ b/fzn-rs-derive/tests/derive_flatzinc_annotation.rs @@ -0,0 +1,397 @@ +#![cfg(test)] // workaround for https://github.com/rust-lang/rust-clippy/issues/11024 + +mod utils; + +use std::collections::BTreeMap; +use std::rc::Rc; + +use fzn_rs::ast::Annotation; +use fzn_rs::ast::AnnotationArgument; +use fzn_rs::ast::AnnotationCall; +use fzn_rs::ast::AnnotationLiteral; +use fzn_rs::ast::Argument; +use fzn_rs::ast::Ast; +use fzn_rs::ast::Domain; +use fzn_rs::ast::Literal; +use fzn_rs::ast::RangeList; +use fzn_rs::ast::Variable; +use fzn_rs::ArrayExpr; +use fzn_rs::TypedInstance; +use fzn_rs::VariableExpr; +use fzn_rs_derive::FlatZincAnnotation; +use fzn_rs_derive::FlatZincConstraint; +use utils::*; + +macro_rules! btreemap { + ($($key:expr => $value:expr,)+) => (btreemap!($($key => $value),+)); + + ( $($key:expr => $value:expr),* ) => { + { + let mut _map = ::std::collections::BTreeMap::new(); + $( + let _ = _map.insert($key, $value); + )* + _map + } + }; +} + +#[test] +fn annotation_without_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + OutputVar, + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Atom("output_var".into()))], + })], + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::OutputVar, + ); +} + +#[test] +fn annotation_with_positional_literal_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + DefinesVar(Rc), + OutputArray(RangeList), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![ + test_node(Annotation::Call(AnnotationCall { + name: "defines_var".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("some_var".into())), + )))], + })), + test_node(Annotation::Call(AnnotationCall { + name: "output_array".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::IntSet(RangeList::from(1..=5))), + )))], + })), + ], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::DefinesVar("some_var".into()), + ); + + assert_eq!( + instance.constraints[0].annotations[1].node, + TypedAnnotation::OutputArray(RangeList::from(1..=5)), + ); +} + +#[test] +fn annotation_with_named_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + DefinesVar { variable_id: Rc }, + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "defines_var".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("some_var".into())), + )))], + }))], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::DefinesVar { + variable_id: "some_var".into() + }, + ); +} + +#[test] +fn nested_annotation_as_argument() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + SomeAnnotation(#[annotation] SomeAnnotationArgs), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum SomeAnnotationArgs { + ArgOne, + ArgTwo(Rc), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![ + test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("arg_one".into())), + )))], + })), + test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::Annotation(AnnotationCall { + name: "arg_two".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Identifier("ident".into())), + )))], + }), + )))], + })), + ], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::SomeAnnotation(SomeAnnotationArgs::ArgOne), + ); + + assert_eq!( + instance.constraints[0].annotations[1].node, + TypedAnnotation::SomeAnnotation(SomeAnnotationArgs::ArgTwo("ident".into())), + ); +} + +#[test] +fn arrays_as_annotation_arguments_with_literal_elements() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + SomeAnnotation(ArrayExpr), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Array(vec![ + test_node(AnnotationLiteral::BaseLiteral(Literal::Int(1))), + test_node(AnnotationLiteral::BaseLiteral(Literal::Int(2))), + ]))], + }))], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + let TypedAnnotation::SomeAnnotation(args) = instance.constraints[0].annotations[0].node.clone(); + + let resolved_args = instance + .resolve_array(&args) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(resolved_args, vec![1, 2]); +} + +#[test] +fn arrays_as_annotation_arguments_with_annotation_elements() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + SomeAnnotation(#[annotation] Vec), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum ArrayElements { + ElementOne, + ElementTwo(i64), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![test_node(Argument::Literal(test_node( + Literal::Identifier("x3".into()), + )))], + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![test_node(AnnotationArgument::Array(vec![ + test_node(AnnotationLiteral::BaseLiteral(Literal::Identifier( + "element_one".into(), + ))), + test_node(AnnotationLiteral::Annotation(AnnotationCall { + name: "element_two".into(), + arguments: vec![test_node(AnnotationArgument::Literal(test_node( + AnnotationLiteral::BaseLiteral(Literal::Int(4)), + )))], + })), + ]))], + }))], + })], + + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.constraints[0].annotations[0].node, + TypedAnnotation::SomeAnnotation(vec![ + ArrayElements::ElementOne, + ArrayElements::ElementTwo(4) + ]), + ); +} + +#[test] +fn annotations_can_be_structs_for_arguments() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(VariableExpr), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum TypedAnnotation { + #[args] + SomeAnnotation(AnnotationArgs), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + struct AnnotationArgs { + ident: Rc, + #[annotation] + ann: OtherAnnotation, + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] + enum OtherAnnotation { + ElementOne, + ElementTwo(i64), + } + + type Instance = TypedInstance; + + let ast = Ast { + variables: btreemap! { + "x1".into() => test_node(Variable { + domain: test_node(Domain::UnboundedInt), + value: None, + annotations: vec![test_node(Annotation::Call(AnnotationCall { + name: "some_annotation".into(), + arguments: vec![ + test_node(AnnotationArgument::Literal(test_node(AnnotationLiteral::BaseLiteral(Literal::Identifier("some_ident".into()))))), + test_node(AnnotationArgument::Literal(test_node(AnnotationLiteral::BaseLiteral(Literal::Identifier("element_one".into()))))), + ], + }))], + }), + }, + arrays: BTreeMap::new(), + constraints: vec![], + solve: satisfy_solve(), + }; + + let instance = Instance::from_ast(ast).expect("valid instance"); + + assert_eq!( + instance.variables["x1"].annotations[0].node, + TypedAnnotation::SomeAnnotation(AnnotationArgs { + ident: "some_ident".into(), + ann: OtherAnnotation::ElementOne, + }), + ); +} diff --git a/fzn-rs-derive/tests/derive_flatzinc_constraint.rs b/fzn-rs-derive/tests/derive_flatzinc_constraint.rs new file mode 100644 index 000000000..2edc44525 --- /dev/null +++ b/fzn-rs-derive/tests/derive_flatzinc_constraint.rs @@ -0,0 +1,485 @@ +#![cfg(test)] // workaround for https://github.com/rust-lang/rust-clippy/issues/11024 + +mod utils; + +use std::collections::BTreeMap; + +use fzn_rs::ast::Argument; +use fzn_rs::ast::Array; +use fzn_rs::ast::Ast; +use fzn_rs::ast::Literal; +use fzn_rs::ast::Span; +use fzn_rs::ArrayExpr; +use fzn_rs::InstanceError; +use fzn_rs::TypedInstance; +use fzn_rs::VariableExpr; +use fzn_rs_derive::FlatZincConstraint; +use utils::*; + +#[test] +fn variant_with_unnamed_fields() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + IntLinLe(ArrayExpr, ArrayExpr>, i64), + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Array(vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ])), + test_node(Argument::Array(vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ])), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + let TypedConstraint::IntLinLe(weights, variables, bound) = + instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn variant_with_named_fields() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + IntLinLe { + weights: ArrayExpr, + variables: ArrayExpr>, + bound: i64, + }, + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Array(vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ])), + test_node(Argument::Array(vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ])), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + let TypedConstraint::IntLinLe { + weights, + variables, + bound, + } = instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn variant_with_name_attribute() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + #[name("int_lin_le")] + LinearInequality { + weights: ArrayExpr, + variables: ArrayExpr>, + bound: i64, + }, + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: BTreeMap::new(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Array(vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ])), + test_node(Argument::Array(vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ])), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + let TypedConstraint::LinearInequality { + weights, + variables, + bound, + } = instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn constraint_referencing_arrays() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + IntLinLe(ArrayExpr, ArrayExpr>, i64), + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: [ + ( + "array1".into(), + test_node(Array { + contents: vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ], + annotations: vec![], + }), + ), + ( + "array2".into(), + test_node(Array { + contents: vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ], + annotations: vec![], + }), + ), + ] + .into_iter() + .collect(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Identifier( + "array1".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Identifier( + "array2".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + + let TypedConstraint::IntLinLe(weights, variables, bound) = + instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(bound, 3); +} + +#[test] +fn constraint_as_struct_args() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + #[args] + IntLinLe(LinearLeq), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + struct LinearLeq { + weights: ArrayExpr, + variables: ArrayExpr>, + bound: i64, + } + + let ast = Ast { + variables: unbounded_variables(["x1", "x2", "x3"]), + arrays: [ + ( + "array1".into(), + test_node(Array { + contents: vec![ + test_node(Literal::Int(2)), + test_node(Literal::Int(3)), + test_node(Literal::Int(5)), + ], + annotations: vec![], + }), + ), + ( + "array2".into(), + test_node(Array { + contents: vec![ + test_node(Literal::Identifier("x1".into())), + test_node(Literal::Identifier("x2".into())), + test_node(Literal::Identifier("x3".into())), + ], + annotations: vec![], + }), + ), + ] + .into_iter() + .collect(), + constraints: vec![test_node(fzn_rs::ast::Constraint { + name: test_node("int_lin_le".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Identifier( + "array1".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Identifier( + "array2".into(), + )))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + })], + solve: satisfy_solve(), + }; + + let instance = TypedInstance::::from_ast(ast).expect("valid instance"); + + let TypedConstraint::IntLinLe(linear) = instance.constraints[0].clone().constraint.node; + + let weights = instance + .resolve_array(&linear.weights) + .unwrap() + .collect::, _>>() + .unwrap(); + let variables = instance + .resolve_array(&linear.variables) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(weights, vec![2, 3, 5]); + assert_eq!( + variables, + vec![ + VariableExpr::Identifier("x1".into()), + VariableExpr::Identifier("x2".into()), + VariableExpr::Identifier("x3".into()) + ] + ); + assert_eq!(linear.bound, 3); +} + +#[test] +fn argument_count_on_tuple_variants() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint(i64), + } + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![node( + fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Int(3)))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + }, + 0, + 10, + )], + solve: satisfy_solve(), + }; + + let error = TypedInstance::::from_ast(ast).expect_err("invalid instance"); + + assert_eq!( + error, + InstanceError::IncorrectNumberOfArguments { + expected: 1, + actual: 2, + span: Span { start: 0, end: 10 }, + } + ); +} + +#[test] +fn argument_count_on_named_fields_variant() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + SomeConstraint { constant: i64 }, + } + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![node( + fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Int(3)))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + }, + 0, + 10, + )], + solve: satisfy_solve(), + }; + + let error = TypedInstance::::from_ast(ast).expect_err("invalid instance"); + + assert_eq!( + error, + InstanceError::IncorrectNumberOfArguments { + expected: 1, + actual: 2, + span: Span { start: 0, end: 10 }, + } + ); +} + +#[test] +fn argument_count_on_args_struct() { + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + enum TypedConstraint { + #[args] + SomeConstraint(Args), + } + + #[derive(Clone, Debug, PartialEq, Eq, FlatZincConstraint)] + struct Args { + argument: i64, + } + + let ast = Ast { + variables: BTreeMap::new(), + arrays: BTreeMap::new(), + constraints: vec![node( + fzn_rs::ast::Constraint { + name: test_node("some_constraint".into()), + arguments: vec![ + test_node(Argument::Literal(test_node(Literal::Int(3)))), + test_node(Argument::Literal(test_node(Literal::Int(3)))), + ], + annotations: vec![], + }, + 0, + 10, + )], + solve: satisfy_solve(), + }; + + let error = TypedInstance::::from_ast(ast).expect_err("invalid instance"); + + assert_eq!( + error, + InstanceError::IncorrectNumberOfArguments { + expected: 1, + actual: 2, + span: Span { start: 0, end: 10 }, + } + ); +} diff --git a/fzn-rs-derive/tests/utils.rs b/fzn-rs-derive/tests/utils.rs new file mode 100644 index 000000000..8de6e5edb --- /dev/null +++ b/fzn-rs-derive/tests/utils.rs @@ -0,0 +1,49 @@ +#![allow( + dead_code, + reason = "it is used in other test files, but somehow compiler can't see it" +)] +#![cfg(test)] + +use std::collections::BTreeMap; +use std::rc::Rc; + +use fzn_rs::ast::{self}; + +pub(crate) fn satisfy_solve() -> ast::SolveItem { + ast::SolveItem { + method: test_node(ast::Method::Satisfy), + annotations: vec![], + } +} + +pub(crate) fn test_node(data: T) -> ast::Node { + node(data, usize::MAX, usize::MAX) +} + +pub(crate) fn node(data: T, span_start: usize, span_end: usize) -> ast::Node { + ast::Node { + node: data, + span: ast::Span { + start: span_start, + end: span_end, + }, + } +} + +pub(crate) fn unbounded_variables<'a>( + names: impl IntoIterator, +) -> BTreeMap, ast::Node>> { + names + .into_iter() + .map(|name| { + ( + Rc::from(name), + test_node(ast::Variable { + domain: test_node(ast::Domain::UnboundedInt), + value: None, + annotations: vec![], + }), + ) + }) + .collect() +} diff --git a/fzn-rs/Cargo.toml b/fzn-rs/Cargo.toml new file mode 100644 index 000000000..8b582f898 --- /dev/null +++ b/fzn-rs/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "fzn-rs" +version = "0.1.0" +repository.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true + +[dependencies] +chumsky = { version = "0.10.1", optional = true } +thiserror = "2.0.12" +fzn-rs-derive = { path = "../fzn-rs-derive/", optional = true } + +[features] +fzn = ["dep:chumsky"] +derive = ["dep:fzn-rs-derive"] + +[package.metadata.docs.rs] +features = ["derive"] + +[lints] +workspace = true diff --git a/fzn-rs/src/ast.rs b/fzn-rs/src/ast.rs new file mode 100644 index 000000000..24deaa413 --- /dev/null +++ b/fzn-rs/src/ast.rs @@ -0,0 +1,383 @@ +//! The AST representing a FlatZinc instance, compatible with both the JSON format and +//! the original FZN format. +//! +//! It is a modified version of the `FlatZinc` type from [`flatzinc-serde`](https://docs.rs/flatzinc-serde). +use std::collections::BTreeMap; +use std::fmt::Display; +use std::iter::FusedIterator; +use std::ops::RangeInclusive; +use std::rc::Rc; + +/// Represents a FlatZinc instance. +/// +/// In the `.fzn` format, identifiers can point to both constants and variables (either single or +/// arrays). In this AST, the constants are immediately resolved and are not kept in their original +/// form. Therefore, any [`Literal::Identifier`] points to a variable or an array. +/// +/// All identifiers are [`Rc`]s to allow parsers to re-use the allocation of the variable name. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Ast { + /// A mapping from identifiers to variables. + pub variables: BTreeMap, Node>>, + /// The arrays in this instance. + pub arrays: BTreeMap, Node>>, + /// A list of constraints. + pub constraints: Vec>, + /// The goal of the model. + pub solve: SolveItem, +} + +/// A decision variable. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Variable { + /// The domain of the variable. + pub domain: Node, + /// Optionally, the value that the variable is equal to. + pub value: Option>, + /// The annotations on this variable. + pub annotations: Vec>, +} + +/// A named array of literals. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Array { + /// The domain of the elements of the array. + pub domain: Node, + /// The elements of the array. + pub contents: Vec>, + /// The annotations associated with this array. + pub annotations: Vec>, +} + +/// The domain of a [`Variable`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Domain { + /// The set of all integers. + UnboundedInt, + /// A finite set of integer values. + Int(RangeList), + /// A boolean domain. + Bool, +} + +/// Holds a non-empty set of values. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RangeList { + /// A sorted list of intervals. + /// + /// Invariant: Consecutive intervals are merged. + intervals: Vec<(E, E)>, +} + +impl RangeList { + /// The smallest element in the set. + pub fn lower_bound(&self) -> &E { + &self.intervals[0].0 + } + + /// The largest element in the set. + pub fn upper_bound(&self) -> &E { + let last_idx = self.intervals.len() - 1; + + &self.intervals[last_idx].1 + } + + /// Returns `true` if the set is a continious range from [`Self::lower_bound`] to + /// [`Self::upper_bound`]. + pub fn is_continuous(&self) -> bool { + self.intervals.len() == 1 + } +} + +macro_rules! impl_iter_fn { + ($int_type:ty) => { + impl<'a> IntoIterator for &'a RangeList<$int_type> { + type Item = $int_type; + + type IntoIter = RangeListIter<'a, $int_type>; + + fn into_iter(self) -> Self::IntoIter { + RangeListIter { + current_interval: self.intervals.first().copied().unwrap_or((1, 0)), + tail: &self.intervals[1..], + } + } + } + }; +} + +impl_iter_fn!(i32); +impl_iter_fn!(i64); + +impl From> for RangeList { + fn from(value: RangeInclusive) -> Self { + RangeList { + intervals: vec![(*value.start(), *value.end())], + } + } +} + +macro_rules! range_list_from_iter { + ($int_type:ty) => { + impl FromIterator<$int_type> for RangeList<$int_type> { + fn from_iter>(iter: T) -> Self { + let mut intervals: Vec<_> = iter.into_iter().map(|e| (e, e)).collect(); + intervals.sort(); + intervals.dedup(); + + let mut idx = 0; + + while idx < intervals.len() - 1 { + let current = intervals[idx]; + let next = intervals[idx + 1]; + + if current.1 >= next.0 - 1 { + intervals[idx] = (current.0, next.1); + let _ = intervals.remove(idx + 1); + } else { + idx += 1; + } + } + + RangeList { intervals } + } + } + }; +} + +range_list_from_iter!(i32); +range_list_from_iter!(i64); + +/// An [`Iterator`] over a [`RangeList`]. +#[derive(Debug)] +pub struct RangeListIter<'a, E> { + current_interval: (E, E), + tail: &'a [(E, E)], +} + +macro_rules! impl_range_list_iter { + ($int_type:ty) => { + impl Iterator for RangeListIter<'_, $int_type> { + type Item = $int_type; + + fn next(&mut self) -> Option { + let (current_lb, current_ub) = self.current_interval; + + if current_lb > current_ub { + let (next_interval, new_tail) = self.tail.split_first()?; + self.current_interval = *next_interval; + self.tail = new_tail; + } + + let current_lb = self.current_interval.0; + self.current_interval.0 += 1; + + Some(current_lb) + } + } + + impl FusedIterator for RangeListIter<'_, $int_type> {} + }; +} + +impl_range_list_iter!(i32); +impl_range_list_iter!(i64); + +/// The foundational element from which expressions are built. Literals are the values/identifiers +/// in the model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Literal { + Int(i64), + Identifier(Rc), + Bool(bool), + IntSet(RangeList), +} + +impl From for Literal { + fn from(value: i64) -> Self { + Literal::Int(value) + } +} + +/// The solve item. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SolveItem { + pub method: Node, + pub annotations: Vec>, +} + +/// Whether to satisfy or optimise the model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Method { + Satisfy, + Optimize { + direction: OptimizationDirection, + objective: Literal, + }, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum OptimizationDirection { + Minimize, + Maximize, +} + +/// A constraint definition. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Constraint { + /// The name of the constraint. + pub name: Node>, + /// The list of arguments. + pub arguments: Vec>, + /// Any annotations on the constraint. + pub annotations: Vec>, +} + +/// An argument for a [`Constraint`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Argument { + Array(Vec>), + Literal(Node), +} + +/// An annotation on any item in the model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Annotation { + /// An annotation without arguments. + Atom(Rc), + /// An annotation with arguments. + Call(AnnotationCall), +} + +impl Annotation { + /// Get the name of the annotation. + pub fn name(&self) -> &str { + match self { + Annotation::Atom(name) => name, + Annotation::Call(call) => &call.name, + } + } +} + +/// An annotation with arguments. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct AnnotationCall { + /// The name of the annotation. + pub name: Rc, + /// Any arguments for the annotation. + pub arguments: Vec>, +} + +/// An individual argument for an [`Annotation`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum AnnotationArgument { + Array(Vec>), + Literal(Node), +} + +/// An annotation literal is either a regular [`Literal`] or it is another annotation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum AnnotationLiteral { + BaseLiteral(Literal), + /// In the FZN grammar, this is an `Annotation` instead of an `AnnotationCall`. We diverge from + /// the grammar to avoid the case where the same input can parse to either a + /// `Annotation::Atom(ident)` or an `Literal::Identifier`. + Annotation(AnnotationCall), +} + +/// Describes a range `[start, end)` in the model file that contains a [`Node`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Span { + /// The index in the source that starts the span. + pub start: usize, + /// The index in the source that ends the span. + /// + /// Note the end is exclusive. + pub end: usize, +} + +impl Display for Span { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}, {})", self.start, self.end) + } +} + +#[cfg(feature = "fzn")] +impl chumsky::span::Span for Span { + type Context = (); + + type Offset = usize; + + fn new(_: Self::Context, range: std::ops::Range) -> Self { + Self { + start: range.start, + end: range.end, + } + } + + fn context(&self) -> Self::Context {} + + fn start(&self) -> Self::Offset { + self.start + } + + fn end(&self) -> Self::Offset { + self.end + } +} + +#[cfg(feature = "fzn")] +impl From for Span { + fn from(value: chumsky::span::SimpleSpan) -> Self { + Span { + start: value.start, + end: value.end, + } + } +} + +#[cfg(feature = "fzn")] +impl From for chumsky::span::SimpleSpan { + fn from(value: Span) -> Self { + chumsky::span::SimpleSpan::from(value.start..value.end) + } +} + +/// A node in the [`Ast`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Node { + /// The span in the source of this node. + pub span: Span, + /// The parsed node. + pub node: T, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rangelist_from_iter_identifies_continuous_ranges() { + let set = RangeList::from_iter([1, 2, 3, 4]); + + assert!(set.is_continuous()); + } + + #[test] + fn rangelist_from_iter_identifiers_non_continuous_ranges() { + let set = RangeList::from_iter([1, 3, 4, 6]); + + assert!(!set.is_continuous()); + } + + #[test] + fn rangelist_iter_produces_elements_in_set() { + let set: RangeList = RangeList::from_iter([1, 3, 5]); + + let mut iter = set.into_iter(); + assert_eq!(Some(1), iter.next()); + assert_eq!(Some(3), iter.next()); + assert_eq!(Some(5), iter.next()); + assert_eq!(None, iter.next()); + } +} diff --git a/fzn-rs/src/error.rs b/fzn-rs/src/error.rs new file mode 100644 index 000000000..337e7c681 --- /dev/null +++ b/fzn-rs/src/error.rs @@ -0,0 +1,92 @@ +use std::fmt::Display; + +use crate::ast; + +#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +pub enum InstanceError { + #[error("constraint '{0}' is not supported")] + UnsupportedConstraint(String), + + #[error("annotation '{0}' is not supported")] + UnsupportedAnnotation(String), + + #[error("expected {expected}, got {actual} at {span}")] + UnexpectedToken { + expected: Token, + actual: Token, + span: ast::Span, + }, + + #[error("array {0} is undefined")] + UndefinedArray(String), + + #[error("expected {expected} arguments, got {actual}")] + IncorrectNumberOfArguments { + expected: usize, + actual: usize, + span: ast::Span, + }, + + #[error("value {0} does not fit in the required integer type")] + IntegerOverflow(i64), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Token { + Identifier, + IntLiteral, + BoolLiteral, + IntSetLiteral, + + Array, + Variable(Box), + AnnotationCall, + Annotation, +} + +impl From<&'_ ast::Literal> for Token { + fn from(value: &'_ ast::Literal) -> Self { + match value { + ast::Literal::Int(_) => Token::IntLiteral, + ast::Literal::Identifier(_) => Token::Identifier, + ast::Literal::Bool(_) => Token::BoolLiteral, + ast::Literal::IntSet(_) => Token::IntSetLiteral, + } + } +} + +impl From<&'_ ast::AnnotationArgument> for Token { + fn from(value: &'_ ast::AnnotationArgument) -> Self { + match value { + ast::AnnotationArgument::Array(_) => Token::Array, + ast::AnnotationArgument::Literal(literal) => (&literal.node).into(), + } + } +} + +impl From<&'_ ast::AnnotationLiteral> for Token { + fn from(value: &'_ ast::AnnotationLiteral) -> Self { + match value { + ast::AnnotationLiteral::BaseLiteral(literal) => literal.into(), + ast::AnnotationLiteral::Annotation(_) => Token::AnnotationCall, + } + } +} + +impl Display for Token { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::Identifier => write!(f, "identifier"), + + Token::IntLiteral => write!(f, "int"), + Token::BoolLiteral => write!(f, "bool"), + Token::IntSetLiteral => write!(f, "int set"), + + Token::AnnotationCall => write!(f, "annotation"), + Token::Annotation => write!(f, "annotation"), + + Token::Array => write!(f, "array"), + Token::Variable(token) => write!(f, "{token} variable"), + } + } +} diff --git a/fzn-rs/src/fzn/mod.rs b/fzn-rs/src/fzn/mod.rs new file mode 100644 index 000000000..95f29f095 --- /dev/null +++ b/fzn-rs/src/fzn/mod.rs @@ -0,0 +1,1115 @@ +use std::collections::BTreeMap; +use std::collections::BTreeSet; +use std::rc::Rc; + +use chumsky::error::Rich; +use chumsky::extra; +use chumsky::input::Input; +use chumsky::input::MapExtra; +use chumsky::input::ValueInput; +use chumsky::prelude::any; +use chumsky::prelude::choice; +use chumsky::prelude::just; +use chumsky::prelude::recursive; +use chumsky::select; +use chumsky::span::SimpleSpan; +use chumsky::IterParser; +use chumsky::Parser; + +use crate::ast; + +mod tokens; + +pub use tokens::Token; +use tokens::Token::*; + +#[derive(Clone, Debug, Default)] +struct ParseState { + /// The identifiers encountered so far. + strings: BTreeSet>, + /// Parameters + parameters: BTreeMap, ParameterValue>, +} + +impl ParseState { + fn get_interned(&mut self, string: &str) -> Rc { + if !self.strings.contains(string) { + let _ = self.strings.insert(Rc::from(string)); + } + + Rc::clone(self.strings.get(string).unwrap()) + } + + fn resolve_literal(&self, literal: ast::Literal) -> ast::Literal { + match literal { + ast::Literal::Identifier(ident) => self + .parameters + .get(&ident) + .map(|value| match value { + ParameterValue::Bool(boolean) => ast::Literal::Bool(*boolean), + ParameterValue::Int(int) => ast::Literal::Int(*int), + ParameterValue::IntSet(set) => ast::Literal::IntSet(set.clone()), + }) + .unwrap_or(ast::Literal::Identifier(ident)), + + lit @ (ast::Literal::Int(_) | ast::Literal::Bool(_) | ast::Literal::IntSet(_)) => lit, + } + } +} + +#[derive(Clone, Debug)] +enum ParameterValue { + Bool(bool), + Int(i64), + IntSet(ast::RangeList), +} + +#[derive(Debug, thiserror::Error)] +pub enum FznError<'src> { + #[error("failed to lex fzn")] + LexError { + reasons: Vec>, + }, + + #[error("failed to parse fzn")] + ParseError { + reasons: Vec, ast::Span>>, + }, +} + +pub fn parse(source: &str) -> Result> { + let mut state = extra::SimpleState(ParseState::default()); + + let tokens = tokens::lex() + .parse(source) + .into_result() + .map_err(|reasons| FznError::LexError { reasons })?; + + let parser_input = tokens.map( + ast::Span { + start: source.len(), + end: source.len(), + }, + |node| (&node.node, &node.span), + ); + + let ast = predicates() + .ignore_then(parameters()) + .ignore_then(arrays()) + .then(variables()) + .then(arrays()) + .then(constraints()) + .then(solve_item()) + .map( + |((((parameter_arrays, variables), variable_arrays), constraints), solve)| { + let mut arrays = parameter_arrays; + arrays.extend(variable_arrays); + + ast::Ast { + variables, + arrays, + constraints, + solve, + } + }, + ) + .parse_with_state(parser_input, &mut state) + .into_result() + .map_err( + |reasons: Vec, _>>| FznError::ParseError { + reasons: reasons + .into_iter() + .map(|error| error.into_owned()) + .collect(), + }, + )?; + + Ok(ast) +} + +/// The extra data attached to the chumsky parsers. +/// +/// We specify a rich error type, as well as an instance of [`ParseState`] for string interning and +/// parameter resolution. +type FznExtra<'tokens, 'src> = + extra::Full, ast::Span>, extra::SimpleState, ()>; + +fn predicates<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + predicate().repeated().collect::>().ignored() +} + +fn predicate<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("predicate")) + .ignore_then(any().and_is(just(SemiColon).not()).repeated()) + .then(just(SemiColon)) + .ignored() +} + +fn parameters<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + parameter().repeated().collect::>().ignored() +} + +fn parameter_type<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + just(Ident("int")), + just(Ident("bool")), + just(Ident("set")) + .then_ignore(just(Ident("of"))) + .then_ignore(just(Ident("int"))), + )) + .ignored() +} + +fn parameter<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, (), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + parameter_type() + .ignore_then(just(Colon)) + .ignore_then(identifier()) + .then_ignore(just(Equal)) + .then(literal()) + .then_ignore(just(SemiColon)) + .try_map_with(|(name, value), extra| { + let state = extra.state(); + + let value = match value.node { + ast::Literal::Int(int) => ParameterValue::Int(int), + ast::Literal::Bool(boolean) => ParameterValue::Bool(boolean), + ast::Literal::IntSet(set) => ParameterValue::IntSet(set), + ast::Literal::Identifier(identifier) => { + return Err(Rich::custom( + value.span, + format!("parameter '{identifier}' is undefined"), + )) + } + }; + + let _ = state.parameters.insert(name, value); + + Ok(()) + }) +} + +fn arrays<'tokens, 'src: 'tokens, I>() -> impl Parser< + 'tokens, + I, + BTreeMap, ast::Node>>, + FznExtra<'tokens, 'src>, +> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + array() + .repeated() + .collect::>() + .map(|arrays| arrays.into_iter().collect()) +} + +fn array<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, (Rc, ast::Node>), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("array")) + .ignore_then(interval_set(integer()).delimited_by(just(OpenBracket), just(CloseBracket))) + .ignore_then(just(Ident("of"))) + .ignore_then(just(Ident("var")).or_not()) + .ignore_then(domain()) + .then_ignore(just(Colon)) + .then(identifier()) + .then(annotations()) + .then_ignore(just(Equal)) + .then( + literal() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBracket), just(CloseBracket)), + ) + .then_ignore(just(SemiColon)) + .map_with(|(((domain, name), annotations), contents), extra| { + ( + name, + ast::Node { + node: ast::Array { + domain, + contents, + annotations, + }, + span: extra.span(), + }, + ) + }) +} + +fn variables<'tokens, 'src: 'tokens, I>() -> impl Parser< + 'tokens, + I, + BTreeMap, ast::Node>>, + FznExtra<'tokens, 'src>, +> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + variable() + .repeated() + .collect::>() + .map(|variables| variables.into_iter().collect()) +} + +fn variable<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, (Rc, ast::Node>), FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("var")) + .ignore_then(domain()) + .then_ignore(just(Colon)) + .then(identifier()) + .then(annotations()) + .then(just(Equal).ignore_then(literal()).or_not()) + .then_ignore(just(SemiColon)) + .map_with(to_node) + .map(|node| { + let ast::Node { + node: (((domain, name), annotations), value), + span, + } = node; + + let variable = ast::Variable { + domain, + value, + annotations, + }; + + ( + name, + ast::Node { + node: variable, + span, + }, + ) + }) +} + +fn domain<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + just(Ident("int")).to(ast::Domain::UnboundedInt), + just(Ident("bool")).to(ast::Domain::Bool), + set_of(integer()).map(ast::Domain::Int), + )) + .map_with(to_node) +} + +fn constraints<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, Vec>, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + constraint().repeated().collect::>() +} + +fn constraint<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("constraint")) + .ignore_then(identifier().map_with(to_node)) + .then( + argument() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenParen), just(CloseParen)), + ) + .then(annotations()) + .then_ignore(just(SemiColon)) + .map(|((name, arguments), annotations)| ast::Constraint { + name, + arguments, + annotations, + }) + .map_with(to_node) +} + +fn argument<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + literal().map(ast::Argument::Literal), + literal() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBracket), just(CloseBracket)) + .map(ast::Argument::Array), + )) + .map_with(to_node) +} + +fn solve_item<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::SolveItem, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(Ident("solve")) + .ignore_then(annotations()) + .then(solve_method()) + .then_ignore(just(SemiColon)) + .map(|(annotations, method)| ast::SolveItem { + method, + annotations, + }) +} + +fn solve_method<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + just(Ident("satisfy")).to(ast::Method::Satisfy), + just(Ident("minimize")) + .ignore_then(identifier()) + .map(|ident| ast::Method::Optimize { + direction: ast::OptimizationDirection::Minimize, + objective: ast::Literal::Identifier(ident), + }), + just(Ident("maximize")) + .ignore_then(identifier()) + .map(|ident| ast::Method::Optimize { + direction: ast::OptimizationDirection::Maximize, + objective: ast::Literal::Identifier(ident), + }), + )) + .map_with(to_node) +} + +fn annotations<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, Vec>, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + annotation().repeated().collect() +} + +fn annotation<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + just(DoubleColon) + .ignore_then(choice(( + annotation_call().map(ast::Annotation::Call), + identifier().map(ast::Annotation::Atom), + ))) + .map_with(to_node) +} + +fn annotation_call<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::AnnotationCall, FznExtra<'tokens, 'src>> +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + recursive(|call| { + identifier() + .then( + annotation_argument(call) + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenParen), just(CloseParen)), + ) + .map(|(name, arguments)| ast::AnnotationCall { name, arguments }) + }) +} + +fn annotation_argument<'tokens, 'src: 'tokens, I>( + call_parser: impl Parser<'tokens, I, ast::AnnotationCall, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + annotation_literal(call_parser.clone()).map(ast::AnnotationArgument::Literal), + annotation_literal(call_parser) + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBracket), just(CloseBracket)) + .map(ast::AnnotationArgument::Array), + )) + .map_with(to_node) +} + +fn annotation_literal<'tokens, 'src: 'tokens, I>( + call_parser: impl Parser<'tokens, I, ast::AnnotationCall, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + call_parser + .map(ast::AnnotationLiteral::Annotation) + .map_with(to_node), + literal().map(|node| ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(node.node), + span: node.span, + }), + )) +} + +fn to_node<'tokens, 'src: 'tokens, I, T>( + node: T, + extra: &mut MapExtra<'tokens, '_, I, FznExtra<'tokens, 'src>>, +) -> ast::Node +where + I: Input<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + ast::Node { + node, + span: extra.span(), + } +} + +fn literal<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, ast::Node, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + choice(( + set_of(integer()).map(ast::Literal::IntSet), + integer().map(ast::Literal::Int), + boolean().map(ast::Literal::Bool), + identifier().map(ast::Literal::Identifier), + )) + .map_with(|literal, extra| { + let state = extra.state(); + state.resolve_literal(literal) + }) + .map_with(to_node) +} + +fn set_of<'tokens, 'src: 'tokens, I, T: Copy + Ord>( + value_parser: impl Parser<'tokens, I, T, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, ast::RangeList, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, + ast::RangeList: FromIterator, +{ + let sparse_set = value_parser + .clone() + .separated_by(just(Comma)) + .collect::>() + .delimited_by(just(OpenBrace), just(CloseBrace)) + .map(ast::RangeList::from_iter); + + choice(( + sparse_set, + interval_set(value_parser).map(|(lb, ub)| ast::RangeList::from(lb..=ub)), + )) +} + +fn interval_set<'tokens, 'src: 'tokens, I, T: Copy + Ord>( + value_parser: impl Parser<'tokens, I, T, FznExtra<'tokens, 'src>> + Clone, +) -> impl Parser<'tokens, I, (T, T), FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + value_parser + .clone() + .then_ignore(just(DoublePeriod)) + .then(value_parser) +} + +fn integer<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, i64, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + select! { + Integer(int) => int, + } +} + +fn boolean<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, bool, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + select! { + Boolean(boolean) => boolean, + } +} + +fn identifier<'tokens, 'src: 'tokens, I>( +) -> impl Parser<'tokens, I, Rc, FznExtra<'tokens, 'src>> + Clone +where + I: ValueInput<'tokens, Span = ast::Span, Token = Token<'src>>, +{ + select! { + Ident(ident) => ident, + } + .map_with(|ident, extra| { + let state: &mut extra::SimpleState = extra.state(); + state.get_interned(ident) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! btreemap { + ($($key:expr => $value:expr,)+) => (btreemap!($($key => $value),+)); + + ( $($key:expr => $value:expr),* ) => { + { + let mut _map = ::std::collections::BTreeMap::new(); + $( + let _ = _map.insert($key, $value); + )* + _map + } + }; + } + + #[test] + fn empty_satisfaction_model() { + let source = r#" + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(15, 22, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn comments_are_ignored() { + let source = r#" + % Generated by MiniZinc 2.9.2 + % Solver: bla + solve satisfy; % This is ignored + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(75, 82, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn predicate_statements_are_ignored() { + let source = r#" + predicate some_predicate(int: xs, var int: ys); + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(71, 78, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn empty_minimization_model() { + let source = r#" + solve minimize objective; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node( + 15, + 33, + ast::Method::Optimize { + direction: ast::OptimizationDirection::Minimize, + objective: ast::Literal::Identifier("objective".into()), + } + ), + annotations: vec![], + } + } + ); + } + + #[test] + fn empty_maximization_model() { + let source = r#" + solve maximize objective; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node( + 15, + 33, + ast::Method::Optimize { + direction: ast::OptimizationDirection::Maximize, + objective: ast::Literal::Identifier("objective".into()), + } + ), + annotations: vec![], + } + } + ); + } + + #[test] + fn variables() { + let source = r#" + var 1..5: x_interval; + var bool: x_bool; + var {1, 3, 5}: x_sparse; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x_interval".into() => node(9, 30, ast::Variable { + domain: node(13, 17, ast::Domain::Int(ast::RangeList::from(1..=5))), + value: None, + annotations: vec![] + }), + "x_bool".into() => node(39, 56, ast::Variable { + domain: node(43, 47, ast::Domain::Bool), + value: None, + annotations: vec![] + }), + "x_sparse".into() => node(65, 89, ast::Variable { + domain: node(69, 78, ast::Domain::Int(ast::RangeList::from_iter([1, 3, 5]))), + value: None, + annotations: vec![] + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(104, 111, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn variable_with_assignment() { + let source = r#" + var 5..5: x1 = 5; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x1".into() => node(9, 26, ast::Variable { + domain: node(13, 17, ast::Domain::Int(ast::RangeList::from(5..=5))), + value: Some(node(24, 25, ast::Literal::Int(5))), + annotations: vec![] + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(41, 48, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn variable_with_assignment_to_named_constant() { + let source = r#" + int: y = 5; + var 5..5: x1 = y; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x1".into() => node(29, 46, ast::Variable { + domain: node(33, 37, ast::Domain::Int(ast::RangeList::from(5..=5))), + value: Some(node(44, 45, ast::Literal::Int(5))), + annotations: vec![] + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(61, 68, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn arrays_of_constants_and_variables() { + let source = r#" + int: p = 5; + array [1..3] of int: ys = [1, 3, p]; + + var int: some_var; + array [1..2] of var int: vars = [1, some_var]; + + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "some_var".into() => node(75, 93, ast::Variable { + domain: node(79, 82, ast::Domain::UnboundedInt), + value: None, + annotations: vec![] + }), + }, + arrays: btreemap! { + "ys".into() => node(29, 65, ast::Array { + domain: node(45, 48, ast::Domain::UnboundedInt), + contents: vec![ + node(56, 57, ast::Literal::Int(1)), + node(59, 60, ast::Literal::Int(3)), + node(62, 63, ast::Literal::Int(5)), + ], + annotations: vec![], + }), + "vars".into() => node(102, 148, ast::Array { + domain: node(122, 125, ast::Domain::UnboundedInt), + contents: vec![ + node(135, 136, ast::Literal::Int(1)), + node(138, 146, ast::Literal::Identifier("some_var".into())), + ], + annotations: vec![], + }), + }, + constraints: vec![], + solve: ast::SolveItem { + method: node(164, 171, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn constraint_item() { + let source = r#" + constraint int_lin_le(weights, [x1, x2, 3], 3); + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![node( + 9, + 56, + ast::Constraint { + name: node(20, 30, "int_lin_le".into()), + arguments: vec![ + node( + 31, + 38, + ast::Argument::Literal(node( + 31, + 38, + ast::Literal::Identifier("weights".into()) + )) + ), + node( + 40, + 51, + ast::Argument::Array(vec![ + node(41, 43, ast::Literal::Identifier("x1".into())), + node(45, 47, ast::Literal::Identifier("x2".into())), + node(49, 50, ast::Literal::Int(3)), + ]), + ), + node( + 53, + 54, + ast::Argument::Literal(node(53, 54, ast::Literal::Int(3))) + ), + ], + annotations: vec![], + } + )], + solve: ast::SolveItem { + method: node(71, 78, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_variables() { + let source = r#" + var 5..5: x1 :: output_var = 5; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: btreemap! { + "x1".into() => node(9, 40, ast::Variable { + domain: node(13, 17, ast::Domain::Int(ast::RangeList::from(5..=5))), + value: Some(node(38, 39, ast::Literal::Int(5))), + annotations: vec![node(22, 35, ast::Annotation::Atom("output_var".into()))], + }), + }, + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(55, 62, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_arrays() { + let source = r#" + array [1..2] of var 1..10: xs :: output_array([1..2]) = []; + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: btreemap! { + "xs".into() => node(9, 68, ast::Array { + domain: node(29, 34, ast::Domain::Int(ast::RangeList::from(1..=10))), + contents: vec![], + annotations: vec![ + node(39, 62, ast::Annotation::Call(ast::AnnotationCall { + name: "output_array".into(), + arguments: vec![node(55, 61, + ast::AnnotationArgument::Array(vec![node(56, 60, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::IntSet( + ast::RangeList::from(1..=2) + ) + ) + )]) + )], + })), + ], + }), + }, + constraints: vec![], + solve: ast::SolveItem { + method: node(83, 90, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_constraints() { + let source = r#" + constraint predicate() :: defines_var(x1); + solve satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![node( + 9, + 51, + ast::Constraint { + name: node(20, 29, "predicate".into()), + arguments: vec![], + annotations: vec![node( + 32, + 50, + ast::Annotation::Call(ast::AnnotationCall { + name: "defines_var".into(), + arguments: vec![node( + 47, + 49, + ast::AnnotationArgument::Literal(node( + 47, + 49, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::Identifier("x1".into()) + ) + )) + )] + }) + )], + } + )], + solve: ast::SolveItem { + method: node(66, 73, ast::Method::Satisfy), + annotations: vec![], + } + } + ); + } + + #[test] + fn annotations_on_solve_item() { + let source = r#" + solve :: int_search(first_fail(xs), indomain_min) satisfy; + "#; + + let ast = parse(source).expect("valid fzn"); + + assert_eq!( + ast, + ast::Ast { + variables: BTreeMap::default(), + arrays: BTreeMap::default(), + constraints: vec![], + solve: ast::SolveItem { + method: node(59, 66, ast::Method::Satisfy), + annotations: vec![node( + 15, + 58, + ast::Annotation::Call(ast::AnnotationCall { + name: "int_search".into(), + arguments: vec![ + node( + 29, + 43, + ast::AnnotationArgument::Literal(node( + 29, + 43, + ast::AnnotationLiteral::Annotation(ast::AnnotationCall { + name: "first_fail".into(), + arguments: vec![node( + 40, + 42, + ast::AnnotationArgument::Literal(node( + 40, + 42, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::Identifier("xs".into()) + ) + )) + )] + }) + )) + ), + node( + 45, + 57, + ast::AnnotationArgument::Literal(node( + 45, + 57, + ast::AnnotationLiteral::BaseLiteral( + ast::Literal::Identifier("indomain_min".into()) + ) + )) + ), + ] + }) + )], + } + } + ); + } + + fn node(start: usize, end: usize, data: T) -> ast::Node { + ast::Node { + node: data, + span: ast::Span { start, end }, + } + } +} diff --git a/fzn-rs/src/fzn/tokens.rs b/fzn-rs/src/fzn/tokens.rs new file mode 100644 index 000000000..8c8de29b4 --- /dev/null +++ b/fzn-rs/src/fzn/tokens.rs @@ -0,0 +1,113 @@ +use std::fmt::Display; + +use chumsky::error::Rich; +use chumsky::extra::{self}; +use chumsky::prelude::any; +use chumsky::prelude::choice; +use chumsky::prelude::just; +use chumsky::text::ascii::ident; +use chumsky::text::int; +use chumsky::IterParser; +use chumsky::Parser; + +use crate::ast; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Token<'src> { + OpenParen, + CloseParen, + OpenBracket, + CloseBracket, + OpenBrace, + CloseBrace, + Comma, + Colon, + DoubleColon, + SemiColon, + DoublePeriod, + Equal, + Ident(&'src str), + Integer(i64), + Boolean(bool), +} + +impl Display for Token<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::OpenParen => write!(f, "("), + Token::CloseParen => write!(f, ")"), + Token::OpenBracket => write!(f, "["), + Token::CloseBracket => write!(f, "]"), + Token::OpenBrace => write!(f, "{{"), + Token::CloseBrace => write!(f, "}}"), + Token::Comma => write!(f, ","), + Token::Colon => write!(f, ":"), + Token::DoubleColon => write!(f, "::"), + Token::SemiColon => write!(f, ";"), + Token::DoublePeriod => write!(f, ".."), + Token::Equal => write!(f, "="), + Token::Ident(ident) => write!(f, "{ident}"), + Token::Integer(int) => write!(f, "{int}"), + Token::Boolean(boolean) => write!(f, "{boolean}"), + } + } +} + +type LexExtra<'src> = extra::Err>; + +pub(super) fn lex<'src>( +) -> impl Parser<'src, &'src str, Vec>>, LexExtra<'src>> { + token() + .padded_by(comment().repeated()) + .padded() + .repeated() + .collect() +} + +fn comment<'src>() -> impl Parser<'src, &'src str, (), extra::Err>> { + just("%") + .then(any().and_is(just('\n').not()).repeated()) + .padded() + .ignored() +} + +fn token<'src>( +) -> impl Parser<'src, &'src str, ast::Node>, extra::Err>> { + choice(( + // Punctuation + just(";").to(Token::SemiColon), + just("::").to(Token::DoubleColon), + just(":").to(Token::Colon), + just(",").to(Token::Comma), + just("..").to(Token::DoublePeriod), + just("[").to(Token::OpenBracket), + just("]").to(Token::CloseBracket), + just("{").to(Token::OpenBrace), + just("}").to(Token::CloseBrace), + just("(").to(Token::OpenParen), + just(")").to(Token::CloseParen), + just("=").to(Token::Equal), + // Values + just("true").to(Token::Boolean(true)), + just("false").to(Token::Boolean(false)), + int_literal().map(Token::Integer), + // Identifiers (including keywords) + ident().map(Token::Ident), + )) + .map_with(|token, extra| { + let span: chumsky::prelude::SimpleSpan = extra.span(); + + ast::Node { + node: token, + span: span.into(), + } + }) +} + +fn int_literal<'src>() -> impl Parser<'src, &'src str, i64, LexExtra<'src>> { + just("-") + .or_not() + .ignore_then(int(10)) + .to_slice() + .map(|slice: &str| slice.parse().unwrap()) +} diff --git a/fzn-rs/src/lib.rs b/fzn-rs/src/lib.rs new file mode 100644 index 000000000..8ce4468d1 --- /dev/null +++ b/fzn-rs/src/lib.rs @@ -0,0 +1,116 @@ +//! # fzn-rs +//! +//! `fzn-rs` is a crate that allows for easy parsing of FlatZinc instances in Rust. +//! +//! ## Comparison to other FlatZinc crates +//! There are two well-known crates for parsing FlatZinc files: +//! - [flatzinc](https://docs.rs/flatzinc), for parsing the original `fzn` format, +//! - and [flatzinc-serde](https://docs.rs/flatzinc-serde), for parsing `fzn.json`. +//! +//! The goal of this crate is to be able to parse both the original `fzn` format, as well as the +//! newer `fzn.json` format. Additionally, there is a derive macro that allows for strongly-typed +//! constraints as they are supported by your application. Finally, our aim is to improve the error +//! messages that are encountered when parsing invalid FlatZinc files. +//! +//! ## Typed Instance +//! The main type exposed by the crate is [`TypedInstance`], which is a fully typed representation +//! of a FlatZinc model. +//! +//! ``` +//! use fzn_rs::TypedInstance; +//! +//! enum Constraints { +//! // ... +//! } +//! +//! type Instance = TypedInstance; +//! ``` +//! +//! ## Derive Macro +//! When parsing a FlatZinc file, the result is an [`ast::Ast`]. That type describes any valid +//! FlatZinc file. However, when consuming FlatZinc, typically you need to process that AST +//! further. For example, to support the [`int_lin_le`][1] constraint, you have to validate that the +//! [`ast::Constraint`] has three arguments, and that each of the arguments has the correct type. +//! +//! When using this crate with the `derive` feature, you can instead do the following: +//! ```rust +//! use fzn_rs::ArrayExpr; +//! use fzn_rs::FlatZincConstraint; +//! use fzn_rs::VariableExpr; +//! +//! #[derive(FlatZincConstraint)] +//! pub enum MyConstraints { +//! /// The variant name is converted to snake_case to serve as the constraint identifier by +//! /// default. +//! IntLinLe(ArrayExpr, ArrayExpr>, i64), +//! +//! /// If the snake_case version of the variant name is different from the constraint +//! /// identifier, then the `#[name(...)], attribute allows you to set the constraint +//! /// identifier explicitly. +//! #[name("int_lin_eq")] +//! LinearEquality(ArrayExpr, ArrayExpr>, i64), +//! +//! /// Constraint arguments can also be named, but the order determines how they are parsed +//! /// from the AST. +//! Element { +//! index: VariableExpr, +//! array: ArrayExpr, +//! rhs: VariableExpr, +//! }, +//! +//! /// Arguments can also be separate structs, if the enum variant has exactly one argument. +//! #[args] +//! IntTimes(Multiplication), +//! } +//! +//! #[derive(FlatZincConstraint)] +//! pub struct Multiplication { +//! a: VariableExpr, +//! b: VariableExpr, +//! c: VariableExpr, +//! } +//! ``` +//! The macro automatically implements [`FlatZincConstraint`] and will handle the parsing +//! of arguments for you. +//! +//! Similar to typed constraints, the derive macro for [`FlatZincAnnotation`] allows for easy +//! parsing of annotations: +//! ``` +//! use std::rc::Rc; +//! +//! use fzn_rs::FlatZincAnnotation; +//! +//! #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] +//! enum TypedAnnotation { +//! /// Matches the snake-case atom "annotation". +//! Annotation, +//! +//! /// Supports nested annotations with the `#[annotation]` attribute. +//! SomeAnnotation(#[annotation] SomeAnnotationArgs), +//! } +//! +//! #[derive(Clone, Debug, PartialEq, Eq, FlatZincAnnotation)] +//! enum SomeAnnotationArgs { +//! /// Just as constraints, the name can be explicitly set. +//! #[name("arg_one")] +//! Arg1, +//! ArgTwo(Rc), +//! } +//! ``` +//! Different to parsing constraints, is that annotations can be ignored. If the AST contains an +//! annotation whose name does not match one of the variants in the enum, then the annotation is +//! simply ignored. +//! +//! [1]: https://docs.minizinc.dev/en/stable/lib-flatzinc-int.html#int-lin-le + +mod error; +mod typed; + +pub mod ast; +#[cfg(feature = "fzn")] +pub mod fzn; + +pub use error::*; +#[cfg(feature = "derive")] +pub use fzn_rs_derive::*; +pub use typed::*; diff --git a/fzn-rs/src/typed/arrays.rs b/fzn-rs/src/typed/arrays.rs new file mode 100644 index 000000000..2d6fe4fe9 --- /dev/null +++ b/fzn-rs/src/typed/arrays.rs @@ -0,0 +1,194 @@ +use std::collections::BTreeMap; +use std::marker::PhantomData; +use std::rc::Rc; + +use super::FromAnnotationArgument; +use super::FromAnnotationLiteral; +use super::FromArgument; +use super::FromLiteral; +use super::VariableExpr; +use crate::ast; +use crate::InstanceError; +use crate::Token; + +/// Models an array in a constraint argument. +/// +/// ## Example +/// ``` +/// use fzn_rs::ArrayExpr; +/// use fzn_rs::FlatZincConstraint; +/// use fzn_rs::VariableExpr; +/// +/// #[derive(FlatZincConstraint)] +/// struct Linear { +/// /// An array of constants. +/// weights: ArrayExpr, +/// /// An array of variables. +/// variables: ArrayExpr>, +/// } +/// ``` +/// +/// Use [`crate::TypedInstance::resolve_array`] to access the elements in the array. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ArrayExpr { + expr: ArrayExprImpl, + ty: PhantomData, +} + +impl ArrayExpr +where + T: FromAnnotationLiteral, +{ + pub(crate) fn resolve<'a, Ann>( + &'a self, + arrays: &'a BTreeMap, ast::Array>, + ) -> Result> + 'a, Rc> { + match &self.expr { + ArrayExprImpl::Identifier(ident) => arrays + .get(ident) + .map(|array| { + GenericIterator(Box::new( + array.contents.iter().map(::from_literal), + )) + }) + .ok_or_else(|| Rc::clone(ident)), + ArrayExprImpl::Array(array) => Ok(GenericIterator(Box::new( + array.iter().map(::from_literal), + ))), + } + } +} + +impl FromArgument for ArrayExpr { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::Argument::Array(contents) => { + let contents = contents + .iter() + .cloned() + .map(|node| ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(node.node), + span: node.span, + }) + .collect(); + + Ok(ArrayExpr { + expr: ArrayExprImpl::Array(contents), + ty: PhantomData, + }) + } + ast::Argument::Literal(ast::Node { + node: ast::Literal::Identifier(ident), + .. + }) => Ok(ArrayExpr { + expr: ArrayExprImpl::Identifier(Rc::clone(ident)), + ty: PhantomData, + }), + ast::Argument::Literal(literal) => Err(InstanceError::UnexpectedToken { + expected: Token::Array, + actual: Token::from(&literal.node), + span: literal.span, + }), + } + } +} + +impl FromAnnotationArgument for ArrayExpr { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::AnnotationArgument::Array(contents) => Ok(ArrayExpr { + expr: ArrayExprImpl::Array(contents.clone()), + ty: PhantomData, + }), + ast::AnnotationArgument::Literal(ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(ast::Literal::Identifier(ident)), + .. + }) => Ok(ArrayExpr { + expr: ArrayExprImpl::Identifier(Rc::clone(ident)), + ty: PhantomData, + }), + ast::AnnotationArgument::Literal(literal) => Err(InstanceError::UnexpectedToken { + expected: Token::Array, + actual: Token::from(&literal.node), + span: literal.span, + }), + } + } +} + +impl From>> for ArrayExpr +where + ast::Literal: From, +{ + fn from(value: Vec>) -> Self { + ArrayExpr { + expr: ArrayExprImpl::Array( + value + .into_iter() + .map(|value| ast::Node { + node: match value { + VariableExpr::Identifier(ident) => { + ast::AnnotationLiteral::BaseLiteral(ast::Literal::Identifier(ident)) + } + VariableExpr::Constant(value) => { + ast::AnnotationLiteral::BaseLiteral(ast::Literal::from(value)) + } + }, + span: ast::Span { + start: usize::MAX, + end: usize::MAX, + }, + }) + .collect(), + ), + ty: PhantomData, + } + } +} + +impl From> for ArrayExpr +where + ast::Literal: From, +{ + fn from(value: Vec) -> Self { + ArrayExpr { + expr: ArrayExprImpl::Array( + value + .into_iter() + .map(|value| ast::Node { + node: ast::AnnotationLiteral::BaseLiteral(ast::Literal::from(value)), + span: ast::Span { + start: usize::MAX, + end: usize::MAX, + }, + }) + .collect(), + ), + ty: PhantomData, + } + } +} + +/// The actual array expression, which is either an identifier or the array. +/// +/// This is a private type as all access to the array should go through [`ArrayExpr::resolve`]. +#[derive(Clone, Debug, PartialEq, Eq)] +enum ArrayExprImpl { + Identifier(Rc), + /// Regardless of the contents of the array, the elements will always be of type + /// [`ast::AnnotationLiteral`] to support the parsing of annotations from arrays. + Array(Vec>), +} + +/// A boxed dyn [`ExactSizeIterator`] which is returned from [`ArrayExpr::resolve`]. +struct GenericIterator<'a, T>(Box + 'a>); + +impl Iterator for GenericIterator<'_, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl ExactSizeIterator for GenericIterator<'_, T> {} diff --git a/fzn-rs/src/typed/constraint.rs b/fzn-rs/src/typed/constraint.rs new file mode 100644 index 000000000..f607c98ce --- /dev/null +++ b/fzn-rs/src/typed/constraint.rs @@ -0,0 +1,8 @@ +use crate::ast; + +/// A constraint that has annotations attached to it. +#[derive(Clone, Debug)] +pub struct AnnotatedConstraint { + pub constraint: ast::Node, + pub annotations: Vec>, +} diff --git a/fzn-rs/src/typed/flatzinc_annotation.rs b/fzn-rs/src/typed/flatzinc_annotation.rs new file mode 100644 index 000000000..b619714cf --- /dev/null +++ b/fzn-rs/src/typed/flatzinc_annotation.rs @@ -0,0 +1,150 @@ +use std::rc::Rc; + +use super::FromLiteral; +use crate::ast; +use crate::InstanceError; +use crate::Token; + +/// Parse an [`ast::Annotation`] into a specific annotation type. +/// +/// The difference with [`crate::FlatZincConstraint::from_ast`] is that annotations can be ignored. +/// [`FlatZincAnnotation::from_ast`] can successfully parse an annotation into nothing, signifying +/// the annotation is not of interest in the final [`crate::TypedInstance`]. +pub trait FlatZincAnnotation: Sized { + /// Parse a value of `Self` from the annotation node. Return `None` if the annotation node + /// clearly is not relevant for `Self`, e.g. when the name is for a completely different + /// annotation than `Self` models. + fn from_ast(annotation: &ast::Node) -> Result, InstanceError>; + + /// Parse an [`ast::Annotation`] into `Self` and produce an error if the annotation cannot be + /// converted to a value of `Self`. + fn from_ast_required(annotation: &ast::Node) -> Result { + let outcome = Self::from_ast(annotation)?; + + // By default, failing to parse an annotation node into an annotation type is not + // necessarily an error since the annotation node can be ignored. In this case, however, + // we require a value to be present. Hence, if `outcome` is `None`, that is an error. + outcome.ok_or_else(|| InstanceError::UnsupportedAnnotation(annotation.node.name().into())) + } +} + +/// A default implementation that ignores all annotations. +impl FlatZincAnnotation for () { + fn from_ast(_: &ast::Node) -> Result, InstanceError> { + Ok(None) + } +} + +/// Parse a value from an [`ast::AnnotationArgument`]. +/// +/// Any type that implements [`FromAnnotationLiteral`] also implements [`FromAnnotationArgument`]. +pub trait FromAnnotationArgument: Sized { + fn from_argument(argument: &ast::Node) -> Result; +} + +/// Parse a value from an [`ast::AnnotationLiteral`]. +pub trait FromAnnotationLiteral: FromLiteral + Sized { + fn expected() -> Token; + + fn from_literal(literal: &ast::Node) -> Result; +} + +impl FromAnnotationLiteral for T { + fn expected() -> Token { + T::expected() + } + + fn from_literal(literal: &ast::Node) -> Result { + match &literal.node { + ast::AnnotationLiteral::BaseLiteral(base_literal) => T::from_literal(&ast::Node { + node: base_literal.clone(), + span: literal.span, + }), + ast::AnnotationLiteral::Annotation(_) => Err(InstanceError::UnexpectedToken { + expected: T::expected(), + actual: Token::AnnotationCall, + span: literal.span, + }), + } + } +} + +impl FromAnnotationArgument for T { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::AnnotationArgument::Literal(literal) => { + ::from_literal(literal) + } + + node => Err(InstanceError::UnexpectedToken { + expected: ::expected(), + actual: node.into(), + span: argument.span, + }), + } + } +} + +/// Parse an [`ast::AnnotationArgument`] as an annotation. This needs to be a separate trait from +/// [`FromAnnotationArgument`] so it does not collide wiith implementations for literals. +pub trait FromNestedAnnotation: Sized { + fn from_argument(argument: &ast::Node) -> Result; +} + +/// Converts an [`ast::AnnotationLiteral`] to an [`ast::Annotation`], or produces an error if that +/// is not possible. +fn annotation_literal_to_annotation( + literal: &ast::Node, +) -> Result, InstanceError> { + match &literal.node { + ast::AnnotationLiteral::BaseLiteral(ast::Literal::Identifier(ident)) => Ok(ast::Node { + node: ast::Annotation::Atom(Rc::clone(ident)), + span: literal.span, + }), + ast::AnnotationLiteral::Annotation(annotation_call) => Ok(ast::Node { + node: ast::Annotation::Call(annotation_call.clone()), + span: literal.span, + }), + ast::AnnotationLiteral::BaseLiteral(lit) => Err(InstanceError::UnexpectedToken { + expected: Token::Annotation, + actual: lit.into(), + span: literal.span, + }), + } +} + +impl FromNestedAnnotation for Ann { + fn from_argument(argument: &ast::Node) -> Result { + let annotation = match &argument.node { + ast::AnnotationArgument::Literal(literal) => annotation_literal_to_annotation(literal)?, + ast::AnnotationArgument::Array(_) => { + return Err(InstanceError::UnexpectedToken { + expected: Token::Annotation, + actual: Token::Array, + span: argument.span, + }); + } + }; + + Ann::from_ast_required(&annotation) + } +} + +impl FromNestedAnnotation for Vec { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::AnnotationArgument::Array(elements) => elements + .iter() + .map(|literal| { + let annotation = annotation_literal_to_annotation(literal)?; + Ann::from_ast_required(&annotation) + }) + .collect::>(), + ast::AnnotationArgument::Literal(lit) => Err(InstanceError::UnexpectedToken { + expected: Token::Array, + actual: (&lit.node).into(), + span: argument.span, + }), + } + } +} diff --git a/fzn-rs/src/typed/flatzinc_constraint.rs b/fzn-rs/src/typed/flatzinc_constraint.rs new file mode 100644 index 000000000..b355a9739 --- /dev/null +++ b/fzn-rs/src/typed/flatzinc_constraint.rs @@ -0,0 +1,27 @@ +use super::FromLiteral; +use crate::ast; +use crate::InstanceError; +use crate::Token; + +/// Parse a constraint from the given [`ast::Constraint`]. +pub trait FlatZincConstraint: Sized { + fn from_ast(constraint: &ast::Node) -> Result; +} + +/// Extract an argument from the [`ast::Argument`] node. +pub trait FromArgument: Sized { + fn from_argument(argument: &ast::Node) -> Result; +} + +impl FromArgument for T { + fn from_argument(argument: &ast::Node) -> Result { + match &argument.node { + ast::Argument::Literal(literal) => T::from_literal(literal), + ast::Argument::Array(_) => Err(InstanceError::UnexpectedToken { + expected: T::expected(), + actual: Token::Array, + span: argument.span, + }), + } + } +} diff --git a/fzn-rs/src/typed/from_literal.rs b/fzn-rs/src/typed/from_literal.rs new file mode 100644 index 000000000..c41f698c2 --- /dev/null +++ b/fzn-rs/src/typed/from_literal.rs @@ -0,0 +1,132 @@ +use std::rc::Rc; + +use super::VariableExpr; +use crate::ast; +use crate::InstanceError; +use crate::Token; + +/// Extract a value from an [`ast::Literal`]. +pub trait FromLiteral: Sized { + /// The [`Token`] that is expected for this implementation. Used to create error messages. + fn expected() -> Token; + + /// Extract `Self` from a literal AST node. + fn from_literal(node: &ast::Node) -> Result; +} + +impl FromLiteral for i32 { + fn expected() -> Token { + Token::IntLiteral + } + + fn from_literal(node: &ast::Node) -> Result { + let integer = ::from_literal(node)?; + i32::try_from(integer).map_err(|_| InstanceError::IntegerOverflow(integer)) + } +} + +impl FromLiteral for i64 { + fn expected() -> Token { + Token::IntLiteral + } + + fn from_literal(node: &ast::Node) -> Result { + match &node.node { + ast::Literal::Int(value) => Ok(*value), + literal => Err(InstanceError::UnexpectedToken { + expected: Token::IntLiteral, + actual: literal.into(), + span: node.span, + }), + } + } +} + +impl FromLiteral for VariableExpr { + fn expected() -> Token { + Token::Variable(Box::new(T::expected())) + } + + fn from_literal(node: &ast::Node) -> Result { + match &node.node { + ast::Literal::Identifier(identifier) => { + Ok(VariableExpr::Identifier(Rc::clone(identifier))) + } + literal => T::from_literal(node) + .map(VariableExpr::Constant) + .map_err(|_| InstanceError::UnexpectedToken { + expected: ::expected(), + actual: literal.into(), + span: node.span, + }), + } + } +} + +impl FromLiteral for Rc { + fn expected() -> Token { + Token::Identifier + } + + fn from_literal(argument: &ast::Node) -> Result { + match &argument.node { + ast::Literal::Identifier(ident) => Ok(Rc::clone(ident)), + + node => Err(InstanceError::UnexpectedToken { + expected: Token::Identifier, + actual: node.into(), + span: argument.span, + }), + } + } +} + +impl FromLiteral for bool { + fn expected() -> Token { + Token::BoolLiteral + } + + fn from_literal(argument: &ast::Node) -> Result { + match &argument.node { + ast::Literal::Bool(boolean) => Ok(*boolean), + + node => Err(InstanceError::UnexpectedToken { + expected: Token::BoolLiteral, + actual: node.into(), + span: argument.span, + }), + } + } +} + +impl FromLiteral for ast::RangeList { + fn expected() -> Token { + Token::IntSetLiteral + } + + fn from_literal(argument: &ast::Node) -> Result { + let set = as FromLiteral>::from_literal(argument)?; + + set.into_iter() + .map(|elem| i32::try_from(elem).map_err(|_| InstanceError::IntegerOverflow(elem))) + .collect::>() + } +} + +impl FromLiteral for ast::RangeList { + fn expected() -> Token { + Token::IntSetLiteral + } + + fn from_literal(argument: &ast::Node) -> Result { + match &argument.node { + ast::Literal::IntSet(set) => Ok(set.clone()), + + node => Err(InstanceError::UnexpectedToken { + expected: Token::IntSetLiteral, + actual: node.into(), + span: argument.span, + }), + } + } +} diff --git a/fzn-rs/src/typed/instance.rs b/fzn-rs/src/typed/instance.rs new file mode 100644 index 000000000..ca36d7b83 --- /dev/null +++ b/fzn-rs/src/typed/instance.rs @@ -0,0 +1,196 @@ +use std::collections::BTreeMap; +use std::rc::Rc; + +use super::ArrayExpr; +use super::FlatZincAnnotation; +use super::FlatZincConstraint; +use super::FromAnnotationLiteral; +use super::FromLiteral; +use super::VariableExpr; +use crate::ast; +use crate::AnnotatedConstraint; +use crate::InstanceError; + +/// A fully typed representation of a FlatZinc instance. +/// +/// It is generic over the type of constraints, as well as the annotations for variables, arrays, +/// constraints, and solve. +#[derive(Clone, Debug)] +pub struct TypedInstance< + Int, + Constraint, + VariableAnnotations = (), + ArrayAnnotations = (), + ConstraintAnnotations = (), + SolveAnnotations = (), +> { + /// The variables that are in the instance. + /// + /// The key is the identifier of the variable, and the value is the domain of the variable. + pub variables: BTreeMap, ast::Variable>, + + /// The arrays in the instance. + /// + /// The key is the identifier of the array, and the value is the array itself. + pub arrays: BTreeMap, ast::Array>, + + /// The constraints in the instance. + pub constraints: Vec>, + + /// The solve item indicating how to solve the model. + pub solve: Solve, +} + +/// Specifies how to solve a [`TypedInstance`]. +/// +/// This is generic over the integer type. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Solve { + pub method: ast::Node>, + pub annotations: Vec>, +} + +/// Indicate whether the model is an optimisation or satisfaction model. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Method { + Satisfy, + Optimize { + direction: ast::OptimizationDirection, + objective: VariableExpr, + }, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[error("array '{0}' is undefined")] +pub struct UndefinedArrayError(pub String); + +impl + TypedInstance +{ + /// Get the elements in an [`ArrayExpr`]. + pub fn resolve_array<'a, T>( + &'a self, + array_expr: &'a ArrayExpr, + ) -> Result> + 'a, UndefinedArrayError> + where + T: FromAnnotationLiteral, + { + array_expr + .resolve(&self.arrays) + .map_err(|identifier| UndefinedArrayError(identifier.as_ref().into())) + } +} + +impl + TypedInstance +where + TConstraint: FlatZincConstraint, + VAnnotations: FlatZincAnnotation, + AAnotations: FlatZincAnnotation, + CAnnotations: FlatZincAnnotation, + SAnnotations: FlatZincAnnotation, + VariableExpr: FromLiteral, +{ + /// Create a [`TypedInstance`] from an [`ast::Ast`]. + /// + /// This parses the constraints and annotations, and can fail e.g. if the number or type of + /// arguments do not match what is expected in the parser. + /// + /// This does _not_ type-check the variables. I.e., if a constraint takes a `var int`, but + /// is provided with an identifier of a `var bool`, then this function will gladly accept that. + pub fn from_ast(ast: ast::Ast) -> Result { + let variables = ast + .variables + .into_iter() + .map(|(id, variable)| { + let variable = ast::Variable { + domain: variable.node.domain, + value: variable.node.value, + annotations: map_annotations(&variable.node.annotations)?, + }; + + Ok((id, variable)) + }) + .collect::>()?; + + let arrays = ast + .arrays + .into_iter() + .map(|(id, array)| { + let array = ast::Array { + domain: array.node.domain, + contents: array.node.contents, + annotations: map_annotations(&array.node.annotations)?, + }; + + Ok((id, array)) + }) + .collect::>()?; + + let constraints = ast + .constraints + .iter() + .map(|constraint| { + let annotations = map_annotations(&constraint.node.annotations)?; + + let instance_constraint = TConstraint::from_ast(constraint)?; + + Ok(AnnotatedConstraint { + constraint: ast::Node { + node: instance_constraint, + span: constraint.span, + }, + annotations, + }) + }) + .collect::>()?; + + let solve = Solve { + method: match ast.solve.method.node { + ast::Method::Satisfy => ast::Node { + node: Method::Satisfy, + span: ast.solve.method.span, + }, + ast::Method::Optimize { + direction, + objective, + } => ast::Node { + node: Method::Optimize { + direction, + objective: as FromLiteral>::from_literal(&ast::Node { + node: objective, + span: ast.solve.method.span, + })?, + }, + span: ast.solve.method.span, + }, + }, + annotations: map_annotations(&ast.solve.annotations)?, + }; + + Ok(TypedInstance { + variables, + arrays, + constraints, + solve, + }) + } +} + +fn map_annotations( + annotations: &[ast::Node], +) -> Result>, InstanceError> { + annotations + .iter() + .filter_map(|annotation| { + Ann::from_ast(annotation) + .map(|maybe_node| { + maybe_node.map(|node| ast::Node { + node, + span: annotation.span, + }) + }) + .transpose() + }) + .collect() +} diff --git a/fzn-rs/src/typed/mod.rs b/fzn-rs/src/typed/mod.rs new file mode 100644 index 000000000..7a3217c51 --- /dev/null +++ b/fzn-rs/src/typed/mod.rs @@ -0,0 +1,23 @@ +mod arrays; +mod constraint; +mod flatzinc_annotation; +mod flatzinc_constraint; +mod from_literal; +mod instance; + +use std::rc::Rc; + +pub use arrays::*; +pub use constraint::*; +pub use flatzinc_annotation::*; +pub use flatzinc_constraint::*; +pub use from_literal::*; +pub use instance::*; + +/// Models a variable in the FlatZinc AST. Since `var T` is a subtype of `T`, a variable can also +/// be a constant. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum VariableExpr { + Identifier(Rc), + Constant(T), +} diff --git a/pumpkin-solver/Cargo.toml b/pumpkin-solver/Cargo.toml index 269628c0e..6eaa0b9ee 100644 --- a/pumpkin-solver/Cargo.toml +++ b/pumpkin-solver/Cargo.toml @@ -11,12 +11,13 @@ repository.workspace = true [dependencies] clap = { version = "4.5.17", features = ["derive"] } env_logger = "0.10.0" -flatzinc = "0.3.21" fnv = "1.0.7" log = "0.4.27" pumpkin-core = { version = "0.2.1", path = "../pumpkin-crates/core/", features = ["clap"] } signal-hook = "0.3.18" thiserror = "2.0.12" +fzn-rs = { version = "0.1.0", path = "../fzn-rs/", features = ["derive", "fzn"] } +ariadne = "0.5.1" [dev-dependencies] clap = { version = "4.5.17", features = ["derive"] } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/ast.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/ast.rs index 0d56d835d..0190acd4c 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/ast.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/ast.rs @@ -1,4 +1,9 @@ +use fzn_rs::ast::RangeList; +use fzn_rs::ArrayExpr; +use fzn_rs::FromAnnotationArgument; +use fzn_rs::VariableExpr; use log::warn; +use pumpkin_core::proof::ConstraintTag; use pumpkin_solver::branching::value_selection::DynamicValueSelector; use pumpkin_solver::branching::value_selection::InDomainInterval; use pumpkin_solver::branching::value_selection::InDomainMax; @@ -20,13 +25,10 @@ use pumpkin_solver::branching::variable_selection::InputOrder; use pumpkin_solver::branching::variable_selection::Largest; use pumpkin_solver::branching::variable_selection::MaxRegret; use pumpkin_solver::branching::variable_selection::Smallest; -use pumpkin_solver::pumpkin_assert_eq_simple; -use pumpkin_solver::pumpkin_assert_simple; use pumpkin_solver::variables::DomainId; use pumpkin_solver::variables::Literal; -use super::error::FlatZincError; -#[derive(Debug)] +#[derive(fzn_rs::FlatZincAnnotation)] pub(crate) enum VariableSelectionStrategy { AntiFirstFail, DomWDeg, @@ -99,48 +101,48 @@ impl VariableSelectionStrategy { } } -#[derive(Debug)] +#[derive(fzn_rs::FlatZincAnnotation)] pub(crate) enum ValueSelectionStrategy { - InDomain, - InDomainInterval, - InDomainMax, - InDomainMedian, - InDomainMiddle, - InDomainMin, - InDomainRandom, - InDomainReverseSplit, - InDomainSplit, - InDomainSplitRandom, - OutDomainMax, - OutDomainMedian, - OutDomainMin, - OutDomainRandom, + Indomain, + IndomainInterval, + IndomainMax, + IndomainMedian, + IndomainMiddle, + IndomainMin, + IndomainRandom, + IndomainReverseSplit, + IndomainSplit, + IndomainSplitRandom, + OutdomainMax, + OutdomainMedian, + OutdomainMin, + OutdomainRandom, } impl ValueSelectionStrategy { pub(crate) fn create_for_literals(&self) -> DynamicValueSelector { DynamicValueSelector::new(match self { - ValueSelectionStrategy::InDomain - | ValueSelectionStrategy::InDomainInterval - | ValueSelectionStrategy::InDomainMin - | ValueSelectionStrategy::InDomainSplit - | ValueSelectionStrategy::OutDomainMax => Box::new(InDomainMin), - ValueSelectionStrategy::InDomainMax - | ValueSelectionStrategy::InDomainReverseSplit - | ValueSelectionStrategy::OutDomainMin => Box::new(InDomainMax), - ValueSelectionStrategy::InDomainMedian => { - warn!("InDomainMedian does not make sense for propositional variables, defaulting to InDomainMin..."); + ValueSelectionStrategy::Indomain + | ValueSelectionStrategy::IndomainInterval + | ValueSelectionStrategy::IndomainMin + | ValueSelectionStrategy::IndomainSplit + | ValueSelectionStrategy::OutdomainMax => Box::new(InDomainMin), + ValueSelectionStrategy::IndomainMax + | ValueSelectionStrategy::IndomainReverseSplit + | ValueSelectionStrategy::OutdomainMin => Box::new(InDomainMax), + ValueSelectionStrategy::IndomainMedian => { + warn!("indomain_median does not make sense for propositional variables, defaulting to indomain_min..."); Box::new(InDomainMin) } - ValueSelectionStrategy::InDomainMiddle => { - warn!("InDomainMiddle does not make sense for propositional variables, defaulting to InDomainMin..."); + ValueSelectionStrategy::IndomainMiddle => { + warn!("indomain_middle does not make sense for propositional variables, defaulting to indomain_min..."); Box::new(InDomainMin) } - ValueSelectionStrategy::InDomainRandom - | ValueSelectionStrategy::InDomainSplitRandom - | ValueSelectionStrategy::OutDomainRandom => Box::new(InDomainRandom), - ValueSelectionStrategy::OutDomainMedian => { - warn!("OutDomainMedian does not make sense for propositional variables, defaulting to InDomainMin..."); + ValueSelectionStrategy::IndomainRandom + | ValueSelectionStrategy::IndomainSplitRandom + | ValueSelectionStrategy::OutdomainRandom => Box::new(InDomainRandom), + ValueSelectionStrategy::OutdomainMedian => { + warn!("outdomain_median does not make sense for propositional variables, defaulting to indomain_min..."); Box::new(InDomainMin) } }) @@ -148,282 +150,132 @@ impl ValueSelectionStrategy { pub(crate) fn create_for_domains(&self) -> DynamicValueSelector { DynamicValueSelector::new(match self { - ValueSelectionStrategy::InDomain => Box::new(InDomainMin), - ValueSelectionStrategy::InDomainInterval => Box::new(InDomainInterval), - ValueSelectionStrategy::InDomainMax => Box::new(InDomainMax), - ValueSelectionStrategy::InDomainMedian => Box::new(InDomainMedian), - ValueSelectionStrategy::InDomainMiddle => Box::new(InDomainMiddle), - ValueSelectionStrategy::InDomainMin => Box::new(InDomainMin), - ValueSelectionStrategy::InDomainRandom => Box::new(InDomainRandom), - ValueSelectionStrategy::InDomainReverseSplit => Box::new(ReverseInDomainSplit), - ValueSelectionStrategy::InDomainSplit => Box::new(InDomainSplit), - ValueSelectionStrategy::InDomainSplitRandom => Box::new(InDomainSplitRandom), - ValueSelectionStrategy::OutDomainMax => Box::new(OutDomainMax), - ValueSelectionStrategy::OutDomainMedian => Box::new(OutDomainMedian), - ValueSelectionStrategy::OutDomainMin => Box::new(OutDomainMin), - ValueSelectionStrategy::OutDomainRandom => Box::new(OutDomainRandom), + ValueSelectionStrategy::Indomain => Box::new(InDomainMin), + ValueSelectionStrategy::IndomainInterval => Box::new(InDomainInterval), + ValueSelectionStrategy::IndomainMax => Box::new(InDomainMax), + ValueSelectionStrategy::IndomainMedian => Box::new(InDomainMedian), + ValueSelectionStrategy::IndomainMiddle => Box::new(InDomainMiddle), + ValueSelectionStrategy::IndomainMin => Box::new(InDomainMin), + ValueSelectionStrategy::IndomainRandom => Box::new(InDomainRandom), + ValueSelectionStrategy::IndomainReverseSplit => Box::new(ReverseInDomainSplit), + ValueSelectionStrategy::IndomainSplit => Box::new(InDomainSplit), + ValueSelectionStrategy::IndomainSplitRandom => Box::new(InDomainSplitRandom), + ValueSelectionStrategy::OutdomainMax => Box::new(OutDomainMax), + ValueSelectionStrategy::OutdomainMedian => Box::new(OutDomainMedian), + ValueSelectionStrategy::OutdomainMin => Box::new(OutDomainMin), + ValueSelectionStrategy::OutdomainRandom => Box::new(OutDomainRandom), }) } } -#[derive(Debug)] -pub(crate) enum Search { - Bool(SearchStrategy), - Int(SearchStrategy), - Seq(Vec), - Unspecified, - WarmStartInt { - variables: flatzinc::AnnExpr, - values: flatzinc::AnnExpr, - }, - WarmStartBool { - variables: flatzinc::AnnExpr, - values: flatzinc::AnnExpr, - }, - WarmStartArray(Vec), +/// The exploration strategies for search annotations. +/// +/// See +/// https://docs.minizinc.dev/en/stable/lib-stdlib-annotations.html#exploration-strategy-annotations. +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) enum Exploration { + Complete, } -#[derive(Debug)] -pub(crate) struct SearchStrategy { - pub(crate) variables: flatzinc::AnnExpr, +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) enum SearchAnnotation { + #[args] + BoolSearch(BoolSearchArgs), + #[args] + IntSearch(IntSearchArgs), + Seq(#[annotation] Vec), + #[args] + WarmStartBool(WarmStartBoolArgs), + #[args] + WarmStartInt(WarmStartIntArgs), + WarmStartArray(#[annotation] Vec), +} + +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) struct IntSearchArgs { + pub(crate) variables: ArrayExpr>, + #[annotation] pub(crate) variable_selection_strategy: VariableSelectionStrategy, + #[annotation] pub(crate) value_selection_strategy: ValueSelectionStrategy, + #[allow( + dead_code, + reason = "the int_search annotation has this argument, so it needs to be present here" + )] + #[annotation] + pub(crate) exploration: Exploration, } -pub(crate) struct FlatZincAst { - pub(crate) parameter_decls: Vec, - pub(crate) single_variables: Vec, - pub(crate) variable_arrays: Vec, - pub(crate) constraint_decls: Vec, - pub(crate) solve_item: flatzinc::SolveItem, - pub(crate) search: Search, +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) struct BoolSearchArgs { + pub(crate) variables: ArrayExpr>, + #[annotation] + pub(crate) variable_selection_strategy: VariableSelectionStrategy, + #[annotation] + pub(crate) value_selection_strategy: ValueSelectionStrategy, + #[allow( + dead_code, + reason = "the int_search annotation has this argument, so it needs to be present here" + )] + #[annotation] + pub(crate) exploration: Exploration, } -impl FlatZincAst { - pub(crate) fn builder() -> FlatZincAstBuilder { - FlatZincAstBuilder { - parameter_decls: vec![], - single_variables: vec![], - variable_arrays: vec![], - constraint_decls: vec![], - solve_item: None, - search: None, - } - } +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) struct WarmStartBoolArgs { + pub(crate) variables: ArrayExpr>, + pub(crate) values: ArrayExpr, } -pub(crate) struct FlatZincAstBuilder { - parameter_decls: Vec, - single_variables: Vec, - variable_arrays: Vec, - constraint_decls: Vec, - solve_item: Option, - - search: Option, +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) struct WarmStartIntArgs { + pub(crate) variables: ArrayExpr>, + pub(crate) values: ArrayExpr, } -impl FlatZincAstBuilder { - pub(crate) fn add_parameter_decl(&mut self, parameter_decl: flatzinc::ParDeclItem) { - self.parameter_decls.push(parameter_decl); - } - - pub(crate) fn add_variable_decl(&mut self, variable_decl: SingleVarDecl) { - self.single_variables.push(variable_decl); - } - - pub(crate) fn add_variable_array(&mut self, array_decl: VarArrayDecl) { - self.variable_arrays.push(array_decl); - } - - pub(crate) fn add_constraint(&mut self, constraint: flatzinc::ConstraintItem) { - self.constraint_decls.push(constraint); - } - - pub(crate) fn set_solve_item(&mut self, solve_item: flatzinc::SolveItem) { - if let Some(annotation) = solve_item.annotations.first() { - self.search = FlatZincAstBuilder::find_search(annotation); - } else { - self.search = Some(Search::Unspecified) - } - let _ = self.solve_item.insert(solve_item); - } - - fn find_search(annotation: &flatzinc::Annotation) -> Option { - match &annotation.id[..] { - "bool_search" => Some(Search::Bool(FlatZincAstBuilder::find_direct_search( - annotation, - ))), - "float_search" => panic!("Search over floats is currently not supported"), - "int_search" => Some(Search::Int(FlatZincAstBuilder::find_direct_search( - annotation, - ))), - "seq_search" => { - pumpkin_assert_eq_simple!( - annotation.expressions.len(), - 1, - "Expected a single expression for sequential search" - ); - Some(Search::Seq(match &annotation.expressions[0] { - flatzinc::AnnExpr::Annotations(annotations) => annotations - .iter() - .filter_map(FlatZincAstBuilder::find_search) - .collect::>(), - other => { - panic!("Expected a list of annotations for `seq_search` but was {other:?}") - } - })) - } - "set_search" => panic!("Search over sets is currently not supported"), - "warm_start_int" => Some(Search::WarmStartInt { - variables: annotation.expressions[0].clone(), - values: annotation.expressions[1].clone(), - }), - "warm_start_bool" => Some(Search::WarmStartBool { - variables: annotation.expressions[0].clone(), - values: annotation.expressions[1].clone(), - }), - "warm_start_array" => { - Some(Search::WarmStartArray(match &annotation.expressions[0] { - flatzinc::AnnExpr::Annotations(annotations) => annotations - .iter() - .filter_map(FlatZincAstBuilder::find_search) - .collect::>(), - other => { - panic!("Expected a list of annotations for `warm_start_array` but was {other:?}") - } - })) - } - "constraint_name" => { - warn!("`constraint_name` is currently not supported; ignoring search annotation"); - None - } - other => panic!("Did not recognise search strategy {other}"), - } - } +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) enum VariableAnnotations { + OutputVar, +} - fn find_direct_search(annotation: &flatzinc::Annotation) -> SearchStrategy { - // First element is the optimization variable - // Second element is the variable selection strategy - // Third element is the value selection strategy - // (Optional) Fourth element is the exploration strategy (e.g. complete search) - pumpkin_assert_simple!( - annotation.expressions.len() >= 3, - "Expected the search annotation to have 3 or 4 elements but it has {} elements", - annotation.expressions.len() - ); +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) enum ArrayAnnotations { + OutputArray(ArrayExpr>), +} - SearchStrategy { - variables: annotation.expressions[0].clone(), - variable_selection_strategy: FlatZincAstBuilder::find_variable_selection_strategy( - &annotation.expressions[1], - ), - value_selection_strategy: FlatZincAstBuilder::find_value_selection_strategy( - &annotation.expressions[2], - ), - } - } +#[derive(fzn_rs::FlatZincAnnotation)] +pub(crate) enum ConstraintAnnotations { + ConstraintTag(TagAnnotation), +} - fn find_variable_selection_strategy(input: &flatzinc::AnnExpr) -> VariableSelectionStrategy { - match input { - flatzinc::AnnExpr::Expr(inner) => match inner { - flatzinc::Expr::VarParIdentifier(identifier) => match &identifier[..] { - "anti_first_fail" => VariableSelectionStrategy::AntiFirstFail, - "dom_w_deg" => VariableSelectionStrategy::DomWDeg, - "first_fail" => VariableSelectionStrategy::FirstFail, - "impact" => VariableSelectionStrategy::Impact, - "input_order" => VariableSelectionStrategy::InputOrder, - "largest" => VariableSelectionStrategy::Largest, - "max_regret" => VariableSelectionStrategy::MaxRegret, - "most_constrained" => VariableSelectionStrategy::MostConstrained, - "occurrence" => VariableSelectionStrategy::Occurrence, - "smallest" => VariableSelectionStrategy::Smallest, - other => panic!("Did not recognise variable selection strategy {other}"), - }, - other => panic!("Expected VarParIdentifier but got {other:?}"), - }, - other => panic!("Expected an expression but got {other:?}"), - } - } +#[derive(Clone, Copy, Debug)] +pub(crate) struct TagAnnotation(ConstraintTag); - fn find_value_selection_strategy(input: &flatzinc::AnnExpr) -> ValueSelectionStrategy { - match input { - flatzinc::AnnExpr::Expr(inner) => match inner { - flatzinc::Expr::VarParIdentifier(identifier) => match &identifier[..] { - "indomain" => ValueSelectionStrategy::InDomain, - "indomain_interval" => ValueSelectionStrategy::InDomainInterval, - "indomain_max" => ValueSelectionStrategy::InDomainMax, - "indomain_median" => ValueSelectionStrategy::InDomainMedian, - "indomain_middle" => ValueSelectionStrategy::InDomainMiddle, - "indomain_min" => ValueSelectionStrategy::InDomainMin, - "indomain_random" => ValueSelectionStrategy::InDomainRandom, - "indomain_reverse_split" => ValueSelectionStrategy::InDomainReverseSplit, - "indomain_split" => ValueSelectionStrategy::InDomainSplit, - "indomain_split_random" => ValueSelectionStrategy::InDomainSplitRandom, - "outdomain_max" => ValueSelectionStrategy::OutDomainMax, - "outdomain_median" => ValueSelectionStrategy::OutDomainMedian, - "outdomain_min" => ValueSelectionStrategy::OutDomainMin, - "outdomain_random" => ValueSelectionStrategy::OutDomainRandom, - other => panic!("Did not recognise value selection strategy {other}"), - }, - other => panic!("Expected VarParIdentifier but got {other:?}"), - }, - other => panic!("Expected an expression but got {other:?}"), - } +impl From for TagAnnotation { + fn from(value: ConstraintTag) -> Self { + TagAnnotation(value) } +} - pub(crate) fn build(self) -> Result { - let FlatZincAstBuilder { - parameter_decls, - single_variables, - variable_arrays, - constraint_decls, - solve_item, - search, - } = self; - - Ok(FlatZincAst { - parameter_decls, - single_variables, - variable_arrays, - constraint_decls, - solve_item: solve_item.ok_or(FlatZincError::MissingSolveItem)?, - search: search.ok_or(FlatZincError::MissingSolveItem)?, - }) +impl From for ConstraintTag { + fn from(value: TagAnnotation) -> Self { + value.0 } } -pub(crate) enum SingleVarDecl { - Bool { - id: String, - expr: Option, - annos: flatzinc::expressions::Annotations, - }, - - IntInRange { - id: String, - lb: i128, - ub: i128, - expr: Option, - annos: flatzinc::expressions::Annotations, - }, - - IntInSet { - id: String, - set: Vec, - - annos: flatzinc::expressions::Annotations, - }, +impl FromAnnotationArgument for TagAnnotation { + fn from_argument( + _: &fzn_rs::ast::Node, + ) -> Result { + unreachable!("This never gets parsed from source") + } } -pub(crate) enum VarArrayDecl { - Bool { - id: String, - annos: Vec, - array_expr: Option, - }, - Int { - id: String, - annos: Vec, - array_expr: Option, - }, -} +pub(crate) type Instance = fzn_rs::TypedInstance< + i32, + super::constraints::Constraints, + VariableAnnotations, + ArrayAnnotations, + ConstraintAnnotations, + SearchAnnotation, +>; diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs index 9e57e74e7..9eb82d83c 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs @@ -2,101 +2,80 @@ use std::rc::Rc; -use flatzinc::Annotation; -use pumpkin_core::containers::HashMap; -use pumpkin_core::variables::DomainId; -use pumpkin_core::Solver; -use pumpkin_solver::variables::Literal; +use fzn_rs::ast; +use pumpkin_core::variables::Literal; use super::context::CompilationContext; use super::context::Domain; -use crate::flatzinc::ast::FlatZincAst; -use crate::flatzinc::ast::SingleVarDecl; +use crate::flatzinc::ast::Instance; +use crate::flatzinc::ast::VariableAnnotations; use crate::flatzinc::instance::Output; use crate::flatzinc::FlatZincError; pub(crate) fn run( - ast: &FlatZincAst, + instance: &Instance, context: &mut CompilationContext, ) -> Result<(), FlatZincError> { - for single_var_decl in &ast.single_variables { - match single_var_decl { - SingleVarDecl::Bool { id, annos, .. } => { - let id = context.identifiers.get_interned(id); - - let representative = context.equivalences.representative(&id); - let domain = context.equivalences.domain(&id); + for (name, variable) in &instance.variables { + match &variable.domain.node { + ast::Domain::Bool => { + let representative = context.equivalences.representative(name)?; + let domain = context.equivalences.domain(name); let domain_id = *context .variable_map - .entry(Rc::clone(&representative)) - .or_insert_with(|| { - create_integer_domain( - context.solver, - &mut context.constant_domain_ids, - representative, - domain, - ) - }); + .entry(representative) + .or_insert_with(|| domain.into_variable(context.solver, name.to_string())); let literal = Literal::new(domain_id); - if is_output_variable(annos) { - context.outputs.push(Output::bool(id, literal)); + if is_output_variable(variable) { + context.outputs.push(Output::bool(Rc::clone(name), literal)); } } - SingleVarDecl::IntInRange { id, annos, .. } - | SingleVarDecl::IntInSet { - id, set: _, annos, .. - } => { - let id = context.identifiers.get_interned(id); - - let representative = context.equivalences.representative(&id); - let domain = context.equivalences.domain(&id); + ast::Domain::Int(_) => { + let representative = context.equivalences.representative(name)?; + let domain = context.equivalences.domain(name); let domain_id = *context .variable_map .entry(Rc::clone(&representative)) .or_insert_with(|| { - create_integer_domain( - context.solver, - &mut context.constant_domain_ids, - representative, - domain, - ) + if domain.is_constant() { + *context + .constant_domain_ids + .entry(match &domain { + Domain::IntervalDomain { lb, ub: _ } => *lb, + Domain::SparseDomain { values } => values[0], + }) + .or_insert_with(|| { + domain.into_variable(context.solver, name.to_string()) + }) + } else { + domain.into_variable(context.solver, name.to_string()) + } }); - if is_output_variable(annos) { - context.outputs.push(Output::int(id, domain_id)); + if is_output_variable(variable) { + context + .outputs + .push(Output::int(Rc::clone(name), domain_id)); } } + + ast::Domain::UnboundedInt => { + return Err(FlatZincError::UnsupportedVariable(name.as_ref().into())) + } } } Ok(()) } -fn create_integer_domain( - solver: &mut Solver, - constant_domains: &mut HashMap, - identifier: Rc, - domain: Domain, -) -> DomainId { - if domain.is_constant() { - let value = match &domain { - Domain::IntervalDomain { lb, ub: _ } => *lb, - Domain::SparseDomain { values } => values[0], - }; - - *constant_domains - .entry(value) - .or_insert_with(|| domain.into_variable(solver, value.to_string())) - } else { - domain.into_variable(solver, identifier.to_string()) - } -} - -fn is_output_variable(annos: &[Annotation]) -> bool { - annos.iter().any(|ann| ann.id == "output_var") +fn is_output_variable(variable: &ast::Variable) -> bool { + variable + .annotations + .iter() + .any(|ann| matches!(ann.node, VariableAnnotations::OutputVar)) } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs index c6a0d0275..35b5c1ba3 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs @@ -3,14 +3,16 @@ use std::cell::RefMut; use std::collections::BTreeSet; use std::rc::Rc; +use fzn_rs::ast::RangeList; +use fzn_rs::ArrayExpr; +use fzn_rs::VariableExpr; use log::warn; use pumpkin_solver::containers::HashMap; -use pumpkin_solver::containers::HashSet; -use pumpkin_solver::proof::ConstraintTag; use pumpkin_solver::variables::DomainId; use pumpkin_solver::variables::Literal; use pumpkin_solver::Solver; +use crate::flatzinc::ast::Instance; use crate::flatzinc::instance::Output; use crate::flatzinc::FlatZincError; @@ -18,10 +20,6 @@ pub(crate) struct CompilationContext<'a> { /// The solver to compile the FlatZinc into. pub(crate) solver: &'a mut Solver, - /// All identifiers occuring in the model. The identifiers are interned, to support cheap - /// cloning. - pub(crate) identifiers: Identifiers, - /// Identifiers of variables that are outputs. pub(crate) outputs: Vec, @@ -41,36 +39,13 @@ pub(crate) struct CompilationContext<'a> { pub(crate) true_literal: Literal, /// Literal which is always false pub(crate) false_literal: Literal, - /// All boolean parameters. - pub(crate) boolean_parameters: HashMap, bool>, - /// All boolean array parameters. - pub(crate) boolean_array_parameters: HashMap, Rc<[bool]>>, - /// A mapping from boolean variable array identifiers to slices of literals. - pub(crate) boolean_variable_arrays: HashMap, Rc<[Literal]>>, - /// All integer parameters. - pub(crate) integer_parameters: HashMap, i32>, - /// All integer array parameters. - pub(crate) integer_array_parameters: HashMap, Rc<[i32]>>, + // A literal which is always true, can be used when using bool constants in the solver + // pub(crate) constant_bool_true: BooleanDomainId, + // A literal which is always false, can be used when using bool constants in the solver + // pub(crate) constant_bool_false: BooleanDomainId, /// The equivalence classes for integer variables. The associated data is the bounds for the /// Only instantiate single domain for every constant variable. pub(crate) constant_domain_ids: HashMap, - /// A mapping from integer variable array identifiers to slices of domain ids. - pub(crate) integer_variable_arrays: HashMap, Rc<[DomainId]>>, - - /// All set parameters. - pub(crate) set_constants: HashMap, Set>, - - /// All the constraints with their constraint tags. - pub(crate) constraints: Vec<(ConstraintTag, flatzinc::ConstraintItem)>, -} - -/// A set parameter. -#[derive(Clone, Debug)] -pub(crate) enum Set { - /// A set defined by the interval `lower_bound..=upper_bound`. - Interval { lower_bound: i32, upper_bound: i32 }, - /// A set defined by some values. - Sparse { values: Box<[i32]> }, } impl CompilationContext<'_> { @@ -80,426 +55,118 @@ impl CompilationContext<'_> { CompilationContext { solver, - identifiers: Default::default(), outputs: Default::default(), equivalences: Default::default(), + variable_map: Default::default(), true_literal, false_literal, - boolean_parameters: Default::default(), - boolean_array_parameters: Default::default(), - boolean_variable_arrays: Default::default(), - integer_parameters: Default::default(), - integer_array_parameters: Default::default(), - variable_map: Default::default(), constant_domain_ids: Default::default(), - integer_variable_arrays: Default::default(), - - set_constants: Default::default(), - - constraints: Default::default(), } } - pub(crate) fn is_identifier_parameter(&mut self, identifier: &str) -> bool { - self.integer_parameters.contains_key(identifier) - } - - // pub fn resolve_bool_constant(&self, identifier: &str) -> Option { - // self.boolean_parameters.get(identifier).copied() - // } - - // pub fn resolve_int_constant(&self, identifier: &str) -> Option { - // self.integer_parameters.get(identifier).copied() - // } - pub(crate) fn resolve_bool_variable( - &mut self, - expr: &flatzinc::Expr, - ) -> Result { - match expr { - flatzinc::Expr::VarParIdentifier(id) => self.resolve_bool_variable_from_identifier(id), - flatzinc::Expr::Bool(value) => { - if *value { - Ok(self.solver.get_true_literal()) - } else { - Ok(self.solver.get_false_literal()) - } - } - _ => Err(FlatZincError::UnexpectedExpr), - } - } - - pub(crate) fn resolve_bool_variable_from_identifier( &self, - identifier: &str, + variable: &VariableExpr, ) -> Result { - if let Some(domain_id) = self - .variable_map - .get(&self.equivalences.representative(identifier)) - { - Ok(Literal::new(*domain_id)) - } else { - self.boolean_parameters - .get(&self.equivalences.representative(identifier)) - .map(|value| { - if *value { - self.solver.get_true_literal() - } else { - self.solver.get_false_literal() - } - }) - .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: identifier.into(), - expected_type: "bool variable".into(), - }) - } - } + match variable { + VariableExpr::Identifier(ident) => { + let representative = self.equivalences.representative(ident)?; - pub(crate) fn resolve_bool_variable_array( - &self, - expr: &flatzinc::Expr, - ) -> Result, FlatZincError> { - match expr { - flatzinc::Expr::VarParIdentifier(id) => { - if let Some(literal) = self.boolean_variable_arrays.get(id.as_str()) { - Ok(Rc::clone(literal)) - } else { - self.boolean_array_parameters - .get(id.as_str()) - .map(|array| { - array - .iter() - .map(|value| { - if *value { - self.solver.get_true_literal() - } else { - self.solver.get_false_literal() - } - }) - .collect() - }) + let domain_id = + self.variable_map + .get(&representative) + .copied() .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: id.as_str().into(), - expected_type: "boolean variable array".into(), - }) - } - } - flatzinc::Expr::ArrayOfBool(array) => array - .iter() - .map(|elem| match elem { - flatzinc::BoolExpr::VarParIdentifier(id) => { - self.resolve_bool_variable_from_identifier(id) - } - flatzinc::BoolExpr::Bool(true) => Ok(self.solver.get_true_literal()), - flatzinc::BoolExpr::Bool(false) => Ok(self.solver.get_false_literal()), - }) - .collect(), - flatzinc::Expr::ArrayOfInt(array) => array - .iter() - .map(|elem| match elem { - flatzinc::IntExpr::VarParIdentifier(id) => { - self.resolve_bool_variable_from_identifier(id) - } - _ => panic!("Bool search should not be over integer variable"), - }) - .collect(), - _ => Err(FlatZincError::UnexpectedExpr), - } - } + identifier: Rc::clone(ident), + expected_type: "bool var".into(), + })?; - pub(crate) fn resolve_bool_constants( - &self, - expr: &flatzinc::Expr, - ) -> Result, FlatZincError> { - match expr { - flatzinc::Expr::VarParIdentifier(id) => self - .boolean_array_parameters - .get(id.as_str()) - .cloned() - .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: id.as_str().into(), - expected_type: "constant boolean array".into(), - }), - flatzinc::Expr::ArrayOfBool(exprs) => exprs - .iter() - .map(|e| self.resolve_bool_expr_to_const(e)) - .collect::, _>>(), - _ => Err(FlatZincError::UnexpectedExpr), - } - } - - pub(crate) fn resolve_array_integer_constants( - &self, - expr: &flatzinc::Expr, - ) -> Result, FlatZincError> { - match expr { - flatzinc::Expr::VarParIdentifier(id) => self - .integer_array_parameters - .get(id.as_str()) - .cloned() - .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: id.as_str().into(), - expected_type: "constant integer array".into(), - }), - flatzinc::Expr::ArrayOfInt(exprs) => exprs - .iter() - .map(|e| self.resolve_int_expr_to_const(e)) - .collect::, _>>(), - _ => Err(FlatZincError::UnexpectedExpr), - } - } - - pub(crate) fn resolve_integer_constant_from_identifier( - &mut self, - identifier: &str, - ) -> Result { - let value = self.resolve_int_expr_to_const(&flatzinc::IntExpr::VarParIdentifier( - identifier.to_owned(), - ))?; - Ok(*self.constant_domain_ids.entry(value).or_insert_with(|| { - self.solver - .new_named_bounded_integer(value, value, identifier.to_owned()) - })) - } - - pub(crate) fn resolve_integer_constant_from_expr( - &self, - expr: &flatzinc::Expr, - ) -> Result { - fn try_into_int_expr(expr: flatzinc::Expr) -> Option { - match expr { - flatzinc::Expr::VarParIdentifier(id) => { - Some(flatzinc::IntExpr::VarParIdentifier(id)) - } - flatzinc::Expr::Int(value) => Some(flatzinc::IntExpr::Int(value)), - _ => None, + Ok(Literal::new(domain_id)) } + VariableExpr::Constant(true) => Ok(self.true_literal), + VariableExpr::Constant(false) => Ok(self.false_literal), } - try_into_int_expr(expr.clone()) - .ok_or(FlatZincError::UnexpectedExpr) - .and_then(|e| self.resolve_int_expr_to_const(&e)) } - pub(crate) fn resolve_bool_expr_to_const( + pub(crate) fn resolve_bool_array( &self, - expr: &flatzinc::BoolExpr, - ) -> Result { - match expr { - flatzinc::BoolExpr::Bool(value) => Ok(*value), - flatzinc::BoolExpr::VarParIdentifier(id) => self - .boolean_parameters - .get(id.as_str()) - .copied() - .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: id.as_str().into(), - expected_type: "constant boolean".into(), - }), - } + instance: &Instance, + array: &ArrayExpr, + ) -> Result, FlatZincError> { + instance + .resolve_array(array) + .map_err(|err| FlatZincError::UndefinedArray(err.0))? + .map(|maybe_int| maybe_int.map_err(FlatZincError::from)) + .collect() } - pub(crate) fn resolve_int_expr_to_const( + pub(crate) fn resolve_bool_variable_array( &self, - expr: &flatzinc::IntExpr, - ) -> Result { - match expr { - flatzinc::IntExpr::Int(value) => i32::try_from(*value).map_err(Into::into), - flatzinc::IntExpr::VarParIdentifier(id) => self - .integer_parameters - .get(id.as_str()) - .copied() - .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: id.as_str().into(), - expected_type: "constant integer".into(), - }), - } - } - - pub(crate) fn resolve_int_expr( - &mut self, - expr: &flatzinc::IntExpr, - ) -> Result { - match expr { - flatzinc::IntExpr::Int(value) => Ok(*self - .constant_domain_ids - .entry(*value as i32) - .or_insert_with(|| { - self.solver.new_named_bounded_integer( - *value as i32, - *value as i32, - value.to_string(), - ) - })), - flatzinc::IntExpr::VarParIdentifier(id) => { - self.resolve_integer_variable_from_identifier(id) - } - } + instance: &Instance, + array: &ArrayExpr>, + ) -> Result, FlatZincError> { + instance + .resolve_array(array) + .map_err(|err| FlatZincError::UndefinedArray(err.0))? + .map(|expr_result| { + let expr = expr_result?; + self.resolve_bool_variable(&expr) + }) + .collect() } pub(crate) fn resolve_integer_variable( &mut self, - expr: &flatzinc::Expr, + variable: &VariableExpr, ) -> Result { - match expr { - flatzinc::Expr::VarParIdentifier(id) => { - self.resolve_integer_variable_from_identifier(id) + match variable { + VariableExpr::Identifier(ident) => { + let representative = self.equivalences.representative(ident)?; + + self.variable_map + .get(&representative) + .copied() + .ok_or_else(|| FlatZincError::InvalidIdentifier { + identifier: Rc::clone(ident), + expected_type: "int var".into(), + }) } - flatzinc::Expr::Int(val) => Ok(*self - .constant_domain_ids - .entry(*val as i32) - .or_insert_with(|| { + VariableExpr::Constant(value) => { + Ok(*self.constant_domain_ids.entry(*value).or_insert_with(|| { self.solver - .new_named_bounded_integer(*val as i32, *val as i32, val.to_string()) - })), - _ => Err(FlatZincError::UnexpectedExpr), + .new_named_bounded_integer(*value, *value, value.to_string()) + })) + } } } - pub(crate) fn resolve_integer_variable_from_identifier( - &mut self, - identifier: &str, - ) -> Result { - if !self.equivalences.classes.contains_key(identifier) { - return Err(FlatZincError::InvalidIdentifier { - identifier: identifier.into(), - expected_type: "integer".into(), - }); - } - if let Some(domain_id) = self - .variable_map - .get(&self.equivalences.representative(identifier)) - { - Ok(*domain_id) - } else { - self.integer_parameters - .get(&self.equivalences.representative(identifier)) - .map(|value| { - *self.constant_domain_ids.entry(*value).or_insert_with(|| { - self.solver - .new_named_bounded_integer(*value, *value, value.to_string()) - }) - }) - .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: identifier.into(), - expected_type: "integer variable".into(), - }) - } + pub(crate) fn resolve_integer_array( + &self, + instance: &Instance, + array: &ArrayExpr, + ) -> Result, FlatZincError> { + instance + .resolve_array(array) + .map_err(|err| FlatZincError::UndefinedArray(err.0))? + .map(|maybe_int| maybe_int.map_err(FlatZincError::from)) + .collect() } pub(crate) fn resolve_integer_variable_array( &mut self, - expr: &flatzinc::Expr, - ) -> Result, FlatZincError> { - match expr { - flatzinc::Expr::VarParIdentifier(id) => { - if let Some(domain_id) = self.integer_variable_arrays.get(id.as_str()) { - Ok(Rc::clone(domain_id)) - } else { - self.integer_array_parameters - .get(id.as_str()) - .map(|array| { - array - .iter() - .map(|value| { - *self.constant_domain_ids.entry(*value).or_insert_with(|| { - self.solver.new_named_bounded_integer( - *value, - *value, - value.to_string(), - ) - }) - }) - .collect() - }) - .ok_or_else(|| FlatZincError::InvalidIdentifier { - identifier: id.as_str().into(), - expected_type: "integer variable array".into(), - }) - } - } - flatzinc::Expr::ArrayOfInt(array) => array - .iter() - .map(|elem| self.resolve_int_expr(elem)) - .collect::, _>>(), - - // The AST is not correct here. Since the type of an in-place array containing only - // identifiers cannot be determined, and the parser attempts to parse ArrayOfBool - // first, we may also get this variant even when parsing integer arrays. - flatzinc::Expr::ArrayOfBool(array) => array - .iter() - .map(|elem| { - if let flatzinc::BoolExpr::VarParIdentifier(id) = elem { - self.resolve_integer_variable_from_identifier(id) - } else { - Err(FlatZincError::UnexpectedExpr) - } - }) - .collect(), - _ => Err(FlatZincError::UnexpectedExpr), - } - } - - pub(crate) fn resolve_set_constant(&self, expr: &flatzinc::Expr) -> Result { - match expr { - flatzinc::Expr::VarParIdentifier(id) => { - self.set_constants.get(id.as_str()).cloned().ok_or( - FlatZincError::InvalidIdentifier { - identifier: id.clone().into(), - expected_type: "set of int".into(), - }, - ) - } - - flatzinc::Expr::Set(set_literal) => match set_literal { - flatzinc::SetLiteralExpr::IntInRange(lower_bound_expr, upper_bound_expr) => { - let lower_bound = self.resolve_int_expr_to_const(lower_bound_expr)?; - let upper_bound = self.resolve_int_expr_to_const(upper_bound_expr)?; - - Ok(Set::Interval { - lower_bound, - upper_bound, - }) - } - flatzinc::SetLiteralExpr::SetInts(exprs) => { - let values = exprs - .iter() - .map(|expr| self.resolve_int_expr_to_const(expr)) - .collect::>()?; - - Ok(Set::Sparse { values }) - } - - flatzinc::SetLiteralExpr::BoundedFloat(_, _) - | flatzinc::SetLiteralExpr::SetFloats(_) => panic!("float values are unsupported"), - }, - - flatzinc::Expr::Bool(_) - | flatzinc::Expr::Int(_) - | flatzinc::Expr::Float(_) - | flatzinc::Expr::ArrayOfBool(_) - | flatzinc::Expr::ArrayOfInt(_) - | flatzinc::Expr::ArrayOfFloat(_) - | flatzinc::Expr::ArrayOfSet(_) => Err(FlatZincError::UnexpectedExpr), - } - } -} - -#[derive(Default, Debug)] -pub(crate) struct Identifiers { - interned_identifiers: HashSet>, -} - -impl Identifiers { - pub(crate) fn get_interned(&mut self, identifier: &str) -> Rc { - if let Some(interned) = self.interned_identifiers.get(identifier) { - Rc::clone(interned) - } else { - let interned: Rc = identifier.into(); - let _ = self.interned_identifiers.insert(Rc::clone(&interned)); - - interned - } + instance: &Instance, + array: &ArrayExpr>, + ) -> Result, FlatZincError> { + instance + .resolve_array(array) + .map_err(|err| FlatZincError::UndefinedArray(err.0))? + .map(|expr_result| { + let expr = expr_result?; + self.resolve_integer_variable(&expr) + }) + .collect() } } @@ -578,13 +245,26 @@ impl VariableEquivalences { /// Get the name of the representative variable of the equivalence class the given variable /// belongs to. /// If the variable doesn't belong to an equivalence class, this method panics. - pub(crate) fn representative(&self, variable: &str) -> Rc { - self.classes[variable] + pub(crate) fn representative(&self, variable: &str) -> Result, FlatZincError> { + let equiv_class = + self.classes + .get(variable) + .ok_or_else(|| FlatZincError::InvalidIdentifier { + identifier: variable.into(), + // Since you should never see this error message, we give a dummy value. We + // cannot panic, due to the `identify_output_arrays` implementation that will + // try to resolve non-existent variable names. + expected_type: "?".into(), + })?; + + let ident = equiv_class .borrow() .variables .first() .cloned() - .expect("all classes have at least one representative") + .expect("all classes have at least one representative"); + + Ok(ident) } /// Get the domain for the given variable, based on the equivalence class it belongs to. @@ -610,19 +290,17 @@ pub(crate) enum Domain { SparseDomain { values: Vec }, } -impl From for Domain { - fn from(value: Set) -> Self { - match value { - Set::Interval { - lower_bound, - upper_bound, - } => Domain::IntervalDomain { - lb: lower_bound, - ub: upper_bound, - }, - Set::Sparse { values } => Domain::SparseDomain { - values: values.to_vec(), - }, +impl From<&'_ RangeList> for Domain { + fn from(value: &'_ RangeList) -> Self { + if value.is_continuous() { + Domain::IntervalDomain { + lb: *value.lower_bound(), + ub: *value.upper_bound(), + } + } else { + let values = value.into_iter().collect::<_>(); + + Domain::SparseDomain { values } } } } @@ -705,7 +383,7 @@ mod tests { let c = Rc::from("c"); equivs.create_equivalence_class(Rc::clone(&a), 0, 1); assert!(equivs.is_defined(&a)); - assert_eq!(equivs.representative(&a), a); + assert_eq!(equivs.representative(&a).unwrap(), a); assert_eq!( equivs.domain(&a), Domain::from_lower_bound_and_upper_bound(0, 1) @@ -713,7 +391,7 @@ mod tests { equivs.create_equivalence_class(Rc::clone(&b), 1, 3); assert!(equivs.is_defined(&b)); - assert_eq!(equivs.representative(&b), b); + assert_eq!(equivs.representative(&b).unwrap(), b); assert_eq!( equivs.domain(&b), Domain::from_lower_bound_and_upper_bound(1, 3) @@ -721,7 +399,7 @@ mod tests { equivs.create_equivalence_class(Rc::clone(&c), 5, 10); assert!(equivs.is_defined(&c)); - assert_eq!(equivs.representative(&c), c); + assert_eq!(equivs.representative(&c).unwrap(), c); assert_eq!( equivs.domain(&c), Domain::from_lower_bound_and_upper_bound(5, 10) @@ -729,21 +407,21 @@ mod tests { equivs.merge(Rc::clone(&a), Rc::clone(&b)); assert!(equivs.is_defined(&a)); - assert_eq!(equivs.representative(&a), a); + assert_eq!(equivs.representative(&a).unwrap(), a); assert_eq!( equivs.domain(&a), Domain::from_lower_bound_and_upper_bound(1, 1) ); assert!(equivs.is_defined(&b)); - assert_eq!(equivs.representative(&b), a); + assert_eq!(equivs.representative(&b).unwrap(), a); assert_eq!( equivs.domain(&b), Domain::from_lower_bound_and_upper_bound(1, 1) ); assert!(equivs.is_defined(&c)); - assert_eq!(equivs.representative(&c), c); + assert_eq!(equivs.representative(&c).unwrap(), c); assert_eq!( equivs.domain(&c), Domain::from_lower_bound_and_upper_bound(5, 10) diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_objective.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_objective.rs index eeda3c8d0..97725cc50 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_objective.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_objective.rs @@ -1,44 +1,30 @@ //! Add objective function to solver -use flatzinc::BoolExpr; -use flatzinc::Goal; - use super::context::CompilationContext; -use crate::flatzinc::ast::FlatZincAst; +use crate::flatzinc::ast::Instance; use crate::flatzinc::instance::FlatzincObjective; use crate::flatzinc::FlatZincError; pub(crate) fn run( - ast: &FlatZincAst, + typed_ast: &Instance, context: &mut CompilationContext, ) -> Result, FlatZincError> { - match &ast.solve_item.goal { - Goal::Satisfy => Ok(None), - Goal::OptimizeBool(optimization_type, bool_expr) => { - // The objective function will be parsed as a bool because that is the first identifier - // it will find For now we assume that the objective function is a single - // integer + match &typed_ast.solve.method.node { + fzn_rs::Method::Satisfy => Ok(None), + fzn_rs::Method::Optimize { + direction, + objective, + } => { + let variable = context.resolve_integer_variable(objective)?; - let domain = match bool_expr { - BoolExpr::Bool(_) => unreachable!( - "We do not expect a constant to be present in the objective function!" - ), - BoolExpr::VarParIdentifier(x) => { - if context.is_identifier_parameter(x) { - context.resolve_integer_constant_from_identifier(x)? - } else { - context.resolve_integer_variable_from_identifier(x)? - } + match direction { + fzn_rs::ast::OptimizationDirection::Minimize => { + Ok(Some(FlatzincObjective::Minimize(variable))) } - }; - - Ok(Some(match optimization_type { - flatzinc::OptimizationType::Minimize => FlatzincObjective::Minimize(domain), - flatzinc::OptimizationType::Maximize => FlatzincObjective::Maximize(domain), - })) + fzn_rs::ast::OptimizationDirection::Maximize => { + Ok(Some(FlatzincObjective::Maximize(variable))) + } + } } - _ => todo!( - "For now we assume that the optimisation function is a single integer to optimise" - ), } } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs index c9f5d7617..918e7bfac 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use pumpkin_solver::branching::branchers::dynamic_brancher::DynamicBrancher; use pumpkin_solver::branching::branchers::independent_variable_value_brancher::IndependentVariableValueBrancher; use pumpkin_solver::branching::branchers::warm_start::WarmStart; @@ -11,164 +9,137 @@ use pumpkin_solver::variables::DomainId; use pumpkin_solver::variables::Literal; use super::context::CompilationContext; -use crate::flatzinc::ast::FlatZincAst; -use crate::flatzinc::ast::Search; -use crate::flatzinc::ast::SearchStrategy; +use crate::flatzinc::ast::BoolSearchArgs; +use crate::flatzinc::ast::Instance; +use crate::flatzinc::ast::IntSearchArgs; +use crate::flatzinc::ast::SearchAnnotation; use crate::flatzinc::ast::ValueSelectionStrategy; use crate::flatzinc::ast::VariableSelectionStrategy; use crate::flatzinc::error::FlatZincError; use crate::flatzinc::instance::FlatzincObjective; pub(crate) fn run( - ast: &FlatZincAst, + typed_ast: &Instance, context: &mut CompilationContext, objective: Option, ) -> Result { - create_from_search_strategy(&ast.search, context, true, objective) + let search = typed_ast + .solve + .annotations + .iter() + .map(|node| &node.node) + .next(); + + create_from_search_strategy(typed_ast, search, context, true, objective) } fn create_from_search_strategy( - strategy: &Search, + typed_ast: &Instance, + strategy: Option<&SearchAnnotation>, context: &mut CompilationContext, append_default_search: bool, objective: Option, ) -> Result { let mut brancher = match strategy { - Search::Bool(SearchStrategy { - variables, + Some(SearchAnnotation::BoolSearch(BoolSearchArgs { + variables, + variable_selection_strategy, + value_selection_strategy, + .. + })) => { + let search_variables = context.resolve_bool_variable_array(typed_ast, variables)?; + + create_search_over_propositional_variables( + &search_variables, variable_selection_strategy, value_selection_strategy, - }) => { - let search_variables = match variables { - flatzinc::AnnExpr::String(identifier) => { - vec![context.resolve_bool_variable_from_identifier(identifier)?] - } - flatzinc::AnnExpr::Expr(expr) => { - context.resolve_bool_variable_array(expr)?.as_ref().to_vec() - } - other => panic!("Expected string or expression but got {other:?}"), - }; - - create_search_over_propositional_variables( - &search_variables, - variable_selection_strategy, - value_selection_strategy, - ) - } - Search::Int(SearchStrategy { - variables, + ) + } + Some(SearchAnnotation::IntSearch(IntSearchArgs { + variables, + variable_selection_strategy, + value_selection_strategy, + .. + })) => { + let search_variables = context.resolve_integer_variable_array(typed_ast, variables)?; + + create_search_over_domains( + &search_variables, variable_selection_strategy, value_selection_strategy, - }) => { - let search_variables = match variables { - flatzinc::AnnExpr::String(identifier) => { - // TODO: unnecessary to create Rc here, for now it's just for the return type - Rc::new([context.resolve_integer_variable_from_identifier(identifier)?]) - } - flatzinc::AnnExpr::Expr(expr) => context.resolve_integer_variable_array(expr)?, - other => panic!("Expected string or expression but got {other:?}"), - }; - create_search_over_domains( - &search_variables, - variable_selection_strategy, - value_selection_strategy, - ) - } - Search::Seq(search_strategies) => DynamicBrancher::new( - search_strategies - .iter() - .map(|strategy| { - let downcast: Box = Box::new( - create_from_search_strategy(strategy, context, false, objective) - .expect("Expected nested sequential strategy to be able to be created"), - ); - downcast - }) - .collect::>(), - ), - Search::Unspecified => { - assert!( - append_default_search, - "when no search is specified, we must add a default search" - ); - - // The default search will be added below, so we give an empty brancher here. - DynamicBrancher::new(vec![]) - } - Search::WarmStartInt { variables, values } => { - match variables { - flatzinc::AnnExpr::String(identifier) => { - panic!("Expected either an array of integers or an array of booleans; not an identifier {identifier}") - } - flatzinc::AnnExpr::Expr(expr) => { - let int_variable_array = context.resolve_integer_variable_array(expr)?; - match values { - flatzinc::AnnExpr::Expr(expr) => { - let int_values_array = context.resolve_array_integer_constants(expr)?; - DynamicBrancher::new(vec![Box::new(WarmStart::new( - &int_variable_array, - &int_values_array, - ))]) - } - x => panic!("Expected an array of integers or an array of booleans; but got {x:?}"), - } - } - other => panic!("Expected expression but got {other:?}"), - } - }, - Search::WarmStartBool{ variables, values } => { - match variables { - flatzinc::AnnExpr::String(identifier) => { - panic!("Expected either an array of integers or an array of booleans; not an identifier {identifier}") - } - flatzinc::AnnExpr::Expr(expr) => { - let bool_variable_array = context - .resolve_bool_variable_array(expr)? - .iter() - .map(|literal| literal.get_integer_variable()) - .collect::>(); - - match values { - flatzinc::AnnExpr::Expr(expr) => { - let bool_values_array = context - .resolve_bool_constants(expr)? - .iter() - .map(|&bool_value| if bool_value { 1 } else { 0 }) - .collect::>(); - DynamicBrancher::new(vec![Box::new(WarmStart::new( - &bool_variable_array, - &bool_values_array, - ))]) - } - x => panic!("Expected an array of integers or an array of booleans; but got {x:?}"), - } - } - other => panic!("Expected expression but got {other:?}"), - } - } - Search::WarmStartArray(search_strategies) => DynamicBrancher::new( - search_strategies - .iter() - .map(|strategy| { - assert!( - matches!(strategy, Search::WarmStartBool { variables: _, values: _ }) || - matches!( - strategy, - Search::WarmStartInt { - variables: _, - values: _ - } - ) || matches!(strategy, Search::WarmStartArray(_)) - , "Expected warm start strategy to consist of either `warm_start` or other `warm_start_array` annotations" - ); - let downcast: Box = Box::new( - create_from_search_strategy(strategy, context, false, objective) - .expect("Expected nested sequential strategy to be able to be created"), - ); - downcast - }) - .collect::>(), - ), + ) + } + Some(SearchAnnotation::Seq(search_strategies)) => DynamicBrancher::new( + search_strategies + .iter() + .map(|strategy| { + let downcast: Box = Box::new( + create_from_search_strategy( + typed_ast, + Some(strategy), + context, + false, + objective, + ) + .expect("Expected nested sequential strategy to be able to be created"), + ); + downcast + }) + .collect::>(), + ), + + Some(SearchAnnotation::WarmStartInt(args)) => { + let search_variables = + context.resolve_integer_variable_array(typed_ast, &args.variables)?; + let values = context.resolve_integer_array(typed_ast, &args.values)?; + + DynamicBrancher::new(vec![Box::new(WarmStart::new(&search_variables, &values))]) + } + + Some(SearchAnnotation::WarmStartBool(args)) => { + let search_variables = context + .resolve_bool_variable_array(typed_ast, &args.variables)? + .into_iter() + .map(|literal| literal.get_integer_variable()) + .collect::>(); + let values = context + .resolve_bool_array(typed_ast, &args.values)? + .into_iter() + .map(|boolean| if boolean { 1 } else { 0 }) + .collect::>(); + + DynamicBrancher::new(vec![Box::new(WarmStart::new(&search_variables, &values))]) + } + + Some(SearchAnnotation::WarmStartArray(warm_starts)) => { + let nested_warm_starts = warm_starts + .iter() + .map(|strategy| { + #[allow(trivial_casts, reason = "otherwise we get a compiler error")] + let brancher = Box::new(create_from_search_strategy( + typed_ast, + Some(strategy), + context, + append_default_search, + objective, + )?) as Box; + + Ok(brancher) + }) + .collect::, FlatZincError>>()?; + + DynamicBrancher::new(nested_warm_starts) + } + + None => { + assert!( + append_default_search, + "when no search is specified, we must add a default search" + ); + + // The default search will be added below, so we give an empty brancher here. + DynamicBrancher::new(vec![]) + } }; if append_default_search { diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_constants.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_constants.rs deleted file mode 100644 index 658e80bef..000000000 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_constants.rs +++ /dev/null @@ -1,84 +0,0 @@ -//! Compilation phase that processes the parameter declarations into constants. - -use std::rc::Rc; - -use super::context::CompilationContext; -use super::context::Set; -use crate::flatzinc::ast::FlatZincAst; -use crate::flatzinc::FlatZincError; - -pub(crate) fn run( - ast: &FlatZincAst, - context: &mut CompilationContext, -) -> Result<(), FlatZincError> { - for parameter_decl in &ast.parameter_decls { - match parameter_decl { - flatzinc::ParDeclItem::Bool { id, bool } => { - let _ = context - .boolean_parameters - .insert(context.identifiers.get_interned(id), *bool); - } - - flatzinc::ParDeclItem::Int { id, int } => { - let value = i32::try_from(*int)?; - - let _ = context - .integer_parameters - .insert(context.identifiers.get_interned(id), value); - } - - flatzinc::ParDeclItem::ArrayOfBool { id, v, .. } => { - let _ = context - .boolean_array_parameters - .insert(context.identifiers.get_interned(id), v.clone().into()); - } - - flatzinc::ParDeclItem::ArrayOfInt { id, v, .. } => { - let value = v - .iter() - .map(|value| i32::try_from(*value)) - .collect::, _>>()?; - - let _ = context - .integer_array_parameters - .insert(context.identifiers.get_interned(id), value); - } - - flatzinc::ParDeclItem::SetOfInt { id, set_literal } => { - let set = match set_literal { - flatzinc::SetLiteral::IntRange(lower_bound, upper_bound) => Set::Interval { - lower_bound: i32::try_from(*lower_bound)?, - upper_bound: i32::try_from(*upper_bound)?, - }, - - flatzinc::SetLiteral::SetInts(values) => { - let values = values - .iter() - .copied() - .map(i32::try_from) - .collect::>()?; - - Set::Sparse { values } - } - - flatzinc::SetLiteral::BoundedFloat(_, _) - | flatzinc::SetLiteral::SetFloats(_) => panic!("float values are unsupported"), - }; - - let _ = context - .set_constants - .insert(context.identifiers.get_interned(id), set); - } - - flatzinc::ParDeclItem::ArrayOfSet { .. } => { - todo!("implement array of integer set parameters") - } - - flatzinc::ParDeclItem::Float { .. } | flatzinc::ParDeclItem::ArrayOfFloat { .. } => { - panic!("floats are not supported") - } - } - } - - Ok(()) -} diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/handle_set_in.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/handle_set_in.rs index d975f6eec..f2db2eaec 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/handle_set_in.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/handle_set_in.rs @@ -1,31 +1,37 @@ //! Scan through all constraint definition and determine whether a `set_in` constraint is present; -//! if this is the case then update the domain of the variable directly. +//! is this is the case then update the domain of the variable directly. +use std::rc::Rc; + +use fzn_rs::VariableExpr; + use super::context::CompilationContext; +use crate::flatzinc::ast::Instance; +use crate::flatzinc::constraints::Constraints; use crate::flatzinc::error::FlatZincError; -pub(crate) fn run(context: &mut CompilationContext) -> Result<(), FlatZincError> { - for (_, constraint_item) in &context.constraints { - let flatzinc::ConstraintItem { - id, - exprs, - annos: _, - } = constraint_item; - if id != "set_in" { - continue; - } - - let set = context.resolve_set_constant(&exprs[1])?; +pub(crate) fn run( + instance: &mut Instance, + context: &mut CompilationContext, +) -> Result<(), FlatZincError> { + instance.constraints.retain(|constraint| { + let (variable, set) = match &constraint.constraint.node { + Constraints::SetIn(variable, set) => (variable, set), + _ => return true, + }; - let id = context.identifiers.get_interned(match &exprs[0] { - flatzinc::Expr::VarParIdentifier(id) => id, - _ => return Err(FlatZincError::UnexpectedExpr), - }); + let id = match variable { + VariableExpr::Identifier(id) => Rc::clone(id), + _ => unreachable!("This constraint makes no sense with a constant."), + }; let mut domain = context.equivalences.get_mut_domain(&id); // We take the intersection between the two domains let new_domain = domain.merge(&set.into()); *domain = new_domain; - } + + false + }); + Ok(()) } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/identify_output_arrays.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/identify_output_arrays.rs new file mode 100644 index 000000000..8a56b8ec7 --- /dev/null +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/identify_output_arrays.rs @@ -0,0 +1,84 @@ +use std::rc::Rc; + +use fzn_rs::ast::RangeList; +use fzn_rs::ast::{self}; +use fzn_rs::FromLiteral; +use fzn_rs::VariableExpr; + +use super::CompilationContext; +use crate::flatzinc::ast::ArrayAnnotations; +use crate::flatzinc::ast::Instance; +use crate::flatzinc::error::FlatZincError; +use crate::flatzinc::instance::Output; + +pub(crate) fn run( + instance: &Instance, + context: &mut CompilationContext, +) -> Result<(), FlatZincError> { + for (name, array) in &instance.arrays { + #[allow( + clippy::unnecessary_find_map, + reason = "it is only unnecessary because ArrayAnnotations has one variant" + )] + let Some(shape) = array.annotations.iter().find_map(|ann| match &ann.node { + ArrayAnnotations::OutputArray(array_expr) => { + let shape = instance + .resolve_array(array_expr) + .map_err(|err| FlatZincError::UndefinedArray(err.0)) + .and_then(|iter| iter.collect::, _>>().map_err(Into::into)) + .map(parse_array_shape); + + Some(shape) + } + }) else { + continue; + }; + + let shape = shape?; + + let output = match array.domain.node { + ast::Domain::UnboundedInt | ast::Domain::Int(_) => { + let variables = array + .contents + .iter() + .map(|node| { + let variable = as FromLiteral>::from_literal(node)?; + + let solver_variable = context.resolve_integer_variable(&variable)?; + Ok(solver_variable) + }) + .collect::, FlatZincError>>()?; + + Output::array_of_int(Rc::clone(name), shape.clone(), variables) + } + + ast::Domain::Bool => { + let variables = array + .contents + .iter() + .map(|node| { + let variable = as FromLiteral>::from_literal(node)?; + + let solver_variable = context.resolve_bool_variable(&variable)?; + Ok(solver_variable) + }) + .collect::, FlatZincError>>()?; + + Output::array_of_bool(Rc::clone(name), shape.clone(), variables) + } + }; + + context.outputs.push(output); + } + + Ok(()) +} + +/// Parse an array of ranges, which is the argument to the `output_array` annotation, to a slice of +/// pairs which is expect by our output system. +fn parse_array_shape(ranges: Vec>) -> Box<[(i32, i32)]> { + ranges + .iter() + .map(|ranges| (*ranges.lower_bound(), *ranges.upper_bound())) + .collect() +} diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/merge_equivalences.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/merge_equivalences.rs index 48792d033..7066c301a 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/merge_equivalences.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/merge_equivalences.rs @@ -1,24 +1,27 @@ //! Merge equivalence classes of each variable definition that refers to another variable. -use flatzinc::ConstraintItem; +use std::rc::Rc; + +use fzn_rs::ast; +use fzn_rs::VariableExpr; use log::warn; -use crate::flatzinc::ast::FlatZincAst; -use crate::flatzinc::ast::SingleVarDecl; +use crate::flatzinc::ast::Instance; use crate::flatzinc::compiler::context::CompilationContext; -use crate::flatzinc::compiler::context::Identifiers; -use crate::flatzinc::compiler::context::VariableEquivalences; +use crate::flatzinc::constraints::Binary; +use crate::flatzinc::constraints::BoolToIntArgs; +use crate::flatzinc::constraints::Constraints; use crate::flatzinc::FlatZincError; use crate::FlatZincOptions; use crate::ProofType; pub(crate) fn run( - ast: &mut FlatZincAst, + typed_ast: &mut Instance, context: &mut CompilationContext, options: &FlatZincOptions, ) -> Result<(), FlatZincError> { - handle_variable_equality_expressions(ast, context, options)?; - remove_int_eq_constraints(context, options)?; + handle_variable_equality_expressions(typed_ast, context, options)?; + remove_equality_constraints(typed_ast, context, options)?; Ok(()) } @@ -41,68 +44,48 @@ fn panic_if_logging_proof(options: &FlatZincOptions) { } fn handle_variable_equality_expressions( - ast: &FlatZincAst, + typed_ast: &Instance, context: &mut CompilationContext, options: &FlatZincOptions, ) -> Result<(), FlatZincError> { - for single_var_decl in &ast.single_variables { - match single_var_decl { - SingleVarDecl::Bool { id, expr, .. } => { - let id = context.identifiers.get_interned(id); - - let Some(flatzinc::BoolExpr::VarParIdentifier(identifier)) = expr else { - continue; - }; - - if !context.equivalences.is_defined(&id) - && context.boolean_parameters.contains_key(&id) - { - // The identifier points to a parameter. - continue; - } + for (name, variable) in typed_ast.variables.iter() { + let other_variable = match &variable.value { + Some(ast::Node { + node: ast::Literal::Identifier(id), + .. + }) => Rc::clone(id), + _ => continue, + }; - if !context.equivalences.is_defined(&id) { + match variable.domain.node { + ast::Domain::Bool => { + if !context.equivalences.is_defined(&other_variable) { return Err(FlatZincError::InvalidIdentifier { - identifier: id.as_ref().into(), + identifier: other_variable.as_ref().into(), expected_type: "var bool".into(), }); } panic_if_logging_proof(options); - let other_id = context.identifiers.get_interned(identifier); - context.equivalences.merge(id, other_id); + context.equivalences.merge(other_variable, Rc::clone(name)); } - SingleVarDecl::IntInRange { id, expr, .. } => { - let id = context.identifiers.get_interned(id); - - let Some(flatzinc::IntExpr::VarParIdentifier(identifier)) = expr else { - continue; - }; - - if !context.equivalences.is_defined(&id) - && context.integer_parameters.contains_key(&id) - { - // The identifier points to a parameter. - continue; - } - - if !context.equivalences.is_defined(&id) { + ast::Domain::Int(_) => { + if !context.equivalences.is_defined(&other_variable) { return Err(FlatZincError::InvalidIdentifier { - identifier: id.as_ref().into(), + identifier: other_variable.as_ref().into(), expected_type: "var bool".into(), }); } panic_if_logging_proof(options); - let other_id = context.identifiers.get_interned(identifier); - context.equivalences.merge(id, other_id); + context.equivalences.merge(other_variable, Rc::clone(name)); } - SingleVarDecl::IntInSet { .. } => { - // We do not handle exquivalences for sparse-set domains. + ast::Domain::UnboundedInt => { + return Err(FlatZincError::UnsupportedVariable(name.as_ref().into())) } } } @@ -110,7 +93,8 @@ fn handle_variable_equality_expressions( Ok(()) } -fn remove_int_eq_constraints( +fn remove_equality_constraints( + typed_ast: &mut Instance, context: &mut CompilationContext, options: &FlatZincOptions, ) -> Result<(), FlatZincError> { @@ -118,209 +102,173 @@ fn remove_int_eq_constraints( return Ok(()); } - context.constraints.retain(|(_, constraint)| { - should_keep_constraint( - constraint, - &mut context.equivalences, - &mut context.identifiers, - ) - }); + typed_ast + .constraints + .retain(|constraint| should_keep_constraint(constraint, context)); Ok(()) } -fn should_keep_int_eq_constraint( - constraint: &ConstraintItem, - identifiers: &mut Identifiers, - equivalences: &mut VariableEquivalences, +/// Possibly merges some equivalence classes based on the constraint. Returns `true` if the +/// constraint needs to be retained, and `false` if it can be removed from the AST. +fn should_keep_constraint( + constraint: &fzn_rs::AnnotatedConstraint, + context: &mut CompilationContext, ) -> bool { - let v1 = match &constraint.exprs[0] { - flatzinc::Expr::VarParIdentifier(id) => identifiers.get_interned(id), - flatzinc::Expr::Int(_) => { - // I don't expect this to be called, but I am not sure. To make it obvious when it does - // happen, the warning is logged. - warn!("'int_eq' with constant argument, ignoring it for merging equivalences"); - return true; - } - flatzinc::Expr::Float(_) - | flatzinc::Expr::Bool(_) - | flatzinc::Expr::Set(_) - | flatzinc::Expr::ArrayOfBool(_) - | flatzinc::Expr::ArrayOfInt(_) - | flatzinc::Expr::ArrayOfFloat(_) - | flatzinc::Expr::ArrayOfSet(_) => unreachable!(), - }; + let (v1, v2) = match &constraint.constraint.node { + Constraints::IntEq(Binary(lhs, rhs)) => { + let v1 = match lhs { + VariableExpr::Identifier(id) => Rc::clone(id), + VariableExpr::Constant(_) => { + // I don't expect this to be called, but I am not sure. To make it obvious when + // it does happen, the warning is logged. + warn!("'int_eq' with constant argument, ignoring it for merging equivalences"); + return true; + } + }; + + let v2 = match rhs { + VariableExpr::Identifier(id) => Rc::clone(id), + VariableExpr::Constant(_) => { + // I don't expect this to be called, but I am not sure. To make it obvious when + // it does happen, the warning is logged. + warn!("'int_eq' with constant argument, ignoring it for merging equivalences"); + return true; + } + }; - let v2 = match &constraint.exprs[1] { - flatzinc::Expr::VarParIdentifier(id) => identifiers.get_interned(id), - flatzinc::Expr::Int(_) => { - // I don't expect this to be called, but I am not sure. To make it obvious when it does - // happen, the warning is logged. - warn!("'int_eq' with constant argument, ignoring it for merging equivalences"); - return true; + (v1, v2) } - flatzinc::Expr::Float(_) - | flatzinc::Expr::Bool(_) - | flatzinc::Expr::Set(_) - | flatzinc::Expr::ArrayOfBool(_) - | flatzinc::Expr::ArrayOfInt(_) - | flatzinc::Expr::ArrayOfFloat(_) - | flatzinc::Expr::ArrayOfSet(_) => unreachable!(), - }; - equivalences.merge(v1, v2); - - false -} + Constraints::BoolToInt(BoolToIntArgs { boolean, integer }) => { + let v1 = match boolean { + VariableExpr::Identifier(id) => Rc::clone(id), + VariableExpr::Constant(_) => { + // I don't expect this to be called, but I am not sure. To make it obvious when + // it does happen, the warning is logged. + warn!("'int_eq' with constant argument, ignoring it for merging equivalences"); + return true; + } + }; + + let v2 = match integer { + VariableExpr::Identifier(id) => Rc::clone(id), + VariableExpr::Constant(_) => { + // I don't expect this to be called, but I am not sure. To make it obvious when + // it does happen, the warning is logged. + warn!("'int_eq' with constant argument, ignoring it for merging equivalences"); + return true; + } + }; -fn should_keep_bool2int_constraint( - constraint: &ConstraintItem, - identifiers: &mut Identifiers, - equivalences: &mut VariableEquivalences, -) -> bool { - let v1 = match &constraint.exprs[0] { - flatzinc::Expr::VarParIdentifier(id) => identifiers.get_interned(id), - flatzinc::Expr::Bool(_) => { - // I don't expect this to be called, but I am not sure. To make it obvious when it does - // happen, the warning is logged. - warn!("'bool2int' with constant argument, ignoring it for merging equivalences"); - return true; + (v1, v2) } - flatzinc::Expr::Float(_) - | flatzinc::Expr::Int(_) - | flatzinc::Expr::Set(_) - | flatzinc::Expr::ArrayOfBool(_) - | flatzinc::Expr::ArrayOfInt(_) - | flatzinc::Expr::ArrayOfFloat(_) - | flatzinc::Expr::ArrayOfSet(_) => unreachable!(), - }; - let v2 = match &constraint.exprs[1] { - flatzinc::Expr::VarParIdentifier(id) => identifiers.get_interned(id), - flatzinc::Expr::Bool(_) => { - // I don't expect this to be called, but I am not sure. To make it obvious when it does - // happen, the warning is logged. - warn!("'bool2int' with constant argument, ignoring it for merging equivalences"); - return true; - } - flatzinc::Expr::Float(_) - | flatzinc::Expr::Int(_) - | flatzinc::Expr::Set(_) - | flatzinc::Expr::ArrayOfBool(_) - | flatzinc::Expr::ArrayOfInt(_) - | flatzinc::Expr::ArrayOfFloat(_) - | flatzinc::Expr::ArrayOfSet(_) => unreachable!(), + _ => return true, }; - equivalences.merge(v1, v2); + context.equivalences.merge(v1, v2); false } -/// Possibly merges some equivalence classes based on the constraint. Returns `true` if the -/// constraint needs to be retained, and `false` if it can be removed from the AST. -fn should_keep_constraint( - constraint: &ConstraintItem, - equivalences: &mut VariableEquivalences, - identifiers: &mut Identifiers, -) -> bool { - match constraint.id.as_str() { - "int_eq" => should_keep_int_eq_constraint(constraint, identifiers, equivalences), - "bool2int" => should_keep_bool2int_constraint(constraint, identifiers, equivalences), - _ => true, - } -} - #[cfg(test)] mod tests { - use flatzinc::ConstraintItem; - use flatzinc::Expr; - use flatzinc::SolveItem; + use std::collections::BTreeMap; + + use fzn_rs::AnnotatedConstraint; + use fzn_rs::Method; + use fzn_rs::Solve; use pumpkin_solver::Solver; use super::*; #[test] fn int_eq_constraints_cause_merging_of_equivalence_classes() { - let mut ast_builder = FlatZincAst::builder(); - - ast_builder.add_variable_decl(SingleVarDecl::IntInRange { - id: "x".into(), - lb: 1, - ub: 5, - expr: None, - annos: vec![], - }); - ast_builder.add_variable_decl(SingleVarDecl::IntInRange { - id: "y".into(), - lb: 1, - ub: 5, - expr: None, - annos: vec![], - }); - ast_builder.add_constraint(ConstraintItem { - id: "int_eq".into(), - exprs: vec![ - Expr::VarParIdentifier("x".into()), - Expr::VarParIdentifier("y".into()), - ], - annos: vec![], - }); - ast_builder.set_solve_item(SolveItem { - goal: flatzinc::Goal::Satisfy, - annotations: vec![], - }); - - let mut ast = ast_builder.build().expect("valid ast"); + let mut instance = Instance { + variables: BTreeMap::from([ + ( + "x".into(), + ast::Variable { + domain: test_node(ast::Domain::Int(ast::RangeList::from(1..=5))), + value: None, + annotations: vec![], + }, + ), + ( + "y".into(), + ast::Variable { + domain: test_node(ast::Domain::Int(ast::RangeList::from(1..=5))), + value: None, + annotations: vec![], + }, + ), + ]), + arrays: BTreeMap::new(), + constraints: vec![AnnotatedConstraint { + constraint: test_node(Constraints::IntEq(Binary( + VariableExpr::Identifier("x".into()), + VariableExpr::Identifier("y".into()), + ))), + annotations: vec![], + }], + solve: Solve { + method: test_node(Method::Satisfy), + annotations: vec![], + }, + }; + let mut solver = Solver::default(); let mut context = CompilationContext::new(&mut solver); let options = FlatZincOptions::default(); - super::super::reserve_constraint_tags::run(&ast, &mut context) + super::super::prepare_variables::run(&instance, &mut context) .expect("step should not fail"); - super::super::prepare_variables::run(&ast, &mut context).expect("step should not fail"); - run(&mut ast, &mut context, &options).expect("step should not fail"); + run(&mut instance, &mut context, &options).expect("step should not fail"); assert_eq!( - context.equivalences.representative("x"), - context.equivalences.representative("y") + context.equivalences.representative("x").unwrap(), + context.equivalences.representative("y").unwrap(), ); - assert!(context.constraints.is_empty()); + assert!(instance.constraints.is_empty()); } #[test] fn int_eq_does_not_merge_when_full_proof_is_being_logged() { - let mut ast_builder = FlatZincAst::builder(); - - ast_builder.add_variable_decl(SingleVarDecl::IntInRange { - id: "x".into(), - lb: 1, - ub: 5, - expr: None, - annos: vec![], - }); - ast_builder.add_variable_decl(SingleVarDecl::IntInRange { - id: "y".into(), - lb: 1, - ub: 5, - expr: None, - annos: vec![], - }); - ast_builder.add_constraint(ConstraintItem { - id: "int_eq".into(), - exprs: vec![ - Expr::VarParIdentifier("x".into()), - Expr::VarParIdentifier("y".into()), - ], - annos: vec![], - }); - ast_builder.set_solve_item(SolveItem { - goal: flatzinc::Goal::Satisfy, - annotations: vec![], - }); - - let mut ast = ast_builder.build().expect("valid ast"); + let mut instance = Instance { + variables: BTreeMap::from([ + ( + "x".into(), + ast::Variable { + domain: test_node(ast::Domain::Int(ast::RangeList::from(1..=5))), + value: None, + annotations: vec![], + }, + ), + ( + "y".into(), + ast::Variable { + domain: test_node(ast::Domain::Int(ast::RangeList::from(1..=5))), + value: None, + annotations: vec![], + }, + ), + ]), + arrays: BTreeMap::new(), + constraints: vec![AnnotatedConstraint { + constraint: test_node(Constraints::IntEq(Binary( + VariableExpr::Identifier("x".into()), + VariableExpr::Identifier("y".into()), + ))), + annotations: vec![], + }], + solve: Solve { + method: test_node(Method::Satisfy), + annotations: vec![], + }, + }; + let mut solver = Solver::default(); let mut context = CompilationContext::new(&mut solver); let options = FlatZincOptions { @@ -328,16 +276,22 @@ mod tests { ..Default::default() }; - super::super::reserve_constraint_tags::run(&ast, &mut context) + super::super::prepare_variables::run(&instance, &mut context) .expect("step should not fail"); - super::super::prepare_variables::run(&ast, &mut context).expect("step should not fail"); - run(&mut ast, &mut context, &options).expect("step should not fail"); + run(&mut instance, &mut context, &options).expect("step should not fail"); assert_ne!( - context.equivalences.representative("x"), - context.equivalences.representative("y") + context.equivalences.representative("x").unwrap(), + context.equivalences.representative("y").unwrap(), ); - assert_eq!(context.constraints.len(), 1); + assert_eq!(instance.constraints.len(), 1); + } + + fn test_node(data: T) -> ast::Node { + ast::Node { + node: data, + span: ast::Span { start: 0, end: 0 }, + } } } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/mod.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/mod.rs index 4ea95c35e..6eb72711c 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/mod.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/mod.rs @@ -2,41 +2,41 @@ mod collect_domains; mod context; mod create_objective; mod create_search_strategy; -mod define_constants; -mod define_variable_arrays; mod handle_set_in; +mod identify_output_arrays; mod merge_equivalences; mod post_constraints; mod prepare_variables; mod remove_unused_variables; mod reserve_constraint_tags; -use context::CompilationContext; +pub(crate) use context::CompilationContext; use pumpkin_solver::Solver; -use super::ast::FlatZincAst; use super::instance::FlatZincInstance; use super::FlatZincError; use super::FlatZincOptions; pub(crate) fn compile( - mut ast: FlatZincAst, + mut ast: fzn_rs::ast::Ast, solver: &mut Solver, options: FlatZincOptions, ) -> Result { let mut context = CompilationContext::new(solver); - define_constants::run(&ast, &mut context)?; - reserve_constraint_tags::run(&ast, &mut context)?; - remove_unused_variables::run(&mut ast, &mut context)?; - prepare_variables::run(&ast, &mut context)?; - merge_equivalences::run(&mut ast, &mut context, &options)?; - handle_set_in::run(&mut context)?; - collect_domains::run(&ast, &mut context)?; - define_variable_arrays::run(&ast, &mut context)?; - post_constraints::run(&ast, &mut context, &options)?; - let objective_function = create_objective::run(&ast, &mut context)?; - let search = create_search_strategy::run(&ast, &mut context, objective_function)?; + remove_unused_variables::run(&mut ast)?; + + let mut typed_ast = super::ast::Instance::from_ast(ast)?; + reserve_constraint_tags::run(&mut typed_ast, &mut context)?; + + prepare_variables::run(&typed_ast, &mut context)?; + merge_equivalences::run(&mut typed_ast, &mut context, &options)?; + handle_set_in::run(&mut typed_ast, &mut context)?; + collect_domains::run(&typed_ast, &mut context)?; + identify_output_arrays::run(&typed_ast, &mut context)?; + post_constraints::run(&typed_ast, &mut context, &options)?; + let objective_function = create_objective::run(&typed_ast, &mut context)?; + let search = create_search_strategy::run(&typed_ast, &mut context, objective_function)?; Ok(FlatZincInstance { outputs: context.outputs, diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs index 5536e5500..af1c794a2 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs @@ -1,7 +1,6 @@ //! Compile constraints into CP propagators -use std::rc::Rc; - +use pumpkin_core::variables::Literal; use pumpkin_solver::constraints; use pumpkin_solver::constraints::Constraint; use pumpkin_solver::constraints::NegatableConstraint; @@ -13,228 +12,312 @@ use pumpkin_solver::variables::DomainId; use pumpkin_solver::variables::TransformableVariable; use super::context::CompilationContext; -use crate::flatzinc::ast::FlatZincAst; -use crate::flatzinc::compiler::context::Set; +use crate::flatzinc::ast::ConstraintAnnotations; +use crate::flatzinc::ast::Instance; +use crate::flatzinc::constraints::ArrayBoolArgs; +use crate::flatzinc::constraints::Binary; +use crate::flatzinc::constraints::BinaryBool; +use crate::flatzinc::constraints::BinaryBoolReif; +use crate::flatzinc::constraints::BoolClauseArgs; +use crate::flatzinc::constraints::BoolElementArgs; +use crate::flatzinc::constraints::BoolLinEqArgs; +use crate::flatzinc::constraints::BoolLinLeArgs; +use crate::flatzinc::constraints::BoolToIntArgs; +use crate::flatzinc::constraints::Constraints; +use crate::flatzinc::constraints::CumulativeArgs; +use crate::flatzinc::constraints::IntElementArgs; +use crate::flatzinc::constraints::Linear; +use crate::flatzinc::constraints::ReifiedBinary; +use crate::flatzinc::constraints::ReifiedLinear; +use crate::flatzinc::constraints::SetInReifArgs; +use crate::flatzinc::constraints::TableInt; +use crate::flatzinc::constraints::TableIntReif; +use crate::flatzinc::constraints::TernaryIntArgs; use crate::flatzinc::FlatZincError; use crate::flatzinc::FlatZincOptions; pub(crate) fn run( - _: &FlatZincAst, + instance: &Instance, context: &mut CompilationContext, options: &FlatZincOptions, ) -> Result<(), FlatZincError> { - for (constraint_tag, constraint_item) in std::mem::take(&mut context.constraints) { - let flatzinc::ConstraintItem { id, exprs, annos } = &constraint_item; - - let is_satisfiable: bool = match id.as_str() { - "array_int_maximum" => compile_array_int_maximum(context, exprs, constraint_tag)?, - "array_int_minimum" => compile_array_int_minimum(context, exprs, constraint_tag)?, - "int_max" => { - compile_ternary_int_predicate(context, exprs, annos, "int_max", constraint_tag, |a, b, c, constraint_tag| { - constraints::maximum([a, b], c, constraint_tag) - })? + use Constraints::*; + + for constraint in &instance.constraints { + #[allow( + clippy::unnecessary_find_map, + reason = "when there are more variants on ConstraintAnnotations, this is the cleaner way" + )] + let constraint_tag = constraint + .annotations + .iter() + .find_map(|ann| match &ann.node { + ConstraintAnnotations::ConstraintTag(tag) => Some((*tag).into()), + }) + .expect("every constraint should have been associated with a tag at an earlier stage"); + + let is_satisfiable: bool = match &constraint.constraint.node { + ArrayIntMinimum(args) => { + let array = context.resolve_integer_variable_array(instance, &args.array)?; + let rhs = context.resolve_integer_variable(&args.extremum)?; + + constraints::minimum(array, rhs, constraint_tag) + .post(context.solver) + .is_ok() } - "int_min" => { - compile_ternary_int_predicate(context, exprs, annos, "int_min", constraint_tag, |a, b, c, constraint_tag| { - constraints::minimum([a, b], c, constraint_tag) - })? + + ArrayIntMaximum(args) => { + let array = context.resolve_integer_variable_array(instance, &args.array)?; + let rhs = context.resolve_integer_variable(&args.extremum)?; + + constraints::maximum(array, rhs, constraint_tag) + .post(context.solver) + .is_ok() } // We rewrite `array_int_element` to `array_var_int_element`. - "array_int_element" => compile_array_var_int_element(context, exprs, constraint_tag)?, - "array_var_int_element" => compile_array_var_int_element(context, exprs, constraint_tag)?, - - "int_eq_imp" => compile_binary_int_imp(context, exprs, annos, "int_eq_imp", constraint_tag, constraints::binary_equals)?, - "int_ge_imp" => compile_binary_int_imp(context, exprs, annos, "int_ge_imp", constraint_tag, constraints::binary_greater_than_or_equals)?, - "int_gt_imp" => compile_binary_int_imp(context, exprs, annos, "int_gt_imp", constraint_tag, constraints::binary_greater_than)?, - "int_le_imp" => compile_binary_int_imp(context, exprs, annos, "int_le_imp", constraint_tag, constraints::binary_less_than_or_equals)?, - "int_lt_imp" => compile_binary_int_imp(context, exprs, annos, "int_lt_imp", constraint_tag, constraints::binary_less_than)?, - "int_ne_imp" => compile_binary_int_imp(context, exprs, annos, "int_ne_imp", constraint_tag, constraints::binary_not_equals)?, - - "int_lin_eq_imp" => compile_int_lin_imp_predicate(context, exprs, annos, "int_lin_eq_imp", constraint_tag, constraints::equals)?, - "int_lin_ge_imp" => compile_int_lin_imp_predicate(context, exprs, annos, "int_lin_ge_imp", constraint_tag, constraints::greater_than_or_equals)?, - "int_lin_gt_imp" => compile_int_lin_imp_predicate(context, exprs, annos, "int_lin_gt_imp", constraint_tag, constraints::greater_than)?, - "int_lin_le_imp" => compile_int_lin_imp_predicate(context, exprs, annos, "int_lin_le_imp", constraint_tag, constraints::less_than_or_equals)?, - "int_lin_lt_imp" => compile_int_lin_imp_predicate(context, exprs, annos, "int_lin_lt_imp", constraint_tag, constraints::less_than)?, - "int_lin_ne_imp" => compile_int_lin_imp_predicate(context, exprs, annos, "int_lin_ne_imp", constraint_tag, constraints::not_equals)?, - - "int_lin_ne" => compile_int_lin_predicate( + ArrayIntElement(args) => { + compile_array_var_int_element(instance, context, args, constraint_tag)? + } + ArrayVarIntElement(args) => { + compile_array_var_int_element(instance, context, args, constraint_tag)? + } + + IntEqImp(args) => { + compile_binary_int_imp(context, args, constraint_tag, constraints::binary_equals)? + } + IntGeImp(args) => compile_binary_int_imp( context, - exprs, - annos, - "int_lin_ne", + args, constraint_tag, - constraints::not_equals, + constraints::binary_greater_than_or_equals, + )?, + IntGtImp(args) => compile_binary_int_imp( + context, + args, + constraint_tag, + constraints::binary_greater_than, )?, - "int_lin_ne_reif" => compile_reified_int_lin_predicate( + IntLeImp(args) => compile_binary_int_imp( context, - exprs, - annos, - "int_lin_ne_reif", + args, + constraint_tag, + constraints::binary_less_than_or_equals, + )?, + IntLtImp(args) => compile_binary_int_imp( + context, + args, + constraint_tag, + constraints::binary_less_than, + )?, + IntNeImp(args) => compile_binary_int_imp( + context, + args, + constraint_tag, + constraints::binary_not_equals, + )?, + + IntLinNe(args) => compile_int_lin_predicate( + instance, + context, + args, constraint_tag, constraints::not_equals, )?, - "int_lin_le" => compile_int_lin_predicate( + IntLinLe(args) => compile_int_lin_predicate( + instance, context, - exprs, - annos, - "int_lin_le", + args, constraint_tag, constraints::less_than_or_equals, )?, - "int_lin_le_reif" => compile_reified_int_lin_predicate( + IntLinEq(args) => compile_int_lin_predicate( + instance, + context, + args, + constraint_tag, + constraints::equals, + )?, + + IntLinNeReif(args) => compile_reified_int_lin_predicate( + instance, context, - exprs, - annos, - "int_lin_le_reif", + args, + constraint_tag, + constraints::not_equals, + )?, + IntLinLeReif(args) => compile_reified_int_lin_predicate( + instance, + context, + args, constraint_tag, constraints::less_than_or_equals, )?, - "int_lin_eq" => { - compile_int_lin_predicate(context, exprs, annos, "int_lin_eq", constraint_tag, constraints::equals)? - } - "int_lin_eq_reif" => compile_reified_int_lin_predicate( + IntLinEqReif(args) => compile_reified_int_lin_predicate( + instance, context, - exprs, - annos, - "int_lin_eq_reif", + args, constraint_tag, constraints::equals, )?, - "int_ne" => compile_binary_int_predicate( + + IntLinNeImp(args) => compile_implied_int_lin_predicate( + instance, context, - exprs, - annos, - "int_ne", + args, constraint_tag, - constraints::binary_not_equals, + constraints::not_equals, )?, - "int_ne_reif" => compile_reified_binary_int_predicate( + IntLinLeImp(args) => compile_implied_int_lin_predicate( + instance, context, - exprs, - annos, - "int_ne_reif", + args, constraint_tag, - constraints::binary_not_equals, + constraints::less_than_or_equals, )?, - "int_eq" => compile_binary_int_predicate( + IntLinEqImp(args) => compile_implied_int_lin_predicate( + instance, context, - exprs, - annos, - "int_eq", + args, constraint_tag, - constraints::binary_equals, + constraints::equals, )?, - "int_eq_reif" => compile_reified_binary_int_predicate( + + IntEq(args) => compile_binary_int_predicate( context, - exprs, - annos, - "int_eq_reif", + args, constraint_tag, constraints::binary_equals, )?, - "int_le" => compile_binary_int_predicate( + IntNe(args) => compile_binary_int_predicate( context, - exprs, - annos, - "int_le", + args, constraint_tag, - constraints::binary_less_than_or_equals, + constraints::binary_not_equals, )?, - "int_le_reif" => compile_reified_binary_int_predicate( + IntLe(args) => compile_binary_int_predicate( context, - exprs, - annos, - "int_le_reif", + args, constraint_tag, constraints::binary_less_than_or_equals, )?, - "int_lt" => compile_binary_int_predicate( + IntLt(args) => compile_binary_int_predicate( context, - exprs, - annos, - "int_lt", + args, constraint_tag, constraints::binary_less_than, )?, - "int_lt_reif" => compile_reified_binary_int_predicate( + IntAbs(args) => { + compile_binary_int_predicate(context, args, constraint_tag, constraints::absolute)? + } + + IntEqReif(args) => compile_reified_binary_int_predicate( + context, + args, + constraint_tag, + constraints::binary_equals, + )?, + IntNeReif(args) => compile_reified_binary_int_predicate( + context, + args, + constraint_tag, + constraints::binary_not_equals, + )?, + IntLtReif(args) => compile_reified_binary_int_predicate( context, - exprs, - annos, - "int_lt_reif", + args, constraint_tag, constraints::binary_less_than, )?, - - "int_plus" => { - compile_ternary_int_predicate(context, exprs, annos, "int_plus", constraint_tag, constraints::plus)? - } - - "int_times" => compile_ternary_int_predicate( + IntLeReif(args) => compile_reified_binary_int_predicate( context, - exprs, - annos, - "int_times", + args, constraint_tag, - constraints::times, + constraints::binary_less_than_or_equals, )?, - "int_div" => compile_ternary_int_predicate( + + IntMax(args) => compile_ternary_int_predicate( context, - exprs, - annos, - "int_div", + args, constraint_tag, - constraints::division, + |a, b, c, constraint_tag| constraints::maximum([a, b], c, constraint_tag), )?, - "int_abs" => compile_binary_int_predicate( + + IntMin(args) => compile_ternary_int_predicate( context, - exprs, - annos, - "int_abs", + args, constraint_tag, - constraints::absolute, + |a, b, c, constraint_tag| constraints::minimum([a, b], c, constraint_tag), )?, - "pumpkin_all_different" => compile_all_different(context, exprs, annos, constraint_tag)?, - "pumpkin_table_int" => compile_table(context, exprs, annos, constraint_tag)?, - "pumpkin_table_int_reif" => compile_table_reif(context, exprs, annos, constraint_tag)?, + IntTimes(args) => { + compile_ternary_int_predicate(context, args, constraint_tag, constraints::times)? + } + IntDiv(args) => { + compile_ternary_int_predicate(context, args, constraint_tag, constraints::division)? + } + IntPlus(args) => { + compile_ternary_int_predicate(context, args, constraint_tag, constraints::plus)? + } - "array_bool_and" => compile_array_bool_and(context, exprs, constraint_tag)?, - "array_bool_element" => { - compile_array_var_bool_element(context, exprs, "array_bool_element", constraint_tag)? + AllDifferent(array) => { + let variables = context.resolve_integer_variable_array(instance, array)?; + constraints::all_different(variables, constraint_tag) + .post(context.solver) + .is_ok() } - "array_var_bool_element" => { - compile_array_var_bool_element(context, exprs, "array_var_bool_element", constraint_tag)? + + Table(table) => compile_table(instance, context, table, constraint_tag)?, + TableReif(table_reif) => { + compile_table_reif(instance, context, table_reif, constraint_tag)? } - "array_bool_or" => compile_bool_or(context, exprs, constraint_tag)?, - "pumpkin_bool_xor" => compile_bool_xor(context, exprs, constraint_tag)?, - "pumpkin_bool_xor_reif" => compile_bool_xor_reif(context, exprs, constraint_tag)?, - "bool2int" => compile_bool2int(context, exprs, constraint_tag)?, + ArrayBoolAnd(args) => compile_array_bool( + instance, + context, + args, + constraint_tag, + constraints::conjunction, + )?, - "bool_lin_eq" => { - compile_bool_lin_eq_predicate(context, exprs, constraint_tag)? + ArrayBoolOr(args) => { + compile_array_bool(instance, context, args, constraint_tag, constraints::clause)? } - "bool_lin_le" => { - compile_bool_lin_le_predicate(context, exprs, constraint_tag)? + BoolXor(args) => compile_bool_xor(context, args, constraint_tag)?, + BoolXorReif(args) => compile_bool_xor_reif(context, args, constraint_tag)?, + + BoolLinEq(args) => { + compile_bool_lin_eq_predicate(instance, context, args, constraint_tag)? } + BoolLinLe(args) => { + compile_bool_lin_le_predicate(instance, context, args, constraint_tag)? + } + + BoolAnd(args) => compile_bool_and(context, args, constraint_tag)?, + BoolEq(args) => compile_bool_eq(context, args, constraint_tag)?, + BoolEqReif(args) => compile_bool_eq_reif(context, args, constraint_tag)?, + BoolNot(args) => compile_bool_not(context, args, constraint_tag)?, + BoolClause(args) => compile_bool_clause(instance, context, args, constraint_tag)?, + + ArrayBoolElement(args) => { + compile_array_var_bool_element(instance, context, args, constraint_tag)? + } + ArrayVarBoolElement(args) => { + compile_array_var_bool_element(instance, context, args, constraint_tag)? + } + + BoolToInt(args) => compile_bool2int(context, args, constraint_tag)?, - "bool_and" => compile_bool_and(context, exprs, constraint_tag)?, - "bool_clause" => compile_bool_clause(context, exprs, constraint_tag)?, - "bool_eq" => compile_bool_eq(context, exprs, constraint_tag)?, - "bool_eq_reif" => compile_bool_eq_reif(context, exprs, constraint_tag)?, - "bool_not" => compile_bool_not(context, exprs, constraint_tag)?, - "set_in_reif" => compile_set_in_reif(context, exprs, constraint_tag)?, - "set_in" => { - // 'set_in' constraints are handled in pre-processing steps. - // TODO: remove it from the AST, so it does not need to be matched here - true + SetIn(_, _) => { + unreachable!("should be removed from the AST at previous stages") } - "pumpkin_cumulative" => compile_cumulative(context, exprs, options, constraint_tag)?, - "pumpkin_cumulative_var" => todo!("The `cumulative` constraint with variable duration/resource consumption/bound is not implemented yet!"), - unknown => todo!("unsupported constraint {unknown}"), + SetInReif(args) => compile_set_in_reif(context, args, constraint_tag)?, + + Cumulative(args) => { + compile_cumulative(instance, context, args, options, constraint_tag)? + } }; if !is_satisfiable { @@ -245,36 +328,20 @@ pub(crate) fn run( Ok(()) } -macro_rules! check_parameters { - ($exprs:ident, $num_parameters:expr, $name:expr) => { - if $exprs.len() != $num_parameters { - return Err(FlatZincError::IncorrectNumberOfArguments { - constraint_id: $name.into(), - expected: $num_parameters, - actual: $exprs.len(), - }); - } - }; -} - fn compile_cumulative( + instance: &Instance, context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &CumulativeArgs, options: &FlatZincOptions, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 4, "pumpkin_cumulative"); - - let start_times = context.resolve_integer_variable_array(&exprs[0])?; - let durations = context.resolve_array_integer_constants(&exprs[1])?; - let resource_requirements = context.resolve_array_integer_constants(&exprs[2])?; - let resource_capacity = context.resolve_integer_constant_from_expr(&exprs[3])?; + let start_times = context.resolve_integer_variable_array(instance, &args.start_times)?; let post_result = constraints::cumulative_with_options( - start_times.iter().copied(), - durations.iter().copied(), - resource_requirements.iter().copied(), - resource_capacity, + start_times, + context.resolve_integer_array(instance, &args.durations)?, + context.resolve_integer_array(instance, &args.resource_requirements)?, + args.resource_capacity, options.cumulative_options, constraint_tag, ) @@ -282,146 +349,94 @@ fn compile_cumulative( Ok(post_result.is_ok()) } -fn compile_array_int_maximum( - context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], - constraint_tag: ConstraintTag, -) -> Result { - check_parameters!(exprs, 2, "array_int_maximum"); - - let rhs = context.resolve_integer_variable(&exprs[0])?; - let array = context.resolve_integer_variable_array(&exprs[1])?; - - Ok( - constraints::maximum(array.as_ref().to_owned(), rhs, constraint_tag) - .post(context.solver) - .is_ok(), - ) -} - -fn compile_array_int_minimum( - context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], - constraint_tag: ConstraintTag, -) -> Result { - check_parameters!(exprs, 2, "array_int_minimum"); - - let rhs = context.resolve_integer_variable(&exprs[0])?; - let array = context.resolve_integer_variable_array(&exprs[1])?; - - Ok( - constraints::minimum(array.as_ref().to_owned(), rhs, constraint_tag) - .post(context.solver) - .is_ok(), - ) -} - fn compile_set_in_reif( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &SetInReifArgs, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, "set_in_reif"); - - let variable = context.resolve_integer_variable(&exprs[0])?; - let set = context.resolve_set_constant(&exprs[1])?; - let reif = context.resolve_bool_variable(&exprs[2])?; - - let success = match set { - Set::Interval { - lower_bound, - upper_bound, - } => { - // `reif -> x \in S` - // Decomposed to `reif -> x >= lb /\ reif -> x <= ub` - let forward = context + let variable = context.resolve_integer_variable(&args.variable)?; + let reif = context.resolve_bool_variable(&args.reification)?; + + let success = if args.set.is_continuous() { + // `reif -> x \in S` + // Decomposed to `reif -> x >= lb /\ reif -> x <= ub` + let forward = context + .solver + .add_clause( + [ + !reif.get_true_predicate(), + predicate![variable >= *args.set.lower_bound()], + ], + constraint_tag, + ) + .is_ok() + && context .solver .add_clause( [ !reif.get_true_predicate(), - predicate![variable >= lower_bound], - ], - constraint_tag, - ) - .is_ok() - && context - .solver - .add_clause( - [ - !reif.get_true_predicate(), - !predicate![variable >= upper_bound + 1], - ], - constraint_tag, - ) - .is_ok(); - - // `!reif -> x \notin S` - // Decomposed to `!reif -> (x < lb \/ x > ub)` - let backward = context - .solver - .add_clause( - [ - reif.get_true_predicate(), - !predicate![variable >= lower_bound], - predicate![variable >= upper_bound + 1], + !predicate![variable >= *args.set.upper_bound() + 1], ], constraint_tag, ) .is_ok(); - forward && backward - } + // `!reif -> x \notin S` + // Decomposed to `!reif -> (x < lb \/ x > ub)` + let backward = context + .solver + .add_clause( + [ + reif.get_true_predicate(), + !predicate![variable >= *args.set.lower_bound()], + predicate![variable >= *args.set.upper_bound() + 1], + ], + constraint_tag, + ) + .is_ok(); + + forward && backward + } else { + let clause = args + .set + .into_iter() + .map(|value| { + context + .solver + .new_literal_for_predicate(predicate![variable == value], constraint_tag) + }) + .collect::>(); - Set::Sparse { values } => { - let clause = values - .iter() - .map(|&value| { - context - .solver - .new_literal_for_predicate(predicate![variable == value], constraint_tag) - }) - .collect::>(); - - constraints::clause(clause, constraint_tag) - .reify(context.solver, reif) - .is_ok() - } + constraints::clause(clause, constraint_tag) + .reify(context.solver, reif) + .is_ok() }; Ok(success) } fn compile_array_var_int_element( + instance: &Instance, context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &IntElementArgs, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, "array_var_int_element"); - - let index = context.resolve_integer_variable(&exprs[0])?.offset(-1); - let array = context.resolve_integer_variable_array(&exprs[1])?; - let rhs = context.resolve_integer_variable(&exprs[2])?; + let index = context.resolve_integer_variable(&args.index)?.offset(-1); + let array = context.resolve_integer_variable_array(instance, &args.array)?; + let rhs = context.resolve_integer_variable(&args.rhs)?; - Ok( - constraints::element(index, array.as_ref().to_owned(), rhs, constraint_tag) - .post(context.solver) - .is_ok(), - ) + Ok(constraints::element(index, array, rhs, constraint_tag) + .post(context.solver) + .is_ok()) } fn compile_bool_not( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BinaryBool, constraint_tag: ConstraintTag, ) -> Result { - // TODO: Take this constraint into account when creating variables, as these can be opposite - // literals of the same PropositionalVariable. Unsure how often this actually appears in models - // though. - - check_parameters!(exprs, 2, "bool_not"); - - let a = context.resolve_bool_variable(&exprs[0])?; - let b = context.resolve_bool_variable(&exprs[1])?; + let a = context.resolve_bool_variable(&args.a)?; + let b = context.resolve_bool_variable(&args.b)?; Ok(constraints::binary_not_equals(a, b, constraint_tag) .post(context.solver) @@ -430,14 +445,12 @@ fn compile_bool_not( fn compile_bool_eq_reif( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BinaryBoolReif, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, "bool_eq_reif"); - - let a = context.resolve_bool_variable(&exprs[0])?; - let b = context.resolve_bool_variable(&exprs[1])?; - let r = context.resolve_bool_variable(&exprs[2])?; + let a = context.resolve_bool_variable(&args.a)?; + let b = context.resolve_bool_variable(&args.b)?; + let r = context.resolve_bool_variable(&args.reification)?; Ok(constraints::binary_equals(a, b, constraint_tag) .reify(context.solver, r) @@ -446,15 +459,11 @@ fn compile_bool_eq_reif( fn compile_bool_eq( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BinaryBool, constraint_tag: ConstraintTag, ) -> Result { - // TODO: Take this constraint into account when merging equivalence classes. Unsure how often - // this actually appears in models though. - check_parameters!(exprs, 2, "bool_eq"); - - let a = context.resolve_bool_variable(&exprs[0])?; - let b = context.resolve_bool_variable(&exprs[1])?; + let a = context.resolve_bool_variable(&args.a)?; + let b = context.resolve_bool_variable(&args.b)?; Ok(constraints::binary_equals(a, b, constraint_tag) .post(context.solver) @@ -462,14 +471,13 @@ fn compile_bool_eq( } fn compile_bool_clause( + instance: &Instance, context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BoolClauseArgs, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 2, "bool_clause"); - - let clause_1 = context.resolve_bool_variable_array(&exprs[0])?; - let clause_2 = context.resolve_bool_variable_array(&exprs[1])?; + let clause_1 = context.resolve_bool_variable_array(instance, &args.clause_1)?; + let clause_2 = context.resolve_bool_variable_array(instance, &args.clause_2)?; let clause: Vec = clause_1 .iter() @@ -483,14 +491,12 @@ fn compile_bool_clause( fn compile_bool_and( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BinaryBoolReif, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 2, "bool_and"); - - let a = context.resolve_bool_variable(&exprs[0])?; - let b = context.resolve_bool_variable(&exprs[1])?; - let r = context.resolve_bool_variable(&exprs[2])?; + let a = context.resolve_bool_variable(&args.a)?; + let b = context.resolve_bool_variable(&args.b)?; + let r = context.resolve_bool_variable(&args.reification)?; Ok(constraints::conjunction([a, b], constraint_tag) .reify(context.solver, r) @@ -499,17 +505,11 @@ fn compile_bool_and( fn compile_bool2int( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BoolToIntArgs, constraint_tag: ConstraintTag, ) -> Result { - // TODO: Perhaps we want to add a phase in the compiler that directly uses the literal - // corresponding to the predicate [b = 1] for the boolean parameter in this constraint. - // See https://emir-demirovic.atlassian.net/browse/PUM-89 - - check_parameters!(exprs, 2, "bool2int"); - - let a = context.resolve_bool_variable(&exprs[0])?; - let b = context.resolve_integer_variable(&exprs[1])?; + let a = context.resolve_bool_variable(&args.boolean)?; + let b = context.resolve_integer_variable(&args.integer)?; Ok( constraints::binary_equals(a.get_integer_variable(), b.scaled(1), constraint_tag) @@ -518,34 +518,13 @@ fn compile_bool2int( ) } -fn compile_bool_or( - context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], - constraint_tag: ConstraintTag, -) -> Result { - check_parameters!(exprs, 2, "bool_or"); - - let clause = context.resolve_bool_variable_array(&exprs[0])?; - let r = context.resolve_bool_variable(&exprs[1])?; - - Ok(constraints::clause(clause.as_ref(), constraint_tag) - .reify(context.solver, r) - .is_ok()) -} - fn compile_bool_xor( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BinaryBool, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 2, "pumpkin_bool_xor"); - - let a = context - .resolve_bool_variable(&exprs[0])? - .get_true_predicate(); - let b = context - .resolve_bool_variable(&exprs[1])? - .get_true_predicate(); + let a = context.resolve_bool_variable(&args.a)?.get_true_predicate(); + let b = context.resolve_bool_variable(&args.b)?.get_true_predicate(); let c1 = context.solver.add_clause([!a, !b], constraint_tag).is_ok(); let c2 = context.solver.add_clause([b, a], constraint_tag).is_ok(); @@ -555,14 +534,12 @@ fn compile_bool_xor( fn compile_bool_xor_reif( context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &BinaryBoolReif, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, "pumpkin_bool_xor_reif"); - - let a = context.resolve_bool_variable(&exprs[0])?; - let b = context.resolve_bool_variable(&exprs[1])?; - let r = context.resolve_bool_variable(&exprs[2])?; + let a = context.resolve_bool_variable(&args.a)?; + let b = context.resolve_bool_variable(&args.b)?; + let r = context.resolve_bool_variable(&args.reification)?; let c1 = constraints::clause([!a, !b, !r], constraint_tag) .post(context.solver) @@ -581,16 +558,14 @@ fn compile_bool_xor_reif( } fn compile_array_var_bool_element( + instance: &Instance, context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], - name: &str, + args: &BoolElementArgs, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, name); - - let index = context.resolve_integer_variable(&exprs[0])?.offset(-1); - let array = context.resolve_bool_variable_array(&exprs[1])?; - let rhs = context.resolve_bool_variable(&exprs[2])?; + let index = context.resolve_integer_variable(&args.index)?.offset(-1); + let array = context.resolve_bool_variable_array(instance, &args.array)?; + let rhs = context.resolve_bool_variable(&args.rhs)?; Ok( constraints::element(index, array.iter().cloned(), rhs, constraint_tag) @@ -599,36 +574,30 @@ fn compile_array_var_bool_element( ) } -fn compile_array_bool_and( +fn compile_array_bool( + instance: &Instance, context: &mut CompilationContext<'_>, - exprs: &[flatzinc::Expr], + args: &ArrayBoolArgs, constraint_tag: ConstraintTag, + create_constraint: impl FnOnce(Vec, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 2, "array_bool_and"); - - let conjunction = context.resolve_bool_variable_array(&exprs[0])?; - let r = context.resolve_bool_variable(&exprs[1])?; + let conjunction = context.resolve_bool_variable_array(instance, &args.booleans)?; + let r = context.resolve_bool_variable(&args.reification)?; - Ok( - constraints::conjunction(conjunction.as_ref(), constraint_tag) - .reify(context.solver, r) - .is_ok(), - ) + Ok(create_constraint(conjunction, constraint_tag) + .reify(context.solver, r) + .is_ok()) } fn compile_ternary_int_predicate( context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - predicate_name: &str, + ternary_int_args: &TernaryIntArgs, constraint_tag: ConstraintTag, create_constraint: impl FnOnce(DomainId, DomainId, DomainId, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 3, predicate_name); - - let a = context.resolve_integer_variable(&exprs[0])?; - let b = context.resolve_integer_variable(&exprs[1])?; - let c = context.resolve_integer_variable(&exprs[2])?; + let a = context.resolve_integer_variable(&ternary_int_args.a)?; + let b = context.resolve_integer_variable(&ternary_int_args.b)?; + let c = context.resolve_integer_variable(&ternary_int_args.c)?; let constraint = create_constraint(a, b, c, constraint_tag); Ok(constraint.post(context.solver).is_ok()) @@ -636,16 +605,12 @@ fn compile_ternary_int_predicate( fn compile_binary_int_predicate( context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - predicate_name: &str, + Binary(lhs, rhs): &Binary, constraint_tag: ConstraintTag, create_constraint: impl FnOnce(DomainId, DomainId, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 2, predicate_name); - - let a = context.resolve_integer_variable(&exprs[0])?; - let b = context.resolve_integer_variable(&exprs[1])?; + let a = context.resolve_integer_variable(lhs)?; + let b = context.resolve_integer_variable(rhs)?; let constraint = create_constraint(a, b, constraint_tag); Ok(constraint.post(context.solver).is_ok()) @@ -653,23 +618,19 @@ fn compile_binary_int_predicate( fn compile_reified_binary_int_predicate( context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - predicate_name: &str, + args: &ReifiedBinary, constraint_tag: ConstraintTag, create_constraint: impl FnOnce(DomainId, DomainId, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 3, predicate_name); - - let a = context.resolve_integer_variable(&exprs[0])?; - let b = context.resolve_integer_variable(&exprs[1])?; - let reif = context.resolve_bool_variable(&exprs[2])?; + let a = context.resolve_integer_variable(&args.a)?; + let b = context.resolve_integer_variable(&args.b)?; + let reif = context.resolve_bool_variable(&args.reification)?; let constraint = create_constraint(a, b, constraint_tag); Ok(constraint.reify(context.solver, reif).is_ok()) } -fn weighted_vars(weights: Rc<[i32]>, vars: Rc<[DomainId]>) -> Box<[AffineView]> { +fn weighted_vars(weights: &[i32], vars: Vec) -> Box<[AffineView]> { vars.iter() .zip(weights.iter()) .filter(|(_, &w)| w != 0) @@ -678,153 +639,110 @@ fn weighted_vars(weights: Rc<[i32]>, vars: Rc<[DomainId]>) -> Box<[AffineView( + instance: &Instance, context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - predicate_name: &str, + args: &Linear, constraint_tag: ConstraintTag, create_constraint: impl FnOnce(Box<[AffineView]>, i32, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 3, predicate_name); - - let weights = context.resolve_array_integer_constants(&exprs[0])?; - let vars = context.resolve_integer_variable_array(&exprs[1])?; - let rhs = context.resolve_integer_constant_from_expr(&exprs[2])?; + let vars = context.resolve_integer_variable_array(instance, &args.variables)?; + let weights = context.resolve_integer_array(instance, &args.weights)?; + let terms = weighted_vars(&weights, vars); - let terms = weighted_vars(weights, vars); - - let constraint = create_constraint(terms, rhs, constraint_tag); + let constraint = create_constraint(terms, args.rhs, constraint_tag); Ok(constraint.post(context.solver).is_ok()) } fn compile_reified_int_lin_predicate( + instance: &Instance, context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - predicate_name: &str, + args: &ReifiedLinear, constraint_tag: ConstraintTag, create_constraint: impl FnOnce(Box<[AffineView]>, i32, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 4, predicate_name); - - let weights = context.resolve_array_integer_constants(&exprs[0])?; - let vars = context.resolve_integer_variable_array(&exprs[1])?; - let rhs = context.resolve_integer_constant_from_expr(&exprs[2])?; - let reif = context.resolve_bool_variable(&exprs[3])?; + let vars = context.resolve_integer_variable_array(instance, &args.variables)?; + let weights = context.resolve_integer_array(instance, &args.weights)?; + let reif = context.resolve_bool_variable(&args.reification)?; - let terms = weighted_vars(weights, vars); + let terms = weighted_vars(&weights, vars); - let constraint = create_constraint(terms, rhs, constraint_tag); + let constraint = create_constraint(terms, args.rhs, constraint_tag); Ok(constraint.reify(context.solver, reif).is_ok()) } -fn compile_int_lin_imp_predicate( +fn compile_implied_int_lin_predicate( + instance: &Instance, context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - predicate_name: &str, + args: &ReifiedLinear, constraint_tag: ConstraintTag, create_constraint: impl FnOnce(Box<[AffineView]>, i32, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 4, predicate_name); - - let weights = context.resolve_array_integer_constants(&exprs[0])?; - let vars = context.resolve_integer_variable_array(&exprs[1])?; - let rhs = context.resolve_integer_constant_from_expr(&exprs[2])?; - let reif = context.resolve_bool_variable(&exprs[3])?; + let vars = context.resolve_integer_variable_array(instance, &args.variables)?; + let weights = context.resolve_integer_array(instance, &args.weights)?; + let reif = context.resolve_bool_variable(&args.reification)?; - let terms = weighted_vars(weights, vars); + let terms = weighted_vars(&weights, vars); - let constraint = create_constraint(terms, rhs, constraint_tag); + let constraint = create_constraint(terms, args.rhs, constraint_tag); Ok(constraint.implied_by(context.solver, reif).is_ok()) } fn compile_binary_int_imp( context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - predicate_name: &str, + args: &ReifiedBinary, constraint_tag: ConstraintTag, create_constraint: impl FnOnce(DomainId, DomainId, ConstraintTag) -> C, ) -> Result { - check_parameters!(exprs, 3, predicate_name); - - let a = context.resolve_integer_variable(&exprs[0])?; - let b = context.resolve_integer_variable(&exprs[1])?; - let reif = context.resolve_bool_variable(&exprs[2])?; + let a = context.resolve_integer_variable(&args.a)?; + let b = context.resolve_integer_variable(&args.b)?; + let reif = context.resolve_bool_variable(&args.reification)?; let constraint = create_constraint(a, b, constraint_tag); Ok(constraint.implied_by(context.solver, reif).is_ok()) } fn compile_bool_lin_eq_predicate( + instance: &Instance, context: &mut CompilationContext, - exprs: &[flatzinc::Expr], + args: &BoolLinEqArgs, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, "bool_lin_eq"); - - let weights = context.resolve_array_integer_constants(&exprs[0])?; - let bools = context.resolve_bool_variable_array(&exprs[1])?; - let rhs = context.resolve_integer_variable(&exprs[2])?; + let bools = context.resolve_bool_variable_array(instance, &args.variables)?; + let weights = context.resolve_integer_array(instance, &args.weights)?; + let rhs = context.resolve_integer_variable(&args.sum)?; - Ok(constraints::boolean_equals( - weights.as_ref().to_owned(), - bools.as_ref().to_owned(), - rhs, - constraint_tag, + Ok( + constraints::boolean_equals(weights, bools, rhs, constraint_tag) + .post(context.solver) + .is_ok(), ) - .post(context.solver) - .is_ok()) } fn compile_bool_lin_le_predicate( + instance: &Instance, context: &mut CompilationContext, - exprs: &[flatzinc::Expr], + args: &BoolLinLeArgs, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, "bool_lin_le"); - - let weights = context.resolve_array_integer_constants(&exprs[0])?; - let bools = context.resolve_bool_variable_array(&exprs[1])?; - let rhs = context.resolve_integer_constant_from_expr(&exprs[2])?; + let bools = context.resolve_bool_variable_array(instance, &args.variables)?; + let weights = context.resolve_integer_array(instance, &args.weights)?; - Ok(constraints::boolean_less_than_or_equals( - weights.as_ref().to_owned(), - bools.as_ref().to_owned(), - rhs, - constraint_tag, + Ok( + constraints::boolean_less_than_or_equals(weights, bools, args.bound, constraint_tag) + .post(context.solver) + .is_ok(), ) - .post(context.solver) - .is_ok()) -} - -fn compile_all_different( - context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], - constraint_tag: ConstraintTag, -) -> Result { - check_parameters!(exprs, 1, "fzn_all_different"); - - let variables = context.resolve_integer_variable_array(&exprs[0])?.to_vec(); - Ok(constraints::all_different(variables, constraint_tag) - .post(context.solver) - .is_ok()) } fn compile_table( + instance: &Instance, context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], + table: &TableInt, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 2, "pumpkin_table_int"); - - let variables = context.resolve_integer_variable_array(&exprs[0])?.to_vec(); - - let flat_table = context.resolve_array_integer_constants(&exprs[1])?; - let table = create_table(flat_table, variables.len()); + let variables = context.resolve_integer_variable_array(instance, &table.variables)?; + let flat_table = context.resolve_integer_array(instance, &table.table)?; + let table = create_table(&flat_table, variables.len()); Ok(constraints::table(variables, table, constraint_tag) .post(context.solver) @@ -832,26 +750,24 @@ fn compile_table( } fn compile_table_reif( + instance: &Instance, context: &mut CompilationContext, - exprs: &[flatzinc::Expr], - _: &[flatzinc::Annotation], + table_reif: &TableIntReif, constraint_tag: ConstraintTag, ) -> Result { - check_parameters!(exprs, 3, "pumpkin_table_int_reif"); - - let variables = context.resolve_integer_variable_array(&exprs[0])?.to_vec(); - - let flat_table = context.resolve_array_integer_constants(&exprs[1])?; - let table = create_table(flat_table, variables.len()); - - let reified = context.resolve_bool_variable(&exprs[2])?; + let variables = context + .resolve_integer_variable_array(instance, &table_reif.variables)? + .to_vec(); + let flat_table = context.resolve_integer_array(instance, &table_reif.table)?; + let table = create_table(&flat_table, variables.len()); + let reified = context.resolve_bool_variable(&table_reif.reification)?; Ok(constraints::table(variables, table, constraint_tag) .reify(context.solver, reified) .is_ok()) } -fn create_table(flat_table: Rc<[i32]>, num_variables: usize) -> Vec> { +fn create_table(flat_table: &[i32], num_variables: usize) -> Vec> { let table = flat_table .iter() .copied() diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/prepare_variables.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/prepare_variables.rs index 365ae2e8a..37b2b8762 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/prepare_variables.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/prepare_variables.rs @@ -1,85 +1,93 @@ //! Scan through all the variable definitions to prepare the equivalence classes for each of them. -use flatzinc::BoolExpr; -use flatzinc::IntExpr; +use std::rc::Rc; -use crate::flatzinc::ast::FlatZincAst; -use crate::flatzinc::ast::SingleVarDecl; +use fzn_rs::ast; + +use crate::flatzinc::ast::Instance; use crate::flatzinc::compiler::context::CompilationContext; use crate::flatzinc::FlatZincError; pub(crate) fn run( - ast: &FlatZincAst, + typed_ast: &Instance, context: &mut CompilationContext, ) -> Result<(), FlatZincError> { - for single_var_decl in &ast.single_variables { - match single_var_decl { - SingleVarDecl::Bool { id, expr, annos: _ } => { - let id = context.identifiers.get_interned(id); - - let (lb, ub) = match expr { + for (name, variable) in &typed_ast.variables { + match &variable.domain.node { + ast::Domain::Bool => { + let (lb, ub) = match &variable.value { None => (0, 1), - Some(BoolExpr::Bool(true)) => (1, 1), - Some(BoolExpr::Bool(false)) => (0, 0), + Some(ast::Node { node, .. }) => match node { + ast::Literal::Bool(true) => (1, 1), + ast::Literal::Bool(false) => (0, 0), - Some(BoolExpr::VarParIdentifier(identifier)) => { - let other_id = context.identifiers.get_interned(identifier); + // The variable is assigned to another variable, but we don't handle that + // case here. + ast::Literal::Identifier(_) => (0, 1), - match context.boolean_parameters.get(&other_id) { - Some(true) => (1, 1), - Some(false) => (0, 0), - None => (0, 1), - } - } + _ => panic!("expected boolean value or identifier, got {node:?}"), + }, }; - context.equivalences.create_equivalence_class(id, lb, ub); + context + .equivalences + .create_equivalence_class(Rc::clone(name), lb, ub); } - SingleVarDecl::IntInRange { - id, - lb, - ub, - expr, - annos: _, - } => { - let id = context.identifiers.get_interned(id); - - let lb = i32::try_from(*lb)?; - let ub = i32::try_from(*ub)?; - - let (lb, ub) = match expr { - None => (lb, ub), - - Some(IntExpr::Int(value)) => { - let value = i32::try_from(*value)?; - (value, value) - } - - Some(IntExpr::VarParIdentifier(identifier)) => { - let other_id = context.identifiers.get_interned(identifier); - - match context.integer_parameters.get(&other_id) { - Some(int) => (*int, *int), - None => (lb, ub), - } - } + ast::Domain::Int(set) if set.is_continuous() => { + let (lb, ub) = match &variable.value { + None => (*set.lower_bound(), *set.upper_bound()), + + Some(ast::Node { node, .. }) => match node { + ast::Literal::Int(int) => (*int, *int), + + // The variable is assigned to another variable, but we don't handle that + // case here. + ast::Literal::Identifier(_) => (*set.lower_bound(), *set.upper_bound()), + + _ => panic!("expected integer value or identifier, got {node:?}"), + }, }; - context.equivalences.create_equivalence_class(id, lb, ub); + let domain_span = variable.domain.span; + + let lb = i32::try_from(lb).map_err(|_| FlatZincError::IntegerTooBig { + integer: lb.to_string(), + span_start: domain_span.start, + span_end: domain_span.end, + })?; + let ub = i32::try_from(ub).map_err(|_| FlatZincError::IntegerTooBig { + integer: ub.to_string(), + span_start: domain_span.start, + span_end: domain_span.end, + })?; + + context + .equivalences + .create_equivalence_class(Rc::clone(name), lb, ub); } - SingleVarDecl::IntInSet { id, set, .. } => { - let id = context.identifiers.get_interned(id); + ast::Domain::Int(set) => { + assert!(!set.is_continuous()); context.equivalences.create_equivalence_class_sparse( - id, - set.iter() - .map(|&value| i32::try_from(value)) + Rc::clone(name), + set.into_iter() + .map(|value| { + i32::try_from(value).map_err(|_| FlatZincError::IntegerTooBig { + integer: value.to_string(), + span_start: variable.domain.span.start, + span_end: variable.domain.span.end, + }) + }) .collect::, _>>()?, ) } + + ast::Domain::UnboundedInt => { + return Err(FlatZincError::UnsupportedVariable(name.as_ref().into())) + } } } @@ -88,21 +96,26 @@ pub(crate) fn run( #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use fzn_rs::Method; + use fzn_rs::Solve; use pumpkin_solver::Solver; use super::*; - use crate::flatzinc::ast::SearchStrategy; - use crate::flatzinc::ast::SingleVarDecl; + use crate::flatzinc::ast::VariableAnnotations; use crate::flatzinc::compiler::context::Domain; #[test] fn bool_variable_creates_equivalence_class() { - let ast = create_dummy_ast([SingleVarDecl::Bool { - id: "SomeVar".into(), - expr: None, - annos: vec![], - }]); + let ast = create_dummy_instance([( + "SomeVar", + ast::Variable { + domain: test_node(ast::Domain::Bool), + value: None, + annotations: vec![], + }, + )]); let mut solver = Solver::default(); let mut context = CompilationContext::new(&mut solver); @@ -115,17 +128,23 @@ mod tests { #[test] fn bool_variable_equal_to_constant_as_singleton_domain() { - let ast = create_dummy_ast([ - SingleVarDecl::Bool { - id: "SomeVar".into(), - expr: Some(BoolExpr::Bool(false)), - annos: vec![], - }, - SingleVarDecl::Bool { - id: "OtherVar".into(), - expr: Some(BoolExpr::Bool(true)), - annos: vec![], - }, + let ast = create_dummy_instance([ + ( + "SomeVar", + ast::Variable { + domain: test_node(ast::Domain::Bool), + value: Some(test_node(ast::Literal::Bool(false))), + annotations: vec![], + }, + ), + ( + "OtherVar", + ast::Variable { + domain: test_node(ast::Domain::Bool), + value: Some(test_node(ast::Literal::Bool(true))), + annotations: vec![], + }, + ), ]); let mut solver = Solver::default(); @@ -143,33 +162,16 @@ mod tests { ); } - #[test] - fn bool_expr_resolves_parameter() { - let ast = create_dummy_ast([SingleVarDecl::Bool { - id: "SomeVar".into(), - expr: Some(BoolExpr::VarParIdentifier("FalsePar".into())), - annos: vec![], - }]); - - let mut solver = Solver::default(); - let mut context = CompilationContext::new(&mut solver); - let _ = context.boolean_parameters.insert("FalsePar".into(), false); - - run(&ast, &mut context).expect("no errors"); - - assert_eq!( - Domain::from_lower_bound_and_upper_bound(0, 0), - context.equivalences.domain("SomeVar") - ); - } - #[test] fn bool_expr_ignores_reference_to_non_existent_identifier() { - let ast = create_dummy_ast([SingleVarDecl::Bool { - id: "SomeVar".into(), - expr: Some(BoolExpr::VarParIdentifier("OtherVar".into())), - annos: vec![], - }]); + let ast = create_dummy_instance([( + "SomeVar", + ast::Variable { + domain: test_node(ast::Domain::Bool), + value: Some(test_node(ast::Literal::Identifier("OtherVar".into()))), + annotations: vec![], + }, + )]); let mut solver = Solver::default(); let mut context = CompilationContext::new(&mut solver); @@ -184,38 +186,17 @@ mod tests { #[test] fn int_expr_constant_is_parsed() { - let ast = create_dummy_ast([SingleVarDecl::IntInRange { - id: "SomeVar".into(), - lb: 1, - ub: 5, - expr: Some(IntExpr::Int(3)), - annos: vec![], - }]); - - let mut solver = Solver::default(); - let mut context = CompilationContext::new(&mut solver); - - run(&ast, &mut context).expect("no errors"); - - assert_eq!( - Domain::from_lower_bound_and_upper_bound(3, 3), - context.equivalences.domain("SomeVar") - ); - } - - #[test] - fn int_expr_named_constant_is_resolved() { - let ast = create_dummy_ast([SingleVarDecl::IntInRange { - id: "SomeVar".into(), - lb: 1, - ub: 5, - expr: Some(IntExpr::VarParIdentifier("IntPar".into())), - annos: vec![], - }]); + let ast = create_dummy_instance([( + "SomeVar", + ast::Variable { + domain: test_node(ast::Domain::Int(ast::RangeList::from(1..=5))), + value: Some(test_node(ast::Literal::Int(3))), + annotations: vec![], + }, + )]); let mut solver = Solver::default(); let mut context = CompilationContext::new(&mut solver); - let _ = context.integer_parameters.insert("IntPar".into(), 3); run(&ast, &mut context).expect("no errors"); @@ -225,22 +206,27 @@ mod tests { ); } - fn create_dummy_ast(decls: impl IntoIterator) -> FlatZincAst { - FlatZincAst { - parameter_decls: vec![], - single_variables: decls.into_iter().collect(), - variable_arrays: vec![], - constraint_decls: vec![], - solve_item: flatzinc::SolveItem { - goal: flatzinc::Goal::Satisfy, + fn create_dummy_instance( + variables: impl IntoIterator)>, + ) -> Instance { + Instance { + variables: variables + .into_iter() + .map(|(name, data)| (Rc::from(name), data)) + .collect(), + arrays: BTreeMap::new(), + constraints: vec![], + solve: Solve { + method: test_node(Method::Satisfy), annotations: vec![], }, - search: crate::flatzinc::ast::Search::Int(SearchStrategy { - variables: flatzinc::AnnExpr::String("test".to_owned()), - variable_selection_strategy: - crate::flatzinc::ast::VariableSelectionStrategy::AntiFirstFail, - value_selection_strategy: crate::flatzinc::ast::ValueSelectionStrategy::InDomain, - }), + } + } + + fn test_node(data: T) -> ast::Node { + ast::Node { + node: data, + span: ast::Span { start: 0, end: 0 }, } } } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/remove_unused_variables.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/remove_unused_variables.rs index 684dc918f..344abfe6e 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/remove_unused_variables.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/remove_unused_variables.rs @@ -8,143 +8,68 @@ use std::collections::BTreeSet; use std::rc::Rc; -use super::context::CompilationContext; -use crate::flatzinc::ast::FlatZincAst; -use crate::flatzinc::ast::VarArrayDecl; +use fzn_rs::ast::Ast; + use crate::flatzinc::error::FlatZincError; -pub(crate) fn run( - ast: &mut FlatZincAst, - context: &mut CompilationContext, -) -> Result<(), FlatZincError> { +pub(crate) fn run(ast: &mut Ast) -> Result<(), FlatZincError> { let mut marked_identifiers = BTreeSet::new(); - mark_identifiers_in_constraints(ast, context, &mut marked_identifiers); - mark_identifiers_in_arrays(ast, context, &mut marked_identifiers); + mark_identifiers_in_constraints(ast, &mut marked_identifiers); + mark_identifiers_in_arrays(ast, &mut marked_identifiers); // Make sure the objective, which can be unconstrained, is always marked. - match &ast.solve_item.goal { - flatzinc::Goal::OptimizeBool(_, flatzinc::BoolExpr::VarParIdentifier(id)) - | flatzinc::Goal::OptimizeInt(_, flatzinc::IntExpr::VarParIdentifier(id)) - | flatzinc::Goal::OptimizeFloat(_, flatzinc::FloatExpr::VarParIdentifier(id)) - | flatzinc::Goal::OptimizeSet(_, flatzinc::SetExpr::VarParIdentifier(id)) => { - let _ = marked_identifiers.insert(context.identifiers.get_interned(id)); - } - _ => {} + if let fzn_rs::ast::Method::Optimize { + objective: fzn_rs::ast::Literal::Identifier(objective), + .. + } = &ast.solve.method.node + { + let _ = marked_identifiers.insert(Rc::clone(objective)); } - ast.single_variables.retain(|decl| match decl { - crate::flatzinc::ast::SingleVarDecl::Bool { id, annos, .. } - | crate::flatzinc::ast::SingleVarDecl::IntInRange { id, annos, .. } - | crate::flatzinc::ast::SingleVarDecl::IntInSet { id, annos, .. } => { - // If the variable is an output variable, then always keep it. - if annos.iter().any(|annotation| annotation.id == "output_var") { - return true; - } - - marked_identifiers.contains(id.as_str()) - } + ast.variables.retain(|name, variable| { + marked_identifiers.contains(name) + || variable + .node + .annotations + .iter() + .any(|node| node.node.name() == "output_var") }); Ok(()) } -macro_rules! mark_literal_exprs { - ($exprs:ident, $expr_type:ident, $identifiers:ident, $context:ident) => {{ - for expr in $exprs { - if let flatzinc::$expr_type::VarParIdentifier(id) = expr { - let _ = $identifiers.insert($context.identifiers.get_interned(id)); - } - } - }}; -} - /// Go over all arrays and mark the identifiers that are elements of the array. -fn mark_identifiers_in_arrays( - ast: &mut FlatZincAst, - context: &mut CompilationContext, - marked_identifiers: &mut BTreeSet>, -) { - for array in &ast.variable_arrays { - match array { - VarArrayDecl::Bool { - array_expr: Some(expr), - .. - } => match expr { - flatzinc::ArrayOfBoolExpr::Array(exprs) => { - mark_literal_exprs!(exprs, BoolExpr, marked_identifiers, context) - } - flatzinc::ArrayOfBoolExpr::VarParIdentifier(_) => { - // This is the following case: - // - // array [1..4] of var int: as = [...]; - // array [1..4] of var int: bs = as; - // - // I don't think this can happen, so for now we panic. If it does happen we - // need to implement it otherwise we may be removing variables that we need - // later on. - panic!("Cannot handle array declarations that are assigned to other arrays.") - } - }, - VarArrayDecl::Int { - array_expr: Some(expr), - .. - } => match expr { - flatzinc::ArrayOfIntExpr::Array(exprs) => { - mark_literal_exprs!(exprs, IntExpr, marked_identifiers, context) - } - flatzinc::ArrayOfIntExpr::VarParIdentifier(_) => { - // This is the following case: - // - // array [1..4] of var int: as = [...]; - // array [1..4] of var int: bs = as; - // - // I don't think this can happen, so for now we panic. If it does happen we - // need to implement it otherwise we may be removing variables that we need - // later on. - panic!("Cannot handle array declarations that are assigned to other arrays.") - } - }, - _ => {} - } - } +fn mark_identifiers_in_arrays(ast: &Ast, marked_identifiers: &mut BTreeSet>) { + ast.arrays + .values() + .flat_map(|array| array.node.contents.iter()) + .for_each(|node| { + mark_literal(&node.node, marked_identifiers); + }); } /// Go over all constraints and add any identifier in the arguments to the `marked_identifiers` set. -fn mark_identifiers_in_constraints( - ast: &mut FlatZincAst, - context: &mut CompilationContext, - marked_identifiers: &mut BTreeSet>, -) { - for expr in ast - .constraint_decls +fn mark_identifiers_in_constraints(ast: &Ast, marked_identifiers: &mut BTreeSet>) { + for argument_node in ast + .constraints .iter() - .flat_map(|constraint| &constraint.exprs) + .flat_map(|constraint| &constraint.node.arguments) { - match expr { - flatzinc::Expr::VarParIdentifier(id) => { - let _ = marked_identifiers.insert(context.identifiers.get_interned(id)); - } - - flatzinc::Expr::ArrayOfBool(exprs) => { - mark_literal_exprs!(exprs, BoolExpr, marked_identifiers, context) - } + match &argument_node.node { + fzn_rs::ast::Argument::Array(nodes) => nodes.iter().for_each(|node| { + mark_literal(&node.node, marked_identifiers); + }), - flatzinc::Expr::ArrayOfInt(exprs) => { - mark_literal_exprs!(exprs, IntExpr, marked_identifiers, context) - } - flatzinc::Expr::ArrayOfFloat(exprs) => { - mark_literal_exprs!(exprs, FloatExpr, marked_identifiers, context) + fzn_rs::ast::Argument::Literal(node) => { + mark_literal(&node.node, marked_identifiers); } - - flatzinc::Expr::ArrayOfSet(exprs) => { - mark_literal_exprs!(exprs, SetExpr, marked_identifiers, context) - } - - flatzinc::Expr::Bool(_) - | flatzinc::Expr::Int(_) - | flatzinc::Expr::Float(_) - | flatzinc::Expr::Set(_) => {} } } } + +fn mark_literal(literal: &fzn_rs::ast::Literal, marked_identifiers: &mut BTreeSet>) { + if let fzn_rs::ast::Literal::Identifier(ident) = literal { + let _ = marked_identifiers.insert(Rc::clone(ident)); + } +} diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/reserve_constraint_tags.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/reserve_constraint_tags.rs index acd908182..2bd75b708 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/reserve_constraint_tags.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/reserve_constraint_tags.rs @@ -3,18 +3,35 @@ //! However, we assume that the first n constraint tags are the flatzinc constraints. Therefore, //! the root-level inferences would throw off that mapping. +use fzn_rs::ast; + use super::context::CompilationContext; -use crate::flatzinc::ast::FlatZincAst; +use crate::flatzinc::ast::ConstraintAnnotations; +use crate::flatzinc::ast::Instance; use crate::flatzinc::error::FlatZincError; pub(crate) fn run( - ast: &FlatZincAst, + instance: &mut Instance, context: &mut CompilationContext, ) -> Result<(), FlatZincError> { - for decl in &ast.constraint_decls { + for constraint in instance.constraints.iter_mut() { let tag = context.solver.new_constraint_tag(); - context.constraints.push((tag, decl.clone())); + constraint + .annotations + .push(generated_node(ConstraintAnnotations::ConstraintTag( + tag.into(), + ))); } Ok(()) } + +fn generated_node(data: T) -> ast::Node { + ast::Node { + span: ast::Span { + start: usize::MAX, + end: usize::MAX, + }, + node: data, + } +} diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/constraints.rs new file mode 100644 index 000000000..95057c1ac --- /dev/null +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/constraints.rs @@ -0,0 +1,270 @@ +use fzn_rs::ast::RangeList; +use fzn_rs::ArrayExpr; +use fzn_rs::VariableExpr; + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) enum Constraints { + SetIn(VariableExpr, RangeList), + + #[args] + SetInReif(SetInReifArgs), + + #[args] + ArrayIntMinimum(ArrayExtremum), + #[args] + ArrayIntMaximum(ArrayExtremum), + + #[args] + IntMax(TernaryIntArgs), + #[args] + IntMin(TernaryIntArgs), + + #[args] + ArrayIntElement(IntElementArgs), + #[args] + ArrayVarIntElement(IntElementArgs), + + #[args] + IntEqImp(ReifiedBinary), + #[args] + IntGeImp(ReifiedBinary), + #[args] + IntGtImp(ReifiedBinary), + #[args] + IntLeImp(ReifiedBinary), + #[args] + IntLtImp(ReifiedBinary), + #[args] + IntNeImp(ReifiedBinary), + + #[args] + IntLinLe(Linear), + #[args] + IntLinEq(Linear), + #[args] + IntLinNe(Linear), + + #[args] + IntLinLeReif(ReifiedLinear), + #[args] + IntLinEqReif(ReifiedLinear), + #[args] + IntLinNeReif(ReifiedLinear), + + #[args] + IntLinLeImp(ReifiedLinear), + #[args] + IntLinEqImp(ReifiedLinear), + #[args] + IntLinNeImp(ReifiedLinear), + + #[args] + IntEq(Binary), + #[args] + IntNe(Binary), + #[args] + IntLe(Binary), + #[args] + IntLt(Binary), + #[args] + IntAbs(Binary), + + #[args] + IntEqReif(ReifiedBinary), + #[args] + IntNeReif(ReifiedBinary), + #[args] + IntLeReif(ReifiedBinary), + #[args] + IntLtReif(ReifiedBinary), + + #[args] + IntTimes(TernaryIntArgs), + #[args] + IntPlus(TernaryIntArgs), + #[args] + IntDiv(TernaryIntArgs), + + #[name("pumpkin_all_different")] + AllDifferent(ArrayExpr>), + + #[args] + #[name("pumpkin_table_int")] + Table(TableInt), + + #[args] + #[name("pumpkin_table_int_reif")] + TableReif(TableIntReif), + + #[args] + ArrayBoolAnd(ArrayBoolArgs), + + #[args] + ArrayBoolOr(ArrayBoolArgs), + + #[args] + BoolClause(BoolClauseArgs), + + #[args] + BoolEq(BinaryBool), + + #[args] + BoolNot(BinaryBool), + + #[args] + #[name("pumpkin_bool_xor")] + BoolXor(BinaryBool), + + #[args] + #[name("pumpkin_bool_xor_reif")] + BoolXorReif(BinaryBoolReif), + + #[args] + #[name("bool2int")] + BoolToInt(BoolToIntArgs), + + #[args] + BoolLinEq(BoolLinEqArgs), + + #[args] + BoolLinLe(BoolLinLeArgs), + + #[args] + BoolAnd(BinaryBoolReif), + #[args] + BoolEqReif(BinaryBoolReif), + + #[args] + ArrayBoolElement(BoolElementArgs), + #[args] + ArrayVarBoolElement(BoolElementArgs), + + #[args] + #[name("pumpkin_cumulative")] + Cumulative(CumulativeArgs), +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct CumulativeArgs { + pub(crate) start_times: ArrayExpr>, + pub(crate) durations: ArrayExpr, + pub(crate) resource_requirements: ArrayExpr, + pub(crate) resource_capacity: i32, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct SetInReifArgs { + pub(crate) variable: VariableExpr, + pub(crate) set: RangeList, + pub(crate) reification: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct BoolClauseArgs { + pub(crate) clause_1: ArrayExpr>, + pub(crate) clause_2: ArrayExpr>, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct BoolLinEqArgs { + pub(crate) weights: ArrayExpr, + pub(crate) variables: ArrayExpr>, + pub(crate) sum: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct BoolLinLeArgs { + pub(crate) weights: ArrayExpr, + pub(crate) variables: ArrayExpr>, + pub(crate) bound: i32, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct BoolToIntArgs { + pub(crate) boolean: VariableExpr, + pub(crate) integer: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct ArrayBoolArgs { + pub(crate) booleans: ArrayExpr>, + pub(crate) reification: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct BinaryBool { + pub(crate) a: VariableExpr, + pub(crate) b: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct BinaryBoolReif { + pub(crate) a: VariableExpr, + pub(crate) b: VariableExpr, + pub(crate) reification: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct TableInt { + pub(crate) variables: ArrayExpr>, + pub(crate) table: ArrayExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct TableIntReif { + pub(crate) variables: ArrayExpr>, + pub(crate) table: ArrayExpr, + pub(crate) reification: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct ArrayExtremum { + pub(crate) extremum: VariableExpr, + pub(crate) array: ArrayExpr>, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct TernaryIntArgs { + pub(crate) a: VariableExpr, + pub(crate) b: VariableExpr, + pub(crate) c: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct IntElementArgs { + pub(crate) index: VariableExpr, + pub(crate) array: ArrayExpr>, + pub(crate) rhs: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct BoolElementArgs { + pub(crate) index: VariableExpr, + pub(crate) array: ArrayExpr>, + pub(crate) rhs: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct Binary(pub(crate) VariableExpr, pub(crate) VariableExpr); + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct ReifiedBinary { + pub(crate) a: VariableExpr, + pub(crate) b: VariableExpr, + pub(crate) reification: VariableExpr, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct Linear { + pub(crate) weights: ArrayExpr, + pub(crate) variables: ArrayExpr>, + pub(crate) rhs: i32, +} + +#[derive(fzn_rs::FlatZincConstraint)] +pub(crate) struct ReifiedLinear { + pub(crate) weights: ArrayExpr, + pub(crate) variables: ArrayExpr>, + pub(crate) rhs: i32, + pub(crate) reification: VariableExpr, +} diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/error.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/error.rs index e9481f997..4874da0e5 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/error.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/error.rs @@ -1,4 +1,4 @@ -use std::num::TryFromIntError; +use std::rc::Rc; use thiserror::Error; @@ -7,31 +7,135 @@ pub(crate) enum FlatZincError { #[error("failed to read instance file: {0}")] Io(#[from] std::io::Error), - #[error("{0}")] - SyntaxError(Box), - #[error("{0} variables are not supported")] UnsupportedVariable(Box), - #[error("integer too big")] - IntegerTooBig(#[from] TryFromIntError), + #[error("integer {integer} is too big for our integer representation")] + IntegerTooBig { + integer: String, + span_start: usize, + span_end: usize, + }, + + #[error("the identifier '{identifier}' does not resolve to an '{expected_type}'")] + InvalidIdentifier { + identifier: Rc, + expected_type: Box, + }, + + #[error("use of undefined array '{0}'")] + UndefinedArray(String), + + #[error("constraint '{0}' is not supported")] + UnsupportedConstraint(String), + + /// Occurs when parsing a nested annotation. + /// + /// In this case, all possible arguments must be parsable into an annotation. If there is a + /// value that cannot be parsed, this error variant is returned. + #[error("annotation '{0}' is not supported")] + UnsupportedAnnotation(String), + + #[error("expected {expected}, got {actual} at ({span_start}, {span_end})")] + UnexpectedToken { + expected: String, + actual: String, + span_start: usize, + span_end: usize, + }, - #[error("constraint {constraint_id} expects {expected} arguments, got {actual}")] + #[error("expected {expected} arguments, got {actual} at ({span_start}, {span_end})")] IncorrectNumberOfArguments { - constraint_id: Box, expected: usize, actual: usize, + span_start: usize, + span_end: usize, }, - #[error("unexpected expression")] - UnexpectedExpr, + #[error("value {0} does not fit in the required integer type")] + IntegerOverflow(i64), +} - #[error("the identifier '{identifier}' does not resolve to an '{expected_type}'")] - InvalidIdentifier { - identifier: Box, - expected_type: Box, - }, +impl From> for FlatZincError { + fn from(value: fzn_rs::fzn::FznError<'_>) -> Self { + match value { + fzn_rs::fzn::FznError::LexError { reasons } => { + // For now we only look at the first error. In the future, fzn-rs may produce + // multiple errors. + let reason = reasons[0].clone(); + + let span = reason.span(); + let expected = reason + .expected() + .map(|pattern| format!("{pattern}, ")) + .collect::(); + + FlatZincError::UnexpectedToken { + expected, + actual: reason + .found() + .map(|c| format!("{c}")) + .unwrap_or("".to_owned()), + span_start: span.start, + span_end: span.end, + } + } + fzn_rs::fzn::FznError::ParseError { reasons } => { + // For now we only look at the first error. In the future, fzn-rs may produce + // multiple errors. + let reason = reasons[0].clone(); + + let span = reason.span(); + let expected = reason + .expected() + .map(|pattern| format!("{pattern}, ")) + .collect::(); + + FlatZincError::UnexpectedToken { + expected, + actual: reason + .found() + .map(|token| format!("{token}")) + .unwrap_or("".to_owned()), + span_start: span.start, + span_end: span.end, + } + } + } + } +} - #[error("missing solve item")] - MissingSolveItem, +impl From for FlatZincError { + fn from(value: fzn_rs::InstanceError) -> Self { + match value { + fzn_rs::InstanceError::UnsupportedConstraint(c) => { + FlatZincError::UnsupportedConstraint(c) + } + fzn_rs::InstanceError::UnsupportedAnnotation(a) => { + FlatZincError::UnsupportedAnnotation(a) + } + fzn_rs::InstanceError::UnexpectedToken { + expected, + actual, + span, + } => FlatZincError::UnexpectedToken { + expected: format!("{expected}"), + actual: format!("{actual}"), + span_start: span.start, + span_end: span.end, + }, + fzn_rs::InstanceError::UndefinedArray(a) => FlatZincError::UndefinedArray(a), + fzn_rs::InstanceError::IncorrectNumberOfArguments { + expected, + actual, + span, + } => FlatZincError::IncorrectNumberOfArguments { + expected, + actual, + span_start: span.start, + span_end: span.end, + }, + fzn_rs::InstanceError::IntegerOverflow(num) => FlatZincError::IntegerOverflow(num), + } + } } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs index 80b6e3a2b..a47f351b0 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs @@ -58,7 +58,7 @@ impl Output { pub(crate) fn array_of_bool( id: Rc, shape: Box<[(i32, i32)]>, - contents: Rc<[Literal]>, + contents: Vec, ) -> Output { Output::ArrayOfBool(ArrayOutput { id, @@ -77,7 +77,7 @@ impl Output { pub(crate) fn array_of_int( id: Rc, shape: Box<[(i32, i32)]>, - contents: Rc<[DomainId]>, + contents: Vec, ) -> Output { Output::ArrayOfInt(ArrayOutput { id, @@ -108,7 +108,7 @@ pub(crate) struct ArrayOutput { /// Example: [(1, 5), (2, 4)] describes a 2d array, where the first dimension in indexed with /// an element of 1..5, and the second dimension is indexed with an element from 2..4. shape: Box<[(i32, i32)]>, - contents: Rc<[T]>, + contents: Vec, } impl ArrayOutput { diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs index 37cf19838..2a3701331 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs @@ -1,8 +1,8 @@ mod ast; mod compiler; +mod constraints; pub(crate) mod error; mod instance; -mod parser; use std::fs::File; use std::io::Read; @@ -108,20 +108,113 @@ fn solution_callback( pub(crate) fn solve( mut solver: Solver, - instance: impl AsRef, + instance_path: impl AsRef, time_limit: Option, options: FlatZincOptions, ) -> Result<(), FlatZincError> { - let init_start_time = Instant::now(); - - let instance = File::open(instance)?; + let instance = File::open(&instance_path)?; let mut termination = Combinator::new( OsSignal::install(), time_limit.map(TimeBudget::starting_now), ); - let instance = parse_and_compile(&mut solver, instance, options)?; + let init_start_time = Instant::now(); + + let instance = match parse_and_compile(&mut solver, instance, options) { + Ok(instance) => instance, + Err(FlatZincError::IntegerTooBig { + integer, + span_start, + span_end, + }) => { + let instance_path_str = instance_path.as_ref().display().to_string(); + let source = std::fs::read_to_string(instance_path).unwrap(); + + ariadne::Report::build( + ariadne::ReportKind::Error, + (&instance_path_str, span_start..span_end), + ) + .with_message("Integer value too big") + .with_label( + ariadne::Label::new((&instance_path_str, span_start..span_end)).with_message( + format!("value {integer} does not fit into our integer representation"), + ), + ) + .finish() + .print((&instance_path_str, ariadne::Source::from(source))) + .unwrap(); + + return Err(FlatZincError::IntegerTooBig { + integer, + span_start, + span_end, + }); + } + Err(FlatZincError::IncorrectNumberOfArguments { + expected, + actual, + span_start, + span_end, + }) => { + let instance_path_str = instance_path.as_ref().display().to_string(); + let source = std::fs::read_to_string(instance_path).unwrap(); + + ariadne::Report::build( + ariadne::ReportKind::Error, + (&instance_path_str, span_start..span_end), + ) + .with_message("Incorrect number of arguments") + .with_label( + ariadne::Label::new((&instance_path_str, span_start..span_end)) + .with_message(format!("Expected {expected} arguments, got {actual}.")), + ) + .finish() + .print((&instance_path_str, ariadne::Source::from(source))) + .unwrap(); + + return Err(FlatZincError::IncorrectNumberOfArguments { + expected, + actual, + span_start, + span_end, + }); + } + Err(FlatZincError::UnexpectedToken { + expected, + actual, + span_start, + span_end, + }) => { + let instance_path_str = instance_path.as_ref().display().to_string(); + let source = std::fs::read_to_string(instance_path).unwrap(); + + ariadne::Report::build( + ariadne::ReportKind::Error, + (&instance_path_str, span_start..span_end), + ) + .with_message("Unexpected input") + .with_label( + ariadne::Label::new((&instance_path_str, span_start..span_end)) + .with_message(format!("Expected {expected}")), + ) + .finish() + .print((&instance_path_str, ariadne::Source::from(source))) + .unwrap(); + + return Err(FlatZincError::UnexpectedToken { + expected, + actual, + span_start, + span_end, + }); + } + + Err(e) => { + return Err(e); + } + }; + let outputs = instance.outputs.clone(); let init_time = init_start_time.elapsed(); @@ -311,10 +404,14 @@ fn satisfy( fn parse_and_compile( solver: &mut Solver, - instance: impl Read, + mut instance: impl Read, options: FlatZincOptions, ) -> Result { - let ast = parser::parse(instance)?; + let mut source = String::new(); + let _ = instance.read_to_string(&mut source)?; + + let ast = fzn_rs::fzn::parse(&source)?; + compiler::compile(ast, solver, options) } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/parser.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/parser.rs deleted file mode 100644 index 2ab73049c..000000000 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/parser.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::io::BufRead; -use std::io::BufReader; -use std::io::Read; -use std::str::FromStr; - -use super::ast::FlatZincAst; -use super::ast::FlatZincAstBuilder; -use super::ast::SingleVarDecl; -use super::ast::VarArrayDecl; -use super::FlatZincError; - -pub(crate) fn parse(source: impl Read) -> Result { - let reader = BufReader::new(source); - - let mut ast_builder = FlatZincAst::builder(); - - for line in reader.lines() { - let line = line?; - - match flatzinc::statements::Stmt::from_str(&line) { - Ok(stmt) => match stmt { - // Ignore. - flatzinc::Stmt::Comment(_) | flatzinc::Stmt::Predicate(_) => {} - - flatzinc::Stmt::Parameter(decl) => ast_builder.add_parameter_decl(decl), - flatzinc::Stmt::Variable(decl) => parse_var_decl(&mut ast_builder, decl)?, - flatzinc::Stmt::Constraint(constraint) => ast_builder.add_constraint(constraint), - flatzinc::Stmt::SolveItem(solve_item) => ast_builder.set_solve_item(solve_item), - }, - Err(msg) => { - return Err(FlatZincError::SyntaxError(msg.into())); - } - } - } - - ast_builder.build() -} - -fn parse_var_decl( - ast: &mut FlatZincAstBuilder, - decl: flatzinc::VarDeclItem, -) -> Result<(), FlatZincError> { - match decl { - flatzinc::VarDeclItem::Bool { id, expr, annos } => { - ast.add_variable_decl(SingleVarDecl::Bool { id, expr, annos }); - Ok(()) - } - - flatzinc::VarDeclItem::IntInRange { - id, - lb, - ub, - expr, - annos, - } => { - ast.add_variable_decl(SingleVarDecl::IntInRange { - id, - lb, - ub, - expr, - annos, - }); - Ok(()) - } - - flatzinc::VarDeclItem::IntInSet { - id, - set, - expr: _, - annos, - } => { - ast.add_variable_decl(SingleVarDecl::IntInSet { id, set, annos }); - Ok(()) - } - - flatzinc::VarDeclItem::ArrayOfBool { - ix: _, - id, - annos, - array_expr, - } => { - ast.add_variable_array(VarArrayDecl::Bool { - id, - annos, - array_expr, - }); - - Ok(()) - } - - flatzinc::VarDeclItem::ArrayOfInt { - ix: _, - id, - annos, - array_expr, - } => { - ast.add_variable_array(VarArrayDecl::Int { - id, - annos, - array_expr, - }); - Ok(()) - } - flatzinc::VarDeclItem::ArrayOfIntInRange { - ix: _, - id, - annos, - array_expr, - .. - } => { - ast.add_variable_array(VarArrayDecl::Int { - id, - annos, - array_expr, - }); - Ok(()) - } - - flatzinc::VarDeclItem::ArrayOfIntInSet { - ix: _, - id, - annos, - array_expr, - set: _, - } => { - ast.add_variable_array(VarArrayDecl::Int { - id, - annos, - array_expr, - }); - Ok(()) - } - - flatzinc::VarDeclItem::Int { .. } => { - Err(FlatZincError::UnsupportedVariable("unbounded int".into())) - } - - flatzinc::VarDeclItem::Float { .. } - | flatzinc::VarDeclItem::BoundedFloat { .. } - | flatzinc::VarDeclItem::ArrayOfFloat { .. } - | flatzinc::VarDeclItem::ArrayOfBoundedFloat { .. } => { - Err(FlatZincError::UnsupportedVariable("float".into())) - } - - flatzinc::VarDeclItem::SetOfInt { .. } - | flatzinc::VarDeclItem::SubSetOfIntSet { .. } - | flatzinc::VarDeclItem::SubSetOfIntRange { .. } - | flatzinc::VarDeclItem::ArrayOfSet { .. } - | flatzinc::VarDeclItem::ArrayOfSubSetOfIntRange { .. } - | flatzinc::VarDeclItem::ArrayOfSubSetOfIntSet { .. } => { - Err(FlatZincError::UnsupportedVariable("set".into())) - } - } -} diff --git a/pumpkin-solver/src/bin/pumpkin-solver/main.rs b/pumpkin-solver/src/bin/pumpkin-solver/main.rs index 3ea95c5b9..e96865588 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/main.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/main.rs @@ -15,6 +15,7 @@ use std::time::Duration; use clap::Parser; use clap::ValueEnum; use file_format::FileFormat; +use flatzinc::error::FlatZincError; use log::error; use log::info; use log::warn; @@ -488,6 +489,14 @@ fn configure_logging_sat( fn main() { match run() { Ok(()) => {} + + // These errors are printed in the flatzinc code. + Err(PumpkinError::FlatZinc(FlatZincError::UnexpectedToken { .. })) + | Err(PumpkinError::FlatZinc(FlatZincError::IntegerTooBig { .. })) + | Err(PumpkinError::FlatZinc(FlatZincError::IncorrectNumberOfArguments { .. })) => { + std::process::exit(1) + } + Err(e) => { error!("Execution failed, error: {e}"); std::process::exit(1);