diff --git a/Cargo.toml b/Cargo.toml index 7e11b767..ec8fabfc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "reveaal" version = "0.1.0" build = "src/build.rs" authors = ["Thomas Lohse", "Sebastian Lund", "Thorulf Neustrup", "Peter Greve"] -edition = "2018" +edition = "2021" [lib] name = "reveaal" @@ -27,22 +27,21 @@ xml-rs = "0.8.3" serde-xml-rs = "0.6.0" elementtree = "1.2.2" dyn-clone = "1.0" -tonic = "0.8.3" -prost = "0.11.0" -tokio = { version = "1.0", features = ["macros", "rt"] } +tonic = "0.11.0" +prost = "0.12.3" +tokio = { version = "1.36.0", features = ["macros", "rt"] } colored = "2.0.0" -simple-error = "0.2.3" force_graph = "0.3.2" rand = "0.8.5" futures = "0.3.21" edbm = { git = "https://github.com/Ecdar/EDBM" } log = "0.4.17" -env_logger = { version = "0.9.0", optional = true } +env_logger = { version = "0.11.2", optional = true } chrono = { version = "0.4.22", optional = true } -test-case = "2.2.2" +test-case = "3.3.1" num_cpus = "1.13.1" -lru = "0.8.1" -itertools = "0.10.5" +lru = "0.12.2" +itertools = "0.12.1" regex = "1" rayon = "1.6.1" lazy_static = "1.4.0" @@ -53,14 +52,14 @@ num = "0.4.1" opt-level = 3 [build-dependencies] -tonic-build = "0.8.2" +tonic-build = "0.11.0" [dev-dependencies] -test-case = "2.2.2" -criterion = { version = "0.4.0", features = ["async_futures"] } +test-case = "3.3.1" +criterion = { version = "0.5.1", features = ["async_futures"] } [target.'cfg(unix)'.dev-dependencies] -pprof = { version = "0.10.1", features = ["flamegraph"] } +pprof = { version = "0.13.0", features = ["flamegraph"] } [[bench]] name = "refinement_bench" diff --git a/src/data_reader/grammars/edge_grammar.pest b/src/data_reader/grammars/edge_grammar.pest index 68bc4736..3c9c1cfb 100644 --- a/src/data_reader/grammars/edge_grammar.pest +++ b/src/data_reader/grammars/edge_grammar.pest @@ -30,12 +30,12 @@ bool_op = _{ and | or } and = { "&&" } or = { "||" } -arith_op = _{ add | sub | mul | div | mod } +arith_op = _{ add | sub | mul | div | modulo } add = { "+" } // Addition sub = { "-" } // Subtraction mul = { "*" } // Multiplication div = { "/" } // Division -mod = { "%" } // Modulo +modulo = { "%" } // Modulo compare_op = _{ geq | leq | eq | lt | gt } geq = { ">=" } // Greater than or equal to diff --git a/src/data_reader/parse_edge.rs b/src/data_reader/parse_edge.rs index 81645b01..ef5b555e 100644 --- a/src/data_reader/parse_edge.rs +++ b/src/data_reader/parse_edge.rs @@ -22,7 +22,7 @@ lazy_static! { .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left)) .op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) - | Op::infix(Rule::r#mod, Assoc::Left)) + | Op::infix(Rule::modulo, Assoc::Left)) .op(Op::infix(Rule::and, Assoc::Left)) .op(Op::infix(Rule::or, Assoc::Left)); } @@ -180,7 +180,7 @@ fn parse_arith_expr(pair: pest::iterators::Pair) -> ArithExpression { Rule::sub => ArithExpression::Difference(left, right), Rule::mul => ArithExpression::Multiplication(left, right), Rule::div => ArithExpression::Division(left, right), - Rule::r#mod => ArithExpression::Modulo(left, right), + Rule::modulo => ArithExpression::Modulo(left, right), _ => unreachable!("Unable to match: {:?} as rule, arith", op), } }) diff --git a/src/lib.rs b/src/lib.rs index 04ec87c1..e7817673 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,6 @@ extern crate colored; extern crate core; extern crate serde; extern crate serde_xml_rs; -extern crate simple_error; extern crate xml; #[macro_use] extern crate lazy_static; diff --git a/src/model_objects/expressions/arith_expression.rs b/src/model_objects/expressions/arith_expression.rs index 117047a6..71260e28 100644 --- a/src/model_objects/expressions/arith_expression.rs +++ b/src/model_objects/expressions/arith_expression.rs @@ -122,7 +122,7 @@ impl ArithExpression { let mut new_constraint = 0; self.iterate_constraints(&mut |left, right| { - //Start by matching left and right operands to get constant, this might fail if it does we skip constraint defaulting to 0 + //Start by matching left and right operands to get constant, this might fail. If it does, we skip constraint defaulting to 0 let constant = ArithExpression::get_constant(left, right, clock, clock_name); if new_constraint < constant { diff --git a/src/model_objects/expressions/bool_expression.rs b/src/model_objects/expressions/bool_expression.rs index 8aa5e31a..fc6403fa 100644 --- a/src/model_objects/expressions/bool_expression.rs +++ b/src/model_objects/expressions/bool_expression.rs @@ -309,81 +309,50 @@ impl BoolExpression { fn simplify_helper(&mut self) -> bool { let mut changed = false; let mut value = None; + let mut handle = |l: &mut Box, + r: &mut Box, + //r: ArithExpression, + cmp: &(dyn Fn(&i32, &i32) -> bool)| { + **l = l.simplify().expect("Can't simplify"); + **r = r.simplify().expect("Can't simplify"); + if let ArithExpression::Int(x) = **l { + if let ArithExpression::Int(y) = **r { + value = Some(BoolExpression::Bool(cmp(&x, &y))) + } + } + }; match self { BoolExpression::AndOp(left, right) => { changed |= left.simplify_helper(); changed |= right.simplify_helper(); - match **left { - BoolExpression::Bool(false) => value = Some(BoolExpression::Bool(false)), - BoolExpression::Bool(true) => value = Some((**right).clone()), - _ => {} - } - match **right { - BoolExpression::Bool(false) => value = Some(BoolExpression::Bool(false)), - BoolExpression::Bool(true) => value = Some((**left).clone()), - _ => {} - } + + value = match (left.as_ref(), right.as_ref()) { + // Short-circuiting + (BoolExpression::Bool(false), _) => Some(BoolExpression::Bool(false)), + (BoolExpression::Bool(true), BoolExpression::Bool(b)) => { + Some(BoolExpression::Bool(*b)) + } + (_, _) => None, + }; } BoolExpression::OrOp(left, right) => { changed |= left.simplify_helper(); changed |= right.simplify_helper(); - match **left { - BoolExpression::Bool(true) => value = Some(BoolExpression::Bool(true)), - BoolExpression::Bool(false) => value = Some((**right).clone()), - _ => {} - } - match **right { - BoolExpression::Bool(true) => value = Some(BoolExpression::Bool(true)), - BoolExpression::Bool(false) => value = Some((**left).clone()), - _ => {} - } - } - - BoolExpression::LessEQ(l, r) => { - **l = l.simplify().expect("Can't simplify"); - **r = r.simplify().expect("Can't simplify"); - if let ArithExpression::Int(x) = **l { - if let ArithExpression::Int(y) = **r { - value = Some(BoolExpression::Bool(x <= y)) + value = match (left.as_ref(), right.as_ref()) { + // Short-circuiting + (BoolExpression::Bool(true), _) => Some(BoolExpression::Bool(true)), + (BoolExpression::Bool(false), BoolExpression::Bool(b)) => { + Some(BoolExpression::Bool(*b)) } - } - } - BoolExpression::GreatEQ(l, r) => { - **l = l.simplify().expect("Can't simplify"); - **r = r.simplify().expect("Can't simplify"); - if let ArithExpression::Int(x) = **l { - if let ArithExpression::Int(y) = **r { - value = Some(BoolExpression::Bool(x >= y)) - } - } - } - BoolExpression::LessT(l, r) => { - **l = l.simplify().expect("Can't simplify"); - **r = r.simplify().expect("Can't simplify"); - if let ArithExpression::Int(x) = **l { - if let ArithExpression::Int(y) = **r { - value = Some(BoolExpression::Bool(x < y)) - } - } - } - BoolExpression::GreatT(l, r) => { - **l = l.simplify().expect("Can't simplify"); - **r = r.simplify().expect("Can't simplify"); - if let ArithExpression::Int(x) = **l { - if let ArithExpression::Int(y) = **r { - value = Some(BoolExpression::Bool(x > y)) - } - } - } - BoolExpression::EQ(l, r) => { - **l = l.simplify().expect("Can't simplify"); - **r = r.simplify().expect("Can't simplify"); - if let ArithExpression::Int(x) = **l { - if let ArithExpression::Int(y) = **r { - value = Some(BoolExpression::Bool(x == y)) - } - } + (_, _) => None, + }; } + + BoolExpression::LessEQ(l, r) => handle(l, r, &i32::le), + BoolExpression::GreatEQ(l, r) => handle(l, r, &i32::ge), + BoolExpression::LessT(l, r) => handle(l, r, &i32::lt), + BoolExpression::GreatT(l, r) => handle(l, r, &i32::gt), + BoolExpression::EQ(l, r) => handle(l, r, &i32::eq), BoolExpression::Bool(_) => {} } diff --git a/src/protobuf_server/ecdar_backend.rs b/src/protobuf_server/ecdar_backend.rs index 855508f9..0f313413 100644 --- a/src/protobuf_server/ecdar_backend.rs +++ b/src/protobuf_server/ecdar_backend.rs @@ -61,14 +61,16 @@ where } } - match future.catch_unwind().await { - Ok(response) => response, - Err(e) => Err(Status::internal(format!( - "{}, please report this bug to the developers", - downcast_to_string(e) - ))), - } - .map(Response::new) + future + .catch_unwind() + .await + .unwrap_or_else(|e| { + Err(Status::internal(format!( + "{}, please report this bug to the developers", + downcast_to_string(e) + ))) + }) + .map(Response::new) } impl ConcreteEcdarBackend {} @@ -81,7 +83,7 @@ impl EcdarBackend for ConcreteEcdarBackend { ) -> Result, Status> { let id = self.num.fetch_add(1, Ordering::SeqCst); let token_response = UserTokenResponse { user_id: id }; - Result::Ok(Response::new(token_response)) + Ok(Response::new(token_response)) } async fn send_query( diff --git a/src/protobuf_server/proto_conversions.rs b/src/protobuf_server/proto_conversions.rs index e16e2e9c..6ebacb84 100644 --- a/src/protobuf_server/proto_conversions.rs +++ b/src/protobuf_server/proto_conversions.rs @@ -144,7 +144,6 @@ impl From for ProtoConstraint { impl From for ProtoClock { fn from(clock: SpecificClockVar) -> Self { - use std::convert::TryFrom; match clock { SpecificClockVar::Zero => Self { clock: Some(ProtoClockEnum::ZeroClock(Default::default())), diff --git a/src/simulation/graph_layout.rs b/src/simulation/graph_layout.rs index 36a5cdb0..9a219426 100644 --- a/src/simulation/graph_layout.rs +++ b/src/simulation/graph_layout.rs @@ -36,13 +36,10 @@ impl Default for Config { } fn get_config() -> Config { - match read_config("config.json") { - Ok(config) => config, - Err(_) => { - info!("Could not find graph layout config, using defaults"); - Config::default() - } - } + read_config("config.json").unwrap_or_else(|_| { + info!("Could not find graph layout config, using defaults"); + Config::default() + }) } fn read_config>(path: P) -> Result> { diff --git a/src/system/extract_system_rep.rs b/src/system/extract_system_rep.rs index 36da99df..12c634ab 100644 --- a/src/system/extract_system_rep.rs +++ b/src/system/extract_system_rep.rs @@ -16,7 +16,6 @@ use super::query_failures::{SyntaxResult, SystemRecipeFailure}; use crate::system::pruning; use edbm::util::constraints::ClockIndex; use log::debug; -use simple_error::bail; #[derive(Debug, Clone, PartialEq, Eq)] pub enum ExecutableQueryError { @@ -157,10 +156,10 @@ pub fn create_executable_query<'a>( } // Should handle consistency, Implementation, determinism and specification here, but we cant deal with it atm anyway - _ => bail!("Not yet setup to handle query"), + _ => Err("Not yet setup to handle query".into()), } } else { - bail!("No query was supplied for extraction") + Err("No query was supplied for extraction".into()) } } diff --git a/src/tests/failure_message/actions_test.rs b/src/tests/failure_message/actions_test.rs index cc854433..4b9e8b0a 100644 --- a/src/tests/failure_message/actions_test.rs +++ b/src/tests/failure_message/actions_test.rs @@ -3,7 +3,7 @@ mod test { use crate::system::query_failures::{ - ConsistencyFailure, DeterminismFailure, DeterminismResult, QueryResult, RefinementFailure, + ConsistencyFailure, DeterminismFailure, QueryResult, RefinementFailure, RefinementPrecondition, }; use crate::system::specifics::SpecificLocation; @@ -14,7 +14,7 @@ mod test { fn determinism_test() { let expected_action = String::from("1"); let expected_location = SpecificLocation::new("NonDeterministic1", "L1", 0); //LocationID::Simple(String::from("L1")); - if let QueryResult::Determinism(DeterminismResult::Err(DeterminismFailure { + if let QueryResult::Determinism(Err(DeterminismFailure { state: actual_state, action: actual_action, system: actual_system, @@ -40,7 +40,7 @@ mod test { })) = json_run_query(PATH, "consistency: NonConsistent").unwrap() { let actual_location = actual_state.locations; - assert_eq!((expected_location), (actual_location)); + assert_eq!(expected_location, actual_location); assert_eq!(actual_system, "NonConsistent"); } else { panic!("Models in samples/action have been changed, REVERT!"); @@ -87,7 +87,7 @@ mod test { ))) = json_run_query(PATH, "refinement: NonConsistent <= CorrectComponent").unwrap() { let actual_location = actual_state.locations; - assert_eq!((expected_location), (actual_location)); + assert_eq!(expected_location, actual_location); assert_eq!(actual_system, "NonConsistent"); } else { panic!("Models in samples/action have been changed, REVERT!"); diff --git a/src/tests/failure_message/consistency_test.rs b/src/tests/failure_message/consistency_test.rs index a5bdc6c7..22f84a52 100644 --- a/src/tests/failure_message/consistency_test.rs +++ b/src/tests/failure_message/consistency_test.rs @@ -2,7 +2,7 @@ mod test { use crate::{ - system::query_failures::{ConsistencyFailure, ConsistencyResult, QueryResult}, + system::query_failures::{ConsistencyFailure, QueryResult}, tests::refinement::helper::json_run_query, }; @@ -13,9 +13,7 @@ mod test { let actual = json_run_query(PATH, "consistency: notConsistent").unwrap(); assert!(matches!( actual, - QueryResult::Consistency(ConsistencyResult::Err( - ConsistencyFailure::InconsistentFrom { .. } - )) + QueryResult::Consistency(Err(ConsistencyFailure::InconsistentFrom { .. })) )); } } diff --git a/src/tests/failure_message/syntax_test.rs b/src/tests/failure_message/syntax_test.rs index c30bb8f2..bbe2bb51 100644 --- a/src/tests/failure_message/syntax_test.rs +++ b/src/tests/failure_message/syntax_test.rs @@ -2,7 +2,7 @@ mod test { use crate::{ - system::query_failures::{QueryResult, SyntaxFailure, SyntaxResult}, + system::query_failures::{QueryResult, SyntaxFailure}, tests::refinement::helper::json_run_query, }; @@ -13,7 +13,7 @@ mod test { let actual = json_run_query(PATH, "syntax: syntaxFailure").unwrap(); assert!(matches!( actual, - QueryResult::Syntax(SyntaxResult::Err(SyntaxFailure::Unparsable { .. })) + QueryResult::Syntax(Err(SyntaxFailure::Unparsable { .. })) )); } } diff --git a/src/tests/reachability/partial_state.rs b/src/tests/reachability/partial_state.rs index a18496dc..30e12812 100644 --- a/src/tests/reachability/partial_state.rs +++ b/src/tests/reachability/partial_state.rs @@ -5,6 +5,7 @@ mod reachability_partial_states_test { use crate::model_objects::{Declarations, Location, LocationType}; use crate::transition_systems::CompositionType; use crate::transition_systems::LocationTree; + use test_case::test_case; fn build_location_tree_helper(id: &str, location_type: LocationType) -> Rc { diff --git a/src/tests/reachability/search_algorithm_test.rs b/src/tests/reachability/search_algorithm_test.rs index 5ff71c2e..5ab7d1db 100644 --- a/src/tests/reachability/search_algorithm_test.rs +++ b/src/tests/reachability/search_algorithm_test.rs @@ -77,7 +77,7 @@ mod reachability_search_algorithm_test { ) }); let path = actual_path.path; - assert!(expected_path.len() == path.len(), "Query: {}\nThe length of the actual and expected are not the same.\nexpected_path.len = {}\nactual_path.len = {} \n", query, expected_path.len(),path.len()); + assert_eq!(expected_path.len(), path.len(), "Query: {}\nThe length of the actual and expected are not the same.\nexpected_path.len = {}\nactual_path.len = {} \n", query, expected_path.len(), path.len()); for i in 0..path.len() { let edges: Vec<_> = path[i].edges.iter().map(|e| e.edge_id.clone()).collect(); assert_eq!( @@ -86,8 +86,8 @@ mod reachability_search_algorithm_test { "Query: {}\nThere should only be one edge in the path \n", query ); - assert!( - expected_path[i] == edges[0], + assert_eq!( + expected_path[i], edges[0], "Query: {}\nThe actual and expected is not the same \n", query ); @@ -114,7 +114,7 @@ mod reachability_search_algorithm_test { ) }); let path = actual_path.path; - assert!(expected_path.len() == path.len(), "Query: {}\nThe length of the actual and expected are not the same.\nexpected_path.len = {}\nactual_path.len = {} \n", query, expected_path.len(),path.len()); + assert_eq!(expected_path.len(), path.len(), "Query: {}\nThe length of the actual and expected are not the same.\nexpected_path.len = {}\nactual_path.len = {} \n", query, expected_path.len(), path.len()); for i in 0..path.len() { let edges: Vec<_> = path[i].edges.iter().map(|e| e.edge_id.clone()).collect(); assert_eq!( @@ -123,8 +123,8 @@ mod reachability_search_algorithm_test { "Query: {}\nThere should only be one edge in the path \n", query ); - assert!( - expected_path[i] == edges, + assert_eq!( + expected_path[i], edges, "Query: {}\nThe actual and expected is not the same \n", query ); diff --git a/src/transition_systems/compiled_component.rs b/src/transition_systems/compiled_component.rs index 0afe086c..955109fb 100644 --- a/src/transition_systems/compiled_component.rs +++ b/src/transition_systems/compiled_component.rs @@ -170,6 +170,10 @@ impl TransitionSystem for CompiledComponent { self.locations.values().cloned().collect() } + fn get_location(&self, id: &LocationID) -> Option> { + self.locations.get(id).cloned() + } + fn get_decls(&self) -> Vec<&Declarations> { vec![&self.comp_info.declarations] } @@ -196,14 +200,6 @@ impl TransitionSystem for CompiledComponent { CompositionType::Simple } - fn get_location(&self, id: &LocationID) -> Option> { - self.locations.get(id).cloned() - } - - fn component_names(&self) -> Vec<&str> { - vec![&self.comp_info.name] - } - fn comp_infos(&'_ self) -> ComponentInfoTree<'_> { ComponentInfoTree::Info(&self.comp_info) } @@ -212,6 +208,10 @@ impl TransitionSystem for CompiledComponent { self.comp_info.name.clone() } + fn component_names(&self) -> Vec<&str> { + vec![&self.comp_info.name] + } + fn construct_location_tree( &self, target: SpecificLocation, diff --git a/src/transition_systems/composition.rs b/src/transition_systems/composition.rs index 42d8f266..6d0e48a7 100644 --- a/src/transition_systems/composition.rs +++ b/src/transition_systems/composition.rs @@ -109,6 +109,11 @@ impl ComposedTransitionSystem for Composition { unreachable!() } + fn check_local_consistency(&self) -> ConsistencyResult { + self.left.check_local_consistency()?; + self.right.check_local_consistency() + } + fn get_children(&self) -> (&TransitionSystemPtr, &TransitionSystemPtr) { (&self.left, &self.right) } @@ -132,9 +137,4 @@ impl ComposedTransitionSystem for Composition { fn get_output_actions(&self) -> HashSet { self.outputs.clone() } - - fn check_local_consistency(&self) -> ConsistencyResult { - self.left.check_local_consistency()?; - self.right.check_local_consistency() - } } diff --git a/src/transition_systems/conjunction.rs b/src/transition_systems/conjunction.rs index 5cfe2d17..ed8a8b46 100644 --- a/src/transition_systems/conjunction.rs +++ b/src/transition_systems/conjunction.rs @@ -39,8 +39,7 @@ impl Conjunction { (right.as_ref(), right_out), ) .map_err(|e| e.to_rfconj(left, right)); - } - if !left_out.is_disjoint(&right_in) { + } else if !left_out.is_disjoint(&right_in) { return ActionFailure::not_disjoint( (left.as_ref(), left_out), (right.as_ref(), right_in), diff --git a/src/transition_systems/location_id.rs b/src/transition_systems/location_id.rs index b2c4cb07..503e2fab 100644 --- a/src/transition_systems/location_id.rs +++ b/src/transition_systems/location_id.rs @@ -1,6 +1,7 @@ use std::fmt::{Display, Formatter}; use crate::parse_queries::parse_to_system_expr; +use crate::transition_systems::variant_eq; use crate::{model_objects::expressions::SystemExpression, system::specifics::SpecialLocation}; #[derive(Debug, Clone, Eq, Hash, PartialEq)] @@ -78,42 +79,44 @@ impl From for LocationID { impl Display for LocationID { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut handle = |left: &LocationID, + right: &LocationID, + eq: &LocationID, + sep: &str| + -> std::fmt::Result { + match ( + variant_eq(left, eq) || variant_eq(left, &LocationID::Simple("".to_string())), + variant_eq(right, eq) || variant_eq(right, &LocationID::Simple("".to_string())), + ) { + (true, true) => write!(f, "{} {} {}", left, sep, right)?, + (true, false) => write!(f, "{} {} ({})", left, sep, right)?, + (false, true) => write!(f, "({}) {} {}", left, sep, right)?, + (false, false) => write!(f, "({}) {} ({})", left, sep, right)?, + }; + Ok(()) + }; match self { - LocationID::Conjunction(left, right) => { - match **left { - LocationID::Conjunction(_, _) => write!(f, "{}", (*left))?, - LocationID::Simple(_) => write!(f, "{}", (*left))?, - _ => write!(f, "({})", (*left))?, - }; - write!(f, "&&")?; - match **right { - LocationID::Conjunction(_, _) => write!(f, "{}", (*right))?, - LocationID::Simple(_) => write!(f, "{}", (*right))?, - _ => write!(f, "({})", (*right))?, - }; - } - LocationID::Composition(left, right) => { - match **left { - LocationID::Composition(_, _) => write!(f, "{}", (*left))?, - LocationID::Simple(_) => write!(f, "{}", (*left))?, - _ => write!(f, "({})", (*left))?, - }; - write!(f, "||")?; - match **right { - LocationID::Composition(_, _) => write!(f, "{}", (*right))?, - LocationID::Simple(_) => write!(f, "{}", (*right))?, - _ => write!(f, "({})", (*right))?, - }; - } + LocationID::Conjunction(left, right) => handle( + left.as_ref(), + right.as_ref(), + &LocationID::Conjunction(left.clone(), right.clone()), + "&&", + )?, + LocationID::Composition(left, right) => handle( + left.as_ref(), + right.as_ref(), + &LocationID::Composition(left.clone(), right.clone()), + "||", + )?, LocationID::Quotient(left, right) => { match **left { - LocationID::Simple(_) => write!(f, "{}", (*left))?, - _ => write!(f, "({})", (*left))?, + LocationID::Simple(_) => write!(f, "{}", *left)?, + _ => write!(f, "({})", *left)?, }; write!(f, "\\\\")?; match **right { - LocationID::Simple(_) => write!(f, "{}", (*right))?, - _ => write!(f, "({})", (*right))?, + LocationID::Simple(_) => write!(f, "{}", *right)?, + _ => write!(f, "({})", *right)?, }; } LocationID::Simple(location_id) => { diff --git a/src/transition_systems/mod.rs b/src/transition_systems/mod.rs index f2d2552f..70016260 100644 --- a/src/transition_systems/mod.rs +++ b/src/transition_systems/mod.rs @@ -17,3 +17,7 @@ pub use location_tree::{CompositionType, LocationTree}; pub use quotient::Quotient; pub use transition_id::TransitionID; pub use transition_system::{TransitionSystem, TransitionSystemPtr}; + +pub fn variant_eq(a: &T, b: &T) -> bool { + std::mem::discriminant(a) == std::mem::discriminant(b) +} diff --git a/src/transition_systems/transition_id.rs b/src/transition_systems/transition_id.rs index 3b8a942c..511e84b6 100644 --- a/src/transition_systems/transition_id.rs +++ b/src/transition_systems/transition_id.rs @@ -1,3 +1,4 @@ +use crate::transition_systems::variant_eq; use std::fmt::{Display, Formatter}; /// TransitionID is used to represent which edges a given transition consists of. @@ -116,26 +117,36 @@ impl TransitionID { impl Display for TransitionID { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut handle = |left: &TransitionID, right: &TransitionID| -> std::fmt::Result { - match (left, right) { - ( - TransitionID::Composition(_, _) | TransitionID::Simple(_), - TransitionID::Composition(_, _) | TransitionID::Simple(_), - ) => write!(f, "{} || {}", left, right)?, - (TransitionID::Composition(_, _) | TransitionID::Simple(_), _) => { - write!(f, "{} || ({})", left, right)? - } - (_, TransitionID::Composition(_, _) | TransitionID::Simple(_)) => { - write!(f, "({}) || {}", left, right)? - } - (_, _) => write!(f, "({}) || ({})", left, right)?, + let mut handle = |left: &TransitionID, + right: &TransitionID, + eq: &TransitionID, + sep: &str| + -> std::fmt::Result { + match ( + variant_eq(left, eq) || variant_eq(left, &TransitionID::Simple("".to_string())), + variant_eq(right, eq) || variant_eq(right, &TransitionID::Simple("".to_string())), + ) { + (true, true) => write!(f, "{} {} {}", left, sep, right)?, + (true, false) => write!(f, "{} {} ({})", left, sep, right)?, + (false, true) => write!(f, "({}) {} {}", left, sep, right)?, + (false, false) => write!(f, "({}) {} ({})", left, sep, right)?, } Ok(()) }; match self { - TransitionID::Conjunction(left, right) | TransitionID::Composition(left, right) => { - handle(left.as_ref(), right.as_ref())?; - } + TransitionID::Conjunction(left, right) => handle( + left.as_ref(), + right.as_ref(), + &TransitionID::Conjunction(left.clone(), right.clone()), + "&&", + )?, + TransitionID::Composition(left, right) => handle( + left.as_ref(), + right.as_ref(), + &TransitionID::Composition(left.clone(), right.clone()), + "||", + )?, + TransitionID::Quotient(left, right) => { for l in left { match *(l) {