diff --git a/Cargo.lock b/Cargo.lock index c2414f6..26258a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,7 +136,7 @@ dependencies = [ [[package]] name = "shakemyleg" -version = "2.2.1" +version = "2.4.0" dependencies = [ "itertools", "json", diff --git a/Cargo.toml b/Cargo.toml index a99b0c8..54161c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "shakemyleg" description = "A simple state machine definition language and interpreter." repository = "https://github.com/cbosoft/sml" license = "MIT" -version = "2.2.1" +version = "2.4.0" edition = "2021" authors = ["Christopher Boyle"] @@ -15,3 +15,6 @@ regex = "1.10.5" serde = { version = "1.0.204", features = ["derive"] } serde_json = "1.0.120" thiserror = "1.0.62" + +[features] +thread_safe = [] diff --git a/README.md b/README.md index 8c936ad..b50d0a8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ ![tests badge](https://github.com/cbosoft/sml/actions/workflows/tests.yml/badge.svg) +[![Written up - here!](https://img.shields.io/static/v1?label=Written+up&message=here!&color=2ea44f)](https://cmjb.tech/blog/2024/08/02/shakemyleg/) + # SML - ShakeMyLeg, is that a State Machine Language? A simple state machine definition language and interpreter. @@ -10,12 +12,12 @@ A very simple example `shakemyleg` machine: # flip_flip.sml state A: - when true: + always: outputs.bar = inputs.bar + 1 changeto B state B: - when true: + always: outputs.bar = inputs.bar + 1 changeto A ``` @@ -34,11 +36,11 @@ struct Foo { let src = r#" state A: - when true: + always: outputs.bar = inputs.bar + 1 changeto B state B: - when true: + always: outputs.bar = inputs.bar + 1 changeto A "#; diff --git a/src/compiler.rs b/src/compiler.rs index b4b60a2..ef04d8d 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,6 +1,14 @@ -use std::{collections::HashMap, rc::Rc}; +use std::collections::HashMap; -use crate::{error::{SML_Error, SML_Result}, expression::Expression, identifier::Identifier, operation::BinaryOperation, state::{State, StateOp}, value::Value, StateMachine}; + +use crate::error::{SML_Error, SML_Result}; +use crate::expression::Expression; +use crate::identifier::Identifier; +use crate::operation::BinaryOperation; +use crate::state::{State, StateOp}; +use crate::value::Value; +use crate::StateMachine; +use crate::refcount::RefCount; // Algorithm from: https://faculty.cs.niu.edu/~hutchins/csci241/eval.htm @@ -41,7 +49,7 @@ impl Token { else if let Ok(v) = s.parse::() { Self::Number(v) } - else if let "+" | "-" | "*" | "/" | "=" | "==" | "<" | "<=" | ">" | ">=" | "!=" = s.as_str() { + else if let "+" | "-" | "*" | "/" | "=" | "==" | "<" | "<=" | ">" | ">=" | "!=" | "&&" | "||" | "contains" = s.as_str() { Self::Operator(s) } else if s == "(" { @@ -226,18 +234,68 @@ fn expr_from_str(s: &str, lineno: usize) -> SML_Result { } +struct StateData { + pub name: String, + pub head: Vec, + pub branches: Vec, + pub has_default: bool, + pub has_otherwise: bool, + pub has_always: bool +} + +impl StateData { + fn new(name: String) -> Self { + Self { + name, + head: Vec::new(), + branches: Vec::new(), + has_default: false, + has_otherwise: false, + has_always: false, + } + } +} + +impl From for State { + fn from(state_data: StateData) -> Self { + let name = state_data.name; + let head = state_data.head; + let body = state_data.branches.into_iter().map(|b| (b.condition, b.body, b.state_op)).collect(); + State::new(name, head, body) + } +} + +struct StateBranchData { + condition: Expression, + body: Vec, + state_op: StateOp, + is_default: bool, +} + +impl StateBranchData { + fn new(condition: Expression) -> Self { + Self { + condition, + body: Vec::new(), + state_op: StateOp::Stay, + is_default: false, + } + } +} + + /// Take a string of SML source and compile to state machine. /// ``` /// use shakemyleg::compile; /// /// let src = r#" /// state init: -/// when inputs.b < 10: +/// when inputs.b <= 10: /// outputs.b = inputs.b + 1 -/// when inputs.b >= 10: +/// otherwise: /// changeto second /// state second: -/// when true: +/// always: /// outputs.c = inputs.c + 2 /// "#; /// @@ -247,10 +305,10 @@ pub fn compile(s: &str) -> SML_Result { let mut c_state_stack = vec![CompileState::TopLevel]; let lines: Vec<_> = s.lines().collect(); let mut i = 0usize; - let mut state_data = None; - let mut state_branch_data = None; + let mut state_data: Option = None; + let mut state_branch_data: Option = None; let mut default_head = Vec::new(); - let mut states = Vec::new(); + let mut states: Vec = Vec::new(); let mut leading_ws = None; let nlines = lines.len(); @@ -277,7 +335,7 @@ pub fn compile(s: &str) -> SML_Result { // if let Some(sname_colon) = line.strip_prefix("state ") { if let Some(sname) = sname_colon.strip_suffix(":") { - state_data = Some((sname.to_string(), Vec::new(), Vec::new())); + state_data = Some(StateData::new(sname.to_string())); c_state_stack.push(CompileState::State); true } @@ -307,25 +365,65 @@ pub fn compile(s: &str) -> SML_Result { c_state_stack.push(CompileState::StateHead); } else if let Some(expr_colon) = line_trim.strip_prefix("when ") { + let has_always = state_data.as_ref().unwrap().has_always; + let has_otherwise = state_data.as_ref().unwrap().has_otherwise; + if has_always || has_otherwise { + return Err(SML_Error::SyntaxError(format!("Branch defined after always or otherwise on line {i}"))); + } + if let Some(expr) = expr_colon.strip_suffix(":") { let cond = expr_from_str(expr, i)?; - state_branch_data = Some((cond, Vec::new(), StateOp::Stay)); + state_branch_data = Some(StateBranchData::new(cond)); c_state_stack.push(CompileState::StateBranch); } else { return Err(SML_Error::SyntaxError(format!("Missing colon on line {i}:{line}"))); } } + else if line_trim == "always:" { + let has_always = state_data.as_ref().unwrap().has_always; + let has_otherwise = state_data.as_ref().unwrap().has_otherwise; + if has_always || has_otherwise { + return Err(SML_Error::SyntaxError(format!("Branch defined after always or otherwise on line {i}."))); + } + + let has_other_branches = state_data.as_ref().unwrap().branches.len() > 0; + if has_other_branches { + return Err(SML_Error::SyntaxError(format!("Always defined after another branch on line {i}. Always must be the other branch."))); + } + + let cond = Expression::Value(Value::Bool(true)); + state_branch_data = Some(StateBranchData::new(cond)); + state_data.as_mut().unwrap().has_always = true; + c_state_stack.push(CompileState::StateBranch); + } + else if line_trim == "otherwise:" { + let has_always = state_data.as_ref().unwrap().has_always; + let has_otherwise = state_data.as_ref().unwrap().has_otherwise; + if has_always || has_otherwise { + return Err(SML_Error::SyntaxError(format!("Branch defined after always or otherwise on line {i}."))); + } + + let has_other_branches = state_data.as_ref().unwrap().branches.len() > 0; + if !has_other_branches { + return Err(SML_Error::SyntaxError(format!("Otherwise defined alone on line {i}. Otherwise must come after at least one other branch."))); + } + + let cond = Expression::Value(Value::Bool(true)); + state_branch_data = Some(StateBranchData::new(cond)); + state_data.as_mut().unwrap().has_otherwise = true; + c_state_stack.push(CompileState::StateBranch); + } else { eprintln!("{}", lines[i-1]); - return Err(SML_Error::SyntaxError(format!("Expect head or when after state intro on line {i}:{line}"))); + return Err(SML_Error::SyntaxError(format!("Expected ['head:', 'when :', 'always:', 'otherwise:'] after state intro on line {i}:{line}"))); } } true } else { - if let Some((name, head, body)) = state_data.take() { - states.push(State::new(name, head, body)); + if let Some(state_data) = state_data.take() { + states.push(state_data.into()); c_state_stack.pop(); false } @@ -351,7 +449,7 @@ pub fn compile(s: &str) -> SML_Result { if line.starts_with(&leading_ws.as_ref().unwrap().1) { let line = line.trim_start(); let expr = expr_from_str(line, i)?; - state_data.as_mut().unwrap().1.push(expr); + state_data.as_mut().unwrap().head.push(expr); true } else { @@ -363,23 +461,33 @@ pub fn compile(s: &str) -> SML_Result { if line.starts_with(&leading_ws.as_ref().unwrap().1) { let line = line.trim_start(); if let Some(state_name) = line.strip_prefix("changeto ") { - state_branch_data.as_mut().unwrap().2 = StateOp::ChangeTo(state_name.to_string()); + state_branch_data.as_mut().unwrap().state_op = StateOp::ChangeTo(state_name.to_string()); } else if line == "end" { - state_branch_data.as_mut().unwrap().2 = StateOp::End; + state_branch_data.as_mut().unwrap().state_op = StateOp::End; } else if line == "stay" { - state_branch_data.as_mut().unwrap().2 = StateOp::Stay; + state_branch_data.as_mut().unwrap().state_op = StateOp::Stay; + } + else if line == "default" { + if state_data.as_ref().unwrap().has_default { + let name = &state_data.as_ref().unwrap().name; + return Err(SML_Error::SyntaxError(format!("Multiple branches marked as default in state {name}. On line {i}."))); + } + else { + state_branch_data.as_mut().unwrap().is_default = true; + state_data.as_mut().unwrap().has_default = true; + } } else { let expr = expr_from_str(line, i)?; - state_branch_data.as_mut().unwrap().1.push(expr); + state_branch_data.as_mut().unwrap().body.push(expr); } true } else { let branch = state_branch_data.take().unwrap(); - state_data.as_mut().unwrap().2.push(branch); + state_data.as_mut().unwrap().branches.push(branch); c_state_stack.pop(); false } @@ -392,18 +500,18 @@ pub fn compile(s: &str) -> SML_Result { } if let Some(branch) = state_branch_data { - state_data.as_mut().unwrap().2.push(branch); + state_data.as_mut().unwrap().branches.push(branch); } - if let Some((name, head, body)) = state_data { - states.push(State::new(name, head, body)); + if let Some(state_data) = state_data { + states.push(state_data.into()); } let initial_state = states[0].name().clone(); let states_iter = states.into_iter(); let mut states = HashMap::new(); for state in states_iter { - states.insert(state.name().clone(), Rc::new(state)); + states.insert(state.name().clone(), RefCount::new(state)); } let initial_state = states.get(&initial_state).unwrap().clone(); @@ -555,9 +663,71 @@ state B: let _ = compile(SRC).unwrap(); } + #[test] + fn test_compile_always_otherwise_1() { + const SRC: &'static str = r#" +state A: + always: + changeto B +state B: + when false: + changeto A + otherwise: + changeto A +"#; + let _ = compile(SRC).unwrap(); + } + + #[test] + #[should_panic] + fn test_compile_always_otherwise_2() { + const SRC: &'static str = r#" +state A: + always: + changeto B + otherwise: + changeto A +state B: + always: + changeto A +"#; + let _ = compile(SRC).unwrap(); + } + + #[test] + #[should_panic] + fn test_compile_always_otherwise_3() { + const SRC: &'static str = r#" +state A: + otherwise: + changeto B + when true: + changeto A +state B: + always: + changeto A +"#; + let _ = compile(SRC).unwrap(); + } + + #[test] + fn test_compile_always_otherwise_4() { + const SRC: &'static str = r#" +state A: + when false: + changeto A + otherwise: + changeto B +state B: + always: + changeto A +"#; + let _ = compile(SRC).unwrap(); + } + #[derive(Serialize)] struct InFoo { - foo: u8 + foo: Vec } #[derive(Deserialize)] @@ -569,19 +739,56 @@ state B: fn test_compile_end() { const SRC: &'static str = r#" state final: - when true: + always: outputs.bar = 1 end "#; let mut sm = compile(SRC).unwrap(); - let i = InFoo { foo: 0u8 }; + let i = InFoo { foo: vec![0u8] }; let o: OutBar = sm.run(i).unwrap().unwrap(); assert_eq!(o.bar, 1u8); - let i = InFoo { foo: 0u8 }; + let i = InFoo { foo: vec![0u8] }; let rv: SML_Result> = sm.run(i); assert!(matches!(rv, Ok(None))); } + #[test] + fn test_compile_contais_1() { + const SRC: &'static str = r#" +state final: + when inputs.foo contains 0: + outputs.bar = 1 + otherwise: + outputs.bar = 0 +"#; + let mut sm = compile(SRC).unwrap(); + + let i = InFoo { foo: vec![0, 1, 2, 3] }; + let o: OutBar = sm.run(i).unwrap().unwrap(); + assert_eq!(o.bar, 1u8); + + let i = InFoo { foo: vec![1, 2, 3] }; + let o: OutBar = sm.run(i).unwrap().unwrap(); + assert_eq!(o.bar, 0u8); + } + + #[test] + #[should_panic] + fn test_compile_contais_2() { + const SRC: &'static str = r#" +state final: + when outputs.bar contains 0: + outputs.bar = 1 + otherwise: + outputs.bar = 0 +"#; + let mut sm = compile(SRC).unwrap(); + + let i = InFoo { foo: vec![0, 1, 2, 3] }; + let o: OutBar = sm.run(i).unwrap().unwrap(); + assert_eq!(o.bar, 1u8); + } + } diff --git a/src/examples.rs b/src/examples.rs index 993a6ab..ae18a9e 100644 --- a/src/examples.rs +++ b/src/examples.rs @@ -1,5 +1,7 @@ //! # Examples //! +//! +//! Compile SML source into a [StateMachine] and run it. //! ``` //! use shakemyleg::compile; //! use serde::{Serialize, Deserialize}; @@ -30,7 +32,7 @@ //! outputs.state = "first" //! when inputs.bar < 10: //! outputs.bar = inputs.bar + 1 -//! when outputs.bar >= 10: +//! otherwise: //! changeto second //! //! state second: @@ -38,7 +40,7 @@ //! outputs.state = "second" //! when inputs.bar > 1: //! outputs.bar = inputs.bar - 1 -//! when outputs.bar <= 1: +//! otherwise: //! changeto first //! "#; //! @@ -50,4 +52,39 @@ //! assert_eq!(o, Outputs { state: "first".to_string(), foo: 1u8, bar: 4u8 }); //! ``` //! +//! We can't define a list literal in SML, but we can interact with lists passed into the machine: +//! ``` +//! use shakemyleg::compile; +//! use serde::{Serialize, Deserialize}; +//! +//! #[derive(Serialize)] +//! struct Globals { +//! things: Vec +//! } +//! +//! #[derive(Serialize)] +//! struct Inputs { +//! thing: String +//! } +//! +//! #[derive(Deserialize, PartialEq, Debug)] +//! struct Outputs { +//! things: Vec +//! } +//! +//! let src = r#" +//! state ThingAccumulator: +//! head: +//! globals.things = globals.things + inputs.thing +//! outputs.things = globals.things +//! when globals.things contains "lastthing": +//! end +//! "#; +//! +//! let mut sm = compile(src).unwrap(); +//! sm.reinit(Globals { things: Vec::new() }); //! +//! let i = Inputs { thing: "FirstThing".to_string() }; +//! let o: Outputs = sm.run(i).unwrap().unwrap(); +//! assert_eq!(o, Outputs { things: vec![ "FirstThing".to_string() ] }); +//! ``` diff --git a/src/identifier.rs b/src/identifier.rs index 9afcbe3..9ae1e6d 100644 --- a/src/identifier.rs +++ b/src/identifier.rs @@ -86,12 +86,7 @@ impl Identifier { store = &mut store[node]; } - let json_value = match v { - Value::Bool(b) => JsonValue::Boolean(*b), - Value::String(s) => JsonValue::String(s.to_string()), - Value::Number(n) => JsonValue::Number((*n).into()), - }; - + let json_value = v.as_json(); store[key.unwrap()] = json_value; Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 8ef3c80..4d8d37e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ mod operation; mod expression; mod state; mod state_machine; +mod refcount; pub use crate::error::{SML_Error, SML_Result}; pub use crate::state_machine::StateMachine; diff --git a/src/operation.rs b/src/operation.rs index 76b72e3..6d474c0 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -60,6 +60,9 @@ pub enum BinaryOperation { // Boolean And, Or, + + // List ops + Contains, } impl BinaryOperation { @@ -86,6 +89,9 @@ impl BinaryOperation { "&&" => Self::And, "||" => Self::Or, + // List + "contains" => Self::Contains, + s => { return Err(SML_Error::SyntaxError(format!("Invalid binary operation {s}"))); } }; @@ -102,31 +108,37 @@ impl BinaryOperation { Self::Add => { match (left, right) { (Value::Number(left), Value::Number(right)) => Ok(Value::Number(left + right)), - _ => Err(SML_Error::BadOperation("Arithmetic only valid for boolean operands.".to_string())) + (Value::List(l), new_value) => { + let mut l = l.clone(); + let new_value = Box::new(new_value.clone()); + l.push(new_value); + Ok(Value::List(l)) + }, + _ => Err(SML_Error::BadOperation("'+' only valid for numerical operands or to add a value to a list.".to_string())) } }, Self::Subtract => { match (left, right) { (Value::Number(left), Value::Number(right)) => Ok(Value::Number(left - right)), - _ => Err(SML_Error::BadOperation("Arithmetic only valid for boolean operands.".to_string())) + _ => Err(SML_Error::BadOperation("Arithmetic only valid for numerical operands.".to_string())) } }, Self::Multiply => { match (left, right) { (Value::Number(left), Value::Number(right)) => Ok(Value::Number(left * right)), - _ => Err(SML_Error::BadOperation("Arithmetic only valid for boolean operands.".to_string())) + _ => Err(SML_Error::BadOperation("Arithmetic only valid for numerical operands.".to_string())) } }, Self::Divide => { match (left, right) { (Value::Number(left), Value::Number(right)) => Ok(Value::Number(left / right)), - _ => Err(SML_Error::BadOperation("Arithmetic only valid for boolean operands.".to_string())) + _ => Err(SML_Error::BadOperation("Arithmetic only valid for numerical operands.".to_string())) } }, Self::Power => { match (left, right) { (Value::Number(left), Value::Number(right)) => Ok(Value::Number(left.powf(*right))), - _ => Err(SML_Error::BadOperation("Arithmetic only valid for boolean operands.".to_string())) + _ => Err(SML_Error::BadOperation("Arithmetic only valid for numerical operands.".to_string())) } }, @@ -185,6 +197,26 @@ impl BinaryOperation { let right = right.as_bool(); Ok(Value::Bool(left || right)) }, + + // List ops + Self::Contains => { + match (left, right) { + (Value::List(left), value) => { + let rv = { + let mut rv = false; + for item in left.iter() { + if **item == *value { + rv = true; + break; + } + } + rv + }; + Ok(Value::Bool(rv)) + } + _ => Err(SML_Error::BadOperation("Invalid type. Syntax is ' contains '.".to_string())) + } + }, } } } diff --git a/src/refcount.rs b/src/refcount.rs new file mode 100644 index 0000000..65340eb --- /dev/null +++ b/src/refcount.rs @@ -0,0 +1,5 @@ +#[cfg(feature = "thread_safe")] +pub type RefCount = std::sync::Arc; + +#[cfg(not(feature = "thread_safe"))] +pub type RefCount = std::rc::Rc; diff --git a/src/state_machine.rs b/src/state_machine.rs index 45bd7aa..80132c7 100644 --- a/src/state_machine.rs +++ b/src/state_machine.rs @@ -1,15 +1,15 @@ use std::collections::HashMap; -use std::rc::Rc; use serde::{Serialize, de::DeserializeOwned}; use json::JsonValue; +use crate::refcount::RefCount; use crate::expression::Expression; use crate::state::{State, StateOp}; use crate::error::{SML_Error, SML_Result}; -type StateRef = Rc; +type StateRef = RefCount; #[derive(Clone, Debug)] @@ -24,7 +24,7 @@ pub struct StateMachine { impl StateMachine { pub fn new(default_head: Vec, states: HashMap, initial_state: StateRef) -> Self { let globals = json::object! { }; - let current_state = Some(Rc::clone(&initial_state)); + let current_state = Some(RefCount::clone(&initial_state)); Self { globals, default_head, states, current_state } } @@ -37,7 +37,7 @@ impl StateMachine { fn get_state(&self, name: &String) -> SML_Result { match self.states.get(name) { - Some(state) => Ok(Rc::clone(state)), + Some(state) => Ok(RefCount::clone(state)), None => Err(SML_Error::NonexistantState(name.clone())) } } @@ -56,7 +56,7 @@ impl StateMachine { Some(current_state) => { let i = serde_json::to_string(&i)?; let i = json::parse(&i)?; - let state = Rc::clone(¤t_state); + let state = RefCount::clone(¤t_state); let (o, state_op) = (*state).run(&i, &mut self.globals, &self.default_head)?; let o = o.to_string(); let o: O = serde_json::from_str(&o)?; @@ -85,3 +85,15 @@ impl StateMachine { Ok(g) } } + +#[cfg(all(test, feature = "thread_safe"))] +mod thread_safety_tests { + use super::StateMachine; + + fn is_send_sync() { } + + #[test] + fn test_send_sync() { + is_send_sync::(); + } +} diff --git a/src/value.rs b/src/value.rs index c1a49a8..392e4a4 100644 --- a/src/value.rs +++ b/src/value.rs @@ -8,6 +8,7 @@ pub enum Value { String(String), Number(f64), Bool(bool), + List(Vec>), } impl Value { @@ -22,8 +23,15 @@ impl Value { else if json.is_boolean() { Ok(Self::Bool(json.as_bool().unwrap())) } + else if json.is_array() { + let mut list = Vec::new(); + for item in json.members() { + list.push(Box::new(Value::new(item)?)); + } + Ok(Self::List(list)) + } else { - Err(SML_Error::JsonFormatError("Value expects a json number, string, or boolean. Got null, object, array, or empty.".to_string())) + Err(SML_Error::JsonFormatError("Value expects a json number, string, array, or boolean. Got null, object, or empty.".to_string())) } } @@ -31,7 +39,32 @@ impl Value { match self { Self::Bool(v) => *v, Self::Number(v) => *v != 0.0, - Self::String(v) => v.is_empty(), + Self::String(v) => !v.is_empty(), + Self::List(v) => !v.is_empty(), + } + } + + pub fn as_json(&self) -> JsonValue { + match &self { + Self::Bool(b) => JsonValue::Boolean(*b), + Self::String(s) => JsonValue::String(s.to_string()), + Self::Number(n) => JsonValue::Number((*n).into()), + Self::List(l) => { + JsonValue::Array(l.iter().map(|v| v.as_json()).collect()) + } + } + } +} + + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::String(s1), Self::String(s2)) => s1 == s2, + (Self::Bool(b1), Self::Bool(b2)) => b1 == b2, + (Self::Number(n1), Self::Number(n2)) => n1 == n2, + (Self::List(l1), Self::List(l2)) => l1 == l2, + _ => false } } }