diff --git a/Inc/ST-LIB_LOW/StateMachine/HeapStateOrder.hpp b/Inc/ST-LIB_LOW/StateMachine/HeapStateOrder.hpp index 6d71b8d7..502ee577 100644 --- a/Inc/ST-LIB_LOW/StateMachine/HeapStateOrder.hpp +++ b/Inc/ST-LIB_LOW/StateMachine/HeapStateOrder.hpp @@ -24,13 +24,12 @@ template class HeapStateOrder : public HeapOrder { } void process() override { - if (callback != nullptr && state_machine.is_on && - state_machine.get_current_state_id() == state) + if (callback != nullptr && state_machine.get_current_state_id() == state) callback(); } void parse(OrderProtocol* socket, uint8_t* data) override { - if (state_machine.is_on && state_machine.get_current_state_id() == state) + if (state_machine.get_current_state_id() == state) HeapOrder::parse(data); } }; diff --git a/Inc/ST-LIB_LOW/StateMachine/StackStateOrder.hpp b/Inc/ST-LIB_LOW/StateMachine/StackStateOrder.hpp index 83cffb9c..98c6d6d5 100644 --- a/Inc/ST-LIB_LOW/StateMachine/StackStateOrder.hpp +++ b/Inc/ST-LIB_LOW/StateMachine/StackStateOrder.hpp @@ -25,13 +25,12 @@ class StackStateOrder : public StackOrder { } void process() override { - if (this->callback != nullptr && state_machine.is_on && - state_machine.get_current_state_id() == state) + if (this->callback != nullptr && state_machine.get_current_state_id() == state) this->callback(); } void parse(OrderProtocol* socket, uint8_t* data) override { - if (state_machine.is_on && state_machine.get_current_state_id() == state) + if (state_machine.get_current_state_id() == state) StackOrder::parse(data); } }; diff --git a/Inc/ST-LIB_LOW/StateMachine/StateMachine.hpp b/Inc/ST-LIB_LOW/StateMachine/StateMachine.hpp index 0e038f57..2bffc305 100644 --- a/Inc/ST-LIB_LOW/StateMachine/StateMachine.hpp +++ b/Inc/ST-LIB_LOW/StateMachine/StateMachine.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #ifdef STLIB_ETH #include "StateMachine/StateOrder.hpp" @@ -19,8 +20,6 @@ using ms = std::chrono::milliseconds; using us = std::chrono::microseconds; using s = std::chrono::seconds; -template using FixedVector = StaticVector; - template concept IsEnum = std::is_enum_v; @@ -58,19 +57,22 @@ concept are_transitions = (std::same_as> && ...); template class State { private: - FixedVector cyclic_actions = {}; - FixedVector on_enter_actions = {}; - FixedVector on_exit_actions = {}; + StaticVector cyclic_actions = {}; + StaticVector on_enter_actions = {}; + StaticVector on_exit_actions = {}; StateEnum state = {}; - FixedVector, NTransitions> transitions = {}; + StaticVector, NTransitions> transitions = {}; public: - [[no_unique_address]] FixedVector state_orders_ids = {}; + [[no_unique_address]] StaticVector state_orders_ids = {}; static constexpr size_t transition_count = NTransitions; template requires are_transitions consteval State(StateEnum state, T... transitions) : state(state) { + if (((transitions.target == state) || ...)) { + ErrorHandler("Current state cannot be the target of a transition"); + } (this->transitions.push_back(transitions), ...); } @@ -195,7 +197,13 @@ template is_on = false; } - constexpr void add_state_order(uint16_t id) { state_orders_ids.push_back(id); } + constexpr void add_state_order(uint16_t id) { +#ifdef STLIB_ETH + state_orders_ids.push_back(id); +#else + (void)id; +#endif + } template consteval TimedAction* add_cyclic_action(Callback action, TimeUnit period) { @@ -231,13 +239,24 @@ concept IsState = is_state::value; template concept are_states = (IsState && ...); +template +class StateMachine; + +template struct is_state_machine : std::false_type {}; + +template +struct is_state_machine> + : std::true_type {}; + +template +concept IsStateMachineClass = is_state_machine>::value; + /// Interface for State Machines to allow other classes to interact with the state machine without /// knowing its implementation class IStateMachine { public: virtual constexpr ~IStateMachine() = default; virtual void check_transitions() = 0; - virtual void set_on(bool is_on) = 0; virtual void force_change_state(size_t state) = 0; virtual size_t get_current_state_id() const = 0; constexpr bool operator==(const IStateMachine&) const = default; @@ -246,49 +265,103 @@ class IStateMachine { virtual void enter() = 0; virtual void exit() = 0; virtual void start() = 0; - template friend class StateMachine; + template friend class StateMachine; }; -template +template struct NestedMachineBinding { + StateEnum state; + NestedSMType* machine; + + constexpr bool operator==(const NestedMachineBinding&) const = default; +}; + +template struct is_nested_machine_binding : std::false_type {}; + +template +struct is_nested_machine_binding> : std::true_type {}; + +template +concept IsNestedMachineBinding = is_nested_machine_binding::value; + +namespace StateMachineHelper { + +template +static consteval auto add_nesting(const State& state, NestedSMType& machine) { + return NestedMachineBinding{state.get_state(), &machine}; +} + +template + requires(IsNestedMachineBinding && ...) +static consteval auto add_nested_machines(Bindings... bindings) { + constexpr std::size_t count = sizeof...(Bindings); + if constexpr (count > 1) { + auto states = std::array{bindings.state...}; + for (std::size_t i = 0; i < count; ++i) { + for (std::size_t j = i + 1; j < count; ++j) { + if (states[i] == states[j]) { + ErrorHandler("Duplicate state found in add_nested_machines"); + } + } + } + } + return std::make_tuple(bindings...); +} + +} // namespace StateMachineHelper + +template class StateMachine : public IStateMachine { -private: - struct NestedPair { - StateEnum state; - IStateMachine* machine; - constexpr bool operator==(const NestedPair&) const = default; - }; + static_assert((IsEnum), "StateEnum must be an enum type"); + static_assert( + (IsStateMachineClass && ...), + "All nested machines must be of type StateMachine" + ); - StateEnum current_state; + template friend class StateMachine; -public: - constexpr ~StateMachine() override = default; + StateEnum current_state; + std::tuple...> nested_machines; - void force_change_state(size_t state) override { - StateEnum new_state = static_cast(state); + void perform_state_change(StateEnum new_state) { if (current_state == new_state) { return; } + + exit(); + std::apply( + [this](auto&... nested) { + (void)((nested.state == this->current_state && nested.machine != nullptr + ? (nested.machine->exit(), true) + : false) || + ...); + }, + nested_machines + ); + #ifdef STLIB_ETH remove_state_orders(); #endif - exit(); current_state = new_state; enter(); + std::apply( + [this](auto&... nested) { + (void)((nested.state == this->current_state && nested.machine != nullptr + ? (nested.machine->enter(), true) + : false) || + ...); + }, + nested_machines + ); + #ifdef STLIB_ETH refresh_state_orders(); #endif } - size_t get_current_state_id() const override { return static_cast(current_state); } - - bool is_on = true; - void set_on(bool is_on) override { this->is_on = is_on; } - private: - FixedVector, NStates> states; - FixedVector, NTransitions> transitions = {}; + StaticVector, NStates> states; + StaticVector, NTransitions> transitions = {}; std::array, NStates> transitions_assoc = {}; - FixedVector nested_state_machine = {}; constexpr bool operator==(const StateMachine&) const = default; @@ -316,12 +389,17 @@ class StateMachine : public IStateMachine { public: template ... S> - consteval StateMachine(StateEnum initial_state, S... states) : current_state(initial_state) { - // Sort states by their enum value + consteval StateMachine( + StateEnum initial_state, + const std::tuple...>& nested_machines_tuple, + S... states_input + ) + : current_state(initial_state), nested_machines(nested_machines_tuple) { + using StateType = State; std::array sorted_states; size_t index = 0; - ((sorted_states[index++] = StateType(states)), ...); + ((sorted_states[index++] = StateType(states_input)), ...); for (size_t i = 0; i < sorted_states.size(); i++) { for (size_t j = 0; j < sorted_states.size() - 1; j++) { @@ -333,6 +411,15 @@ class StateMachine : public IStateMachine { } } + // Check for duplicate states + for (size_t i = 0; i < sorted_states.size() - 1; i++) { + for (size_t j = i + 1; j < sorted_states.size(); j++) { + if (sorted_states[i].get_state() == sorted_states[j].get_state()) { + ErrorHandler("Duplicate state found in StateMachine constructor"); + } + } + } + // Check that states are contiguous and start from 0 for (size_t i = 0; i < sorted_states.size(); i++) { if (static_cast(sorted_states[i].get_state()) != i) { @@ -350,82 +437,51 @@ class StateMachine : public IStateMachine { offset += s.get_transitions().size(); } } + constexpr ~StateMachine() override = default; void check_transitions() override { auto& [i, n] = transitions_assoc[static_cast(current_state)]; + for (auto index = i; index < i + n; ++index) { const auto& t = transitions[index]; if (t.predicate()) { - exit(); - for (auto& nested : nested_state_machine) { - if (nested.state == current_state) { - nested.machine->exit(); - break; - } - } -#ifdef STLIB_ETH - remove_state_orders(); -#endif - current_state = t.target; - enter(); - for (auto& nested : nested_state_machine) { - if (nested.state == current_state) { - nested.machine->enter(); - break; - } - } -#ifdef STLIB_ETH - refresh_state_orders(); -#endif + perform_state_change(t.target); break; } } - for (auto& nested : nested_state_machine) { - if (nested.state == current_state) { - nested.machine->check_transitions(); - break; - } - } + std::apply( + [this](auto&... nested) { + (void)((nested.state == this->current_state && nested.machine != nullptr + ? (nested.machine->check_transitions(), true) + : false) || + ...); + }, + nested_machines + ); } void start() override { enter(); - for (auto& nested : nested_state_machine) { - if (nested.state == current_state) { - nested.machine->start(); - break; - } - } + std::apply( + [this](auto&... nested) { + (void)((nested.state == this->current_state && nested.machine != nullptr + ? (nested.machine->start(), true) + : false) || + ...); + }, + nested_machines + ); } - template void force_change_state(const State& state) { - StateEnum new_state = state.get_state(); - if (current_state == new_state) { - return; - } + void force_change_state(size_t state) override { + perform_state_change(static_cast(state)); + } - exit(); - for (auto& nested : nested_state_machine) { - if (nested.state == current_state) { - nested.machine->exit(); - break; - } - } -#ifdef STLIB_ETH - remove_state_orders(); -#endif - current_state = new_state; - enter(); - for (auto& nested : nested_state_machine) { - if (nested.state == current_state) { - nested.machine->enter(); - break; - } - } -#ifdef STLIB_ETH - refresh_state_orders(); -#endif + size_t get_current_state_id() const override { return static_cast(current_state); } + + template void force_change_state(const State& state) { + perform_state_change(state.get_state()); } template @@ -473,22 +529,6 @@ class StateMachine : public IStateMachine { ErrorHandler("Error: The state is not added to the state machine"); } - template - constexpr void - add_state_machine(IStateMachine& state_machine, const State& state) { - for (auto& nested : nested_state_machine) { - if (nested.state == state.get_state()) { - ErrorHandler( - "Only one Nested State Machine can be added per state, tried to add to state: " - "%d", - static_cast(state.get_state()) - ); - return; - } - } - nested_state_machine.push_back({state.get_state(), &state_machine}); - } - StateEnum get_current_state() const { return current_state; } inline void refresh_state_orders() { @@ -552,6 +592,39 @@ consteval auto make_state_machine(StateEnum initial_state, States... states) { return StateMachine( initial_state, + std::tuple<>{}, + states... + ); +} +/* @brief Helper function to create a StateMachine instance + * + * @tparam States Variadic template parameter pack representing the states + * @param initial_state The initial state enum value + * @param nested_machines Tuple of NestedMachineBinding representing the nested state machines to + * its corresponding state + * @param states The states to be included in the state machine + * @return A StateMachine instance initialized with the provided initial state and states, as well + * as the nested state machines + */ + +template + requires are_states +consteval auto make_state_machine( + StateEnum initial_state, + std::tuple nested_machines, + States... states +) { + constexpr size_t number_of_states = sizeof...(states); + constexpr size_t number_of_transitions = + (std::remove_reference_t::transition_count + ... + 0); + + return StateMachine< + StateEnum, + number_of_states, + number_of_transitions, + typename std::remove_pointer().machine)>::type...>( + initial_state, + nested_machines, states... ); } diff --git a/Tests/CMakeLists.txt b/Tests/CMakeLists.txt index 0c882b8a..5262628e 100644 --- a/Tests/CMakeLists.txt +++ b/Tests/CMakeLists.txt @@ -21,6 +21,7 @@ add_executable(${STLIB_TEST_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/../Src/HALAL/Models/DMA/DMA2.cpp ${CMAKE_CURRENT_LIST_DIR}/Time/scheduler_test.cpp ${CMAKE_CURRENT_LIST_DIR}/Time/timer_wrapper_test.cpp + ${CMAKE_CURRENT_LIST_DIR}/StateMachine/state_machine_test.cpp ${CMAKE_CURRENT_LIST_DIR}/adc_test.cpp ${CMAKE_CURRENT_LIST_DIR}/spi2_test.cpp ${CMAKE_CURRENT_LIST_DIR}/dma2_test.cpp diff --git a/Tests/StateMachine/state_machine_test.cpp b/Tests/StateMachine/state_machine_test.cpp new file mode 100644 index 00000000..c05ff95b --- /dev/null +++ b/Tests/StateMachine/state_machine_test.cpp @@ -0,0 +1,280 @@ +#include +#include "ST-LIB_LOW/StateMachine/StateMachine.hpp" +#include "HALAL/Services/Time/Scheduler.hpp" + +enum class MasterState { A, B, C }; + +enum class SubState { S1, S2 }; + +static int a_enter_count = 0; +static int a_exit_count = 0; +static int a_cyclic_count = 0; + +static int b_enter_count = 0; +static int b_exit_count = 0; +static int b_cyclic_count = 0; + +static int c_enter_count = 0; +static int c_exit_count = 0; + +static int s1_enter_count = 0; +static int s1_exit_count = 0; +static int s1_cyclic_count = 0; + +static int s2_enter_count = 0; +static int s2_exit_count = 0; + +static bool condition_a_to_b = false; +static bool condition_b_to_c = false; +static bool condition_c_to_a = false; +static bool condition_s1_to_s2 = false; + +static void reset_test_state() { + a_enter_count = 0; + a_exit_count = 0; + a_cyclic_count = 0; + + b_enter_count = 0; + b_exit_count = 0; + b_cyclic_count = 0; + + c_enter_count = 0; + c_exit_count = 0; + + s1_enter_count = 0; + s1_exit_count = 0; + s1_cyclic_count = 0; + + s2_enter_count = 0; + s2_exit_count = 0; + + condition_a_to_b = false; + condition_b_to_c = false; + condition_c_to_a = false; + condition_s1_to_s2 = false; + + Scheduler::active_task_count_ = 0; + Scheduler::free_bitmap_ = 0xFFFF'FFFF; + Scheduler::ready_bitmap_ = 0; + Scheduler::sorted_task_ids_ = 0; + Scheduler::global_tick_us_ = 0; + Scheduler::current_interval_us_ = 0; + + TIM2_BASE->CNT = 0; + TIM2_BASE->ARR = 0; + TIM2_BASE->SR = 0; + TIM2_BASE->CR1 = 0; + TIM2_BASE->DIER = 0; +} + +static void tick_scheduler(int ticks) { + TIM2_BASE->PSC = 2; + for (int i = 0; i < ticks; i++) { + for (int j = 0; j <= TIM2_BASE->PSC; j++) + TIM2_BASE->inc_cnt_and_check(1); + Scheduler::update(); + } +} + +static constexpr auto state_s1 = + make_state(SubState::S1, Transition{SubState::S2, []() { + return condition_s1_to_s2; + }}); +static constexpr auto state_s2 = make_state(SubState::S2); + +static inline auto test_nested_machine = []() consteval { + auto sm = make_state_machine(SubState::S1, state_s1, state_s2); + using namespace std::chrono_literals; + + sm.add_enter_action([]() { s1_enter_count++; }, state_s1); + sm.add_exit_action([]() { s1_exit_count++; }, state_s1); + sm.add_cyclic_action([]() { s1_cyclic_count++; }, 10ms, state_s1); + + sm.add_enter_action([]() { s2_enter_count++; }, state_s2); + sm.add_exit_action([]() { s2_exit_count++; }, state_s2); + + return sm; +}(); + +static constexpr auto state_a = + make_state(MasterState::A, Transition{MasterState::B, []() { + return condition_a_to_b; + }}); +static constexpr auto state_b = + make_state(MasterState::B, Transition{MasterState::C, []() { + return condition_b_to_c; + }}); +static constexpr auto state_c = + make_state(MasterState::C, Transition{MasterState::A, []() { + return condition_c_to_a; + }}); + +static inline auto test_machine = []() consteval { + auto nested = StateMachineHelper::add_nesting(state_b, test_nested_machine); + auto sm = make_state_machine( + MasterState::A, + StateMachineHelper::add_nested_machines(nested), + state_a, + state_b, + state_c + ); + using namespace std::chrono_literals; + + sm.add_enter_action([]() { a_enter_count++; }, state_a); + sm.add_exit_action([]() { a_exit_count++; }, state_a); + sm.add_cyclic_action([]() { a_cyclic_count++; }, 10ms, state_a); + + sm.add_enter_action([]() { b_enter_count++; }, state_b); + sm.add_exit_action([]() { b_exit_count++; }, state_b); + sm.add_cyclic_action([]() { b_cyclic_count++; }, 20ms, state_b); + + sm.add_enter_action([]() { c_enter_count++; }, state_c); + sm.add_exit_action([]() { c_exit_count++; }, state_c); + + return sm; +}(); + +class StateMachineTest : public ::testing::Test { +protected: + void SetUp() override { + // Reset everything before tests + reset_test_state(); + + test_machine.force_change_state((size_t)MasterState::A); + test_nested_machine.force_change_state((size_t)SubState::S1); + + test_machine.force_change_state((size_t)MasterState::A); + test_nested_machine.force_change_state((size_t)SubState::S1); + test_machine.get_states()[0].unregister_all_timed_actions(); + test_machine.get_states()[1].unregister_all_timed_actions(); + test_machine.get_states()[2].unregister_all_timed_actions(); + test_nested_machine.get_states()[0].unregister_all_timed_actions(); + test_nested_machine.get_states()[1].unregister_all_timed_actions(); + + reset_test_state(); + } +}; + +TEST_F(StateMachineTest, StartTriggersEnterActions) { + test_machine.start(); + EXPECT_EQ(a_enter_count, 1); + EXPECT_EQ(a_exit_count, 0); + EXPECT_EQ(b_enter_count, 0); + EXPECT_EQ(test_machine.get_current_state(), MasterState::A); +} + +TEST_F(StateMachineTest, BasicTransition) { + test_machine.start(); + a_enter_count = 0; + + condition_a_to_b = true; + test_machine.check_transitions(); + + EXPECT_EQ(test_machine.get_current_state(), MasterState::B); + EXPECT_EQ(a_exit_count, 1); + EXPECT_EQ(b_enter_count, 1); + EXPECT_EQ(s1_enter_count, 1); // Nested machine should also enter its initial state +} + +TEST_F(StateMachineTest, NestedTransition) { + test_machine.start(); + condition_a_to_b = true; + test_machine.check_transitions(); + + EXPECT_EQ(test_machine.get_current_state(), MasterState::B); + EXPECT_EQ(test_nested_machine.get_current_state(), SubState::S1); + + condition_s1_to_s2 = true; + test_machine.check_transitions(); + + EXPECT_EQ(test_nested_machine.get_current_state(), SubState::S2); + EXPECT_EQ(s1_exit_count, 1); + EXPECT_EQ(s2_enter_count, 1); +} + +TEST_F(StateMachineTest, MasterStateChangeExitsNested) { + test_machine.start(); + condition_a_to_b = true; + test_machine.check_transitions(); + + EXPECT_EQ(test_nested_machine.get_current_state(), SubState::S1); + s1_enter_count = 0; + b_exit_count = 0; + + condition_b_to_c = true; + test_machine.check_transitions(); + + EXPECT_EQ(test_machine.get_current_state(), MasterState::C); + EXPECT_EQ(b_exit_count, 1); + EXPECT_EQ(s1_exit_count, 1); + EXPECT_EQ(c_enter_count, 1); +} + +TEST_F(StateMachineTest, CyclicActionsRun) { + test_machine.start(); + + Scheduler::start(); + + tick_scheduler(100); + tick_scheduler(10000); + + EXPECT_GE(a_cyclic_count, 1); + EXPECT_EQ(b_cyclic_count, 0); + + condition_a_to_b = true; + test_machine.check_transitions(); + + a_cyclic_count = 0; + + tick_scheduler(20000); + + EXPECT_EQ(a_cyclic_count, 0); // A cyclic shouldn't run + EXPECT_GE(b_cyclic_count, 1); // B cyclic should run + EXPECT_GE(s1_cyclic_count, 1); // Nested S1 cyclic should run +} + +template struct constant_eval {}; + +template +concept CanCompile = requires { typename constant_eval; }; + +struct DuplicateNestedCheck { + static consteval bool invoke() { + auto sm_nested_1 = make_state_machine(SubState::S1, state_s1, state_s2); + auto sm_nested_2 = make_state_machine(SubState::S1, state_s1, state_s2); + + auto nested1 = StateMachineHelper::add_nesting(state_a, sm_nested_1); + auto nested2 = StateMachineHelper::add_nesting(state_a, sm_nested_2); + + auto sm = make_state_machine( + MasterState::A, + StateMachineHelper::add_nested_machines(nested1, nested2), + state_a, + state_b + ); + return true; + } +}; + +struct ValidNestedCheck { + static consteval bool invoke() { + auto sm_nested_1 = make_state_machine(SubState::S1, state_s1, state_s2); + auto sm_nested_2 = make_state_machine(SubState::S1, state_s1, state_s2); + + auto nested1 = StateMachineHelper::add_nesting(state_a, sm_nested_1); + auto nested2 = StateMachineHelper::add_nesting(state_b, sm_nested_2); + + auto sm = make_state_machine( + MasterState::A, + StateMachineHelper::add_nested_machines(nested1, nested2), + state_a, + state_b + ); + return true; + } +}; + +TEST(StateMachineCompileCheck, ValidatesSFINAEOntoS_M) { + static_assert(CanCompile, "Valid nested mapping should compile."); + static_assert(!CanCompile, "Duplicate state mappings must not compile."); +}