From 63c8318d7bff3b18418a51ffe7a2c452e1128ff6 Mon Sep 17 00:00:00 2001 From: Bogdan Vlahov Date: Mon, 12 May 2025 11:57:27 -0400 Subject: [PATCH 1/2] Fix base plant logic issues and unit tests The Base Plant logic was originally based upon timestamps were initialized at -1 to indicate no data had been updated. This would cause issues when combined with ROS so we changed the initialization value of timestamps to 0. This commit fixes the resulting logic bugs that relied on timestamps of -1 to determine whether or not to run various code. Basic things include no calls to pubControl() until there a control sequence has been calculated, and waiting for a new state logic. - Change explicit calls to exit() into throwing exceptions - Existing unit tests have been updated - New unit tests written to explicitly test the new logic adjustments. --- include/mppi/controllers/controller.cuh | 20 +++--- include/mppi/core/base_plant.hpp | 30 +++++---- .../mppi_test/mock_classes/mock_controller.h | 5 +- tests/mppi_core/base_plant_tester.cu | 65 +++++++++++++++++-- 4 files changed, 89 insertions(+), 31 deletions(-) diff --git a/include/mppi/controllers/controller.cuh b/include/mppi/controllers/controller.cuh index c356d107..9beb8b0d 100644 --- a/include/mppi/controllers/controller.cuh +++ b/include/mppi/controllers/controller.cuh @@ -326,8 +326,9 @@ public: * @param rel_time * @return */ - virtual control_array getCurrentControl(state_array& state, double rel_time, state_array& target_nominal_state, - control_trajectory& c_traj, TEMPLATED_FEEDBACK_STATE& fb_state) + virtual control_array getCurrentControl(Eigen::Ref state, double rel_time, + Eigen::Ref target_nominal_state, + Eigen::Ref c_traj, TEMPLATED_FEEDBACK_STATE& fb_state) { // MPPI control control_array u_ff = interpolateControls(rel_time, c_traj); @@ -348,7 +349,8 @@ public: * @param steps - number of dt's to slide control sequence forward * Slide the control sequence forwards by 'steps' */ - virtual void slideControlSequence(int steps) { + virtual void slideControlSequence(int steps) + { // Save the control history this->saveControlHistoryHelper(steps, this->control_, this->control_history_); @@ -360,7 +362,7 @@ public: * @param rel_time time since the solution was calculated * @return */ - virtual control_array interpolateControls(double rel_time, control_trajectory& c_traj) + virtual control_array interpolateControls(double rel_time, Eigen::Ref c_traj) { int lower_idx = (int)(rel_time / getDt()); int upper_idx = lower_idx + 1; @@ -370,14 +372,10 @@ public: control_array prev_cmd = c_traj.col(lower_idx); control_array next_cmd = c_traj.col(upper_idx); interpolated_control = (1 - alpha) * prev_cmd + alpha * next_cmd; - - // printf("prev: %d %f, %f\n", lower_idx, prev_cmd[0], prev_cmd[1]); - // printf("next: %d %f, %f\n", upper_idx, next_cmd[0], next_cmd[1]); - // printf("smoother: %f\n", alpha); return interpolated_control; } - virtual state_array interpolateState(state_trajectory& s_traj, double rel_time) + virtual state_array interpolateState(Eigen::Ref s_traj, double rel_time) { int lower_idx = (int)(rel_time / getDt()); int upper_idx = lower_idx + 1; @@ -392,8 +390,8 @@ public: * @param rel_time * @return */ - virtual control_array interpolateFeedback(state_array& state, state_array& target_nominal_state, double rel_time, - TEMPLATED_FEEDBACK_STATE& fb_state) + virtual control_array interpolateFeedback(Eigen::Ref state, Eigen::Ref target_nominal_state, + double rel_time, TEMPLATED_FEEDBACK_STATE& fb_state) { return fb_controller_->interpolateFeedback_(state, target_nominal_state, rel_time, fb_state); } diff --git a/include/mppi/core/base_plant.hpp b/include/mppi/core/base_plant.hpp index 5203f97e..5a9ca886 100644 --- a/include/mppi/core/base_plant.hpp +++ b/include/mppi/core/base_plant.hpp @@ -65,7 +65,7 @@ class BasePlant std::atomic has_new_dynamics_params_{ false }; std::atomic has_new_cost_params_{ false }; std::atomic has_new_controller_params_{ false }; - std::atomic enabled_{ false }; + std::atomic has_received_state_{ false }; // Values needed s_array init_state_ = s_array::Zero(); @@ -286,7 +286,7 @@ class BasePlant * @param state the most recent state from state estimator * @param time the time of the most recent state from the state estimator */ - virtual void updateState(s_array& state, double time) + virtual void updateState(Eigen::Ref state, double time) { // calculate and update all timing variables double temp_last_state_update_time = last_used_state_update_time_; @@ -295,8 +295,9 @@ class BasePlant state_ = state; state_time_ = time; + has_received_state_ = true; - if (last_used_state_update_time_ < 0) + if (num_iter_ == 0) { // we have not optimized yet so no reason to publish controls return; @@ -446,13 +447,16 @@ class BasePlant double temp_last_state_time = getStateTime(); double temp_last_used_state_update_time = last_used_state_update_time_; + // If it is the first iteration and we have received state, we should not wait for timestamps to differ + bool skip_first_loop = num_iter_ == 0 && has_received_state_; + // wait for a new state to compute control sequence from - int counter = 0; - while (temp_last_used_state_update_time == temp_last_state_time && is_alive->load()) + while (temp_last_used_state_update_time == temp_last_state_time && !skip_first_loop && is_alive->load()) { usleep(50); temp_last_state_time = getStateTime(); - counter++; + // In case when runControlIteration is ran before getting state and state time is specifically 0 + skip_first_loop = num_iter_ == 0 && has_received_state_; } if (!is_alive->load()) { @@ -487,7 +491,7 @@ class BasePlant // calculate how much we should slide the control sequence double dt = temp_last_state_time - temp_last_used_state_update_time; - if (temp_last_used_state_update_time == -1) + if (num_iter_ == 0) { // // should only happen on the first iteration dt = 0; @@ -518,21 +522,21 @@ class BasePlant { std::cerr << "ERROR: Nan in control inside plant" << std::endl; std::cerr << control_traj << std::endl; - exit(-1); + throw std::runtime_error("Control Trajectory inside plant has a NaN"); } s_traj state_traj = controller_->getTargetStateSeq(); if (!state_traj.allFinite()) { std::cerr << "ERROR: Nan in state inside plant" << std::endl; std::cerr << state_traj << std::endl; - exit(-1); + throw std::runtime_error("State Trajectory inside plant has a NaN"); } o_traj output_traj = controller_->getTargetOutputSeq(); - if (!state_traj.allFinite()) + if (!output_traj.allFinite()) { - std::cerr << "ERROR: Nan in state inside plant" << std::endl; - std::cerr << state_traj << std::endl; - exit(-1); + std::cerr << "ERROR: Nan in output inside plant" << std::endl; + std::cerr << output_traj << std::endl; + throw std::runtime_error("Output Trajectory inside plant has a NaN"); } optimization_duration_ = mppi::math::timeDiffms(std::chrono::steady_clock::now(), optimization_start); diff --git a/tests/include/mppi_test/mock_classes/mock_controller.h b/tests/include/mppi_test/mock_classes/mock_controller.h index 3d6f423e..ff1029ce 100644 --- a/tests/include/mppi_test/mock_classes/mock_controller.h +++ b/tests/include/mppi_test/mock_classes/mock_controller.h @@ -19,11 +19,12 @@ class MockController MOCK_METHOD0(resetControls, void()); MOCK_METHOD1(computeFeedback, void(const Eigen::Ref& state)); MOCK_METHOD1(slideControlSequence, void(int stride)); - MOCK_METHOD5(getCurrentControl, - control_array(state_array&, double, state_array&, control_trajectory&, TEMPLATED_FEEDBACK_STATE&)); + MOCK_METHOD5(getCurrentControl, control_array(Eigen::Ref, double, Eigen::Ref, + Eigen::Ref, TEMPLATED_FEEDBACK_STATE&)); MOCK_METHOD2(computeControl, void(const Eigen::Ref& state, int optimization_stride)); MOCK_METHOD(control_trajectory, getControlSeq, (), (const, override)); MOCK_METHOD(state_trajectory, getTargetStateSeq, (), (const, override)); + // MOCK_METHOD(output_trajectory, getTargetOutputSeq, (), (const, override)); MOCK_METHOD(TEMPLATED_FEEDBACK_STATE, getFeedbackState, (), (const, override)); MOCK_METHOD(control_array, getFeedbackControl, (const Eigen::Ref&, const Eigen::Ref&, int), (override)); diff --git a/tests/mppi_core/base_plant_tester.cu b/tests/mppi_core/base_plant_tester.cu index ef189666..8e5c908f 100644 --- a/tests/mppi_core/base_plant_tester.cu +++ b/tests/mppi_core/base_plant_tester.cu @@ -181,7 +181,7 @@ TEST_F(BasePlantTest, Constructor) EXPECT_EQ(plant->getHz(), 20); EXPECT_EQ(plant->getTargetOptimizationStride(), 1); EXPECT_EQ(plant->getNumIter(), 0); - EXPECT_EQ(plant->getLastUsedPoseUpdateTime(), -1); + EXPECT_EQ(plant->getLastUsedPoseUpdateTime(), 0); EXPECT_EQ(plant->getStatus(), 1); EXPECT_EQ(mockController->getFeedbackEnabled(), false); EXPECT_EQ(plant->hasNewCostParams(), false); @@ -275,7 +275,6 @@ TEST_F(BasePlantTest, updateParametersAllTrue) TEST_F(BasePlantTest, updateStateOutsideTimeTest) { - mockController->setDt(DT); plant->setLastTime(0); EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); @@ -293,7 +292,6 @@ TEST_F(BasePlantTest, updateStateOutsideTimeTest) TEST_F(BasePlantTest, updateStateTest) { - mockController->setDt(DT); plant->setLastTime(0); MockController::state_array state = MockController::state_array::Zero(); @@ -303,16 +301,73 @@ TEST_F(BasePlantTest, updateStateTest) EXPECT_EQ(plant->pubControlCalled, 0); EXPECT_EQ(plant->pubNominalStateCalled, 0); - EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(1); + // Calling updateState() should not pub controls when none have been calculated as of yet + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->setLastUsedTime(11); plant->updateState(state, 12); EXPECT_EQ(plant->getState(), state); - EXPECT_EQ(plant->pubControlCalled, 1); + EXPECT_EQ(plant->pubControlCalled, 0); EXPECT_EQ(plant->pubNominalStateCalled, 0); // TODO in debug should pub nominal state } +TEST_F(BasePlantTest, pubControlOnlyAfterControlAreCalculatedTest) +{ + ::testing::Sequence s1; + // Step 1: calling updateState() before controls are calculated should not call controller->getCurrentControl() + MockController::state_array state = MockController::state_array::Zero(); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0).InSequence(s1); + double curr_time = 0.0; + plant->updateState(state, curr_time); + EXPECT_EQ(plant->pubControlCalled, 0); + EXPECT_EQ(plant->pubNominalStateCalled, 0); + // ::testing::Mock::VerifyAndClearExpectations(mockController.get()); + + // Step 2: run control iteration inside plant + // Create valid outputs from gmock methods to prevent nan detection from triggering + MockController::control_trajectory valid_control_seq = MockController::control_trajectory::Zero(MockDynamics::CONTROL_DIM, NUM_TIMESTEPS); + MockController::state_trajectory valid_state_seq = MockController::state_trajectory::Zero(MockDynamics::STATE_DIM, NUM_TIMESTEPS); + EXPECT_CALL(*mockController, computeControl(testing::_, testing::_)).Times(1).InSequence(s1); + EXPECT_CALL(*mockController, getControlSeq()).Times(1).WillRepeatedly(testing::Return(valid_control_seq)); + EXPECT_CALL(*mockController, getTargetStateSeq()).Times(1).WillRepeatedly(testing::Return(valid_state_seq)); + // EXPECT_CALL(*mockController, getTargetOutputSeq()).Times(1); + std::atomic is_alive(true); + plant->runControlIteration(&is_alive); + + // Step 3: calling updateState() now should use controller->getCurrentControl() + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(1).InSequence(s1); + curr_time++; + plant->updateState(state, curr_time); + EXPECT_EQ(plant->pubControlCalled, 1); + EXPECT_EQ(plant->pubNominalStateCalled, 0); +} + +TEST_F(BasePlantTest, EnsureReceivingStateCompletesRunControlIterationTest) +{ + std::atomic is_alive(true); + // Create valid outputs from gmock methods to prevent nan detection from triggering + MockController::control_trajectory valid_control_seq = MockController::control_trajectory::Zero(MockDynamics::CONTROL_DIM, NUM_TIMESTEPS); + MockController::state_trajectory valid_state_seq = MockController::state_trajectory::Zero(MockDynamics::STATE_DIM, NUM_TIMESTEPS); + EXPECT_CALL(*mockController, computeControl(testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockController, getControlSeq()).Times(1).WillRepeatedly(testing::Return(valid_control_seq)); + EXPECT_CALL(*mockController, getTargetStateSeq()).Times(1).WillRepeatedly(testing::Return(valid_state_seq)); + // EXPECT_CALL(*mockController, getTargetOutputSeq()).Times(1); + std::thread new_thread(&MockTestPlant::runControlIteration, plant.get(), &is_alive); + // Wait some period of time and then call updateState() + std::cout << "Wait for new state" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); + MockController::state_array state = MockController::state_array::Zero(); + double curr_time = 0.0; + plant->updateState(state, curr_time); + std::cout << "State sent" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // is_alive.store(false); + new_thread.join(); + +} + TEST_F(BasePlantTest, runControlIterationStoppedTest) { EXPECT_CALL(*mockController, slideControlSequence(testing::_)).Times(0); From 8e3fcccfbffa9193acec4a7a6569348f89f93ac4 Mon Sep 17 00:00:00 2001 From: Bogdan Vlahov Date: Mon, 12 May 2025 12:06:27 -0400 Subject: [PATCH 2/2] Fix BufferedPlant tests to new BasePlant logic - Add get/set methods for buffer_dt value - Remove template instantiation for specific buffer types as this can cause "target already defined" errors --- include/mppi/core/buffered_plant.hpp | 14 ++++++++---- tests/mppi_core/buffered_plant_tester.cu | 27 +++++++++++++++++------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/include/mppi/core/buffered_plant.hpp b/include/mppi/core/buffered_plant.hpp index 7039961e..fe4ec405 100644 --- a/include/mppi/core/buffered_plant.hpp +++ b/include/mppi/core/buffered_plant.hpp @@ -75,6 +75,16 @@ class BufferedPlant : public BasePlant buffer_.clearBuffers(); } + double getBufferDt() const + { + return buffer_dt_; + } + + void setBufferDt(const double buff_dt) + { + buffer_dt_ = buff_dt; + } + protected: Buffer buffer_; @@ -83,8 +93,4 @@ class BufferedPlant : public BasePlant double buffer_dt_ = 0.02; // the spacing between well sampled buffer positions }; -template class BufferMessage; -template class BufferMessage; -template class BufferMessage; - #endif // MPPIGENERIC_BUFFERED_PLANT_H diff --git a/tests/mppi_core/buffered_plant_tester.cu b/tests/mppi_core/buffered_plant_tester.cu index 2fc02c7f..17e3db05 100644 --- a/tests/mppi_core/buffered_plant_tester.cu +++ b/tests/mppi_core/buffered_plant_tester.cu @@ -122,10 +122,6 @@ public: { return this->buffer_tau_; } - double getBufferDt() - { - return this->buffer_dt_; - } void setLastUsedUpdateTime(double time) { this->last_used_state_update_time_ = time; @@ -198,6 +194,13 @@ TEST_F(BufferedPlantTest, Constructor) EXPECT_FLOAT_EQ(plant->getBufferDt(), 0.02); } +TEST_F(BufferedPlantTest, setBufferDt) +{ + double new_buffer_dt = 30.0; + plant->setBufferDt(new_buffer_dt); + EXPECT_FLOAT_EQ(plant->getBufferDt(), new_buffer_dt); +} + TEST_F(BufferedPlantTest, interpNew) { Eigen::Vector3f pos = Eigen::Vector3f::Ones(); @@ -208,7 +211,9 @@ TEST_F(BufferedPlantTest, interpNew) MockDynamics::state_array state = MockDynamics::state_array::Random(); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->setLastUsedUpdateTime(0); plant->updateOdometry(pos, quat, vel, omega, 0.0); @@ -380,7 +385,9 @@ TEST_F(BufferedPlantTest, updateOdometry) MockDynamics::state_array state = MockDynamics::state_array::Random(); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->setLastUsedUpdateTime(0.0); plant->updateOdometry(pos, quat, vel, omega, 0.0); @@ -448,7 +455,9 @@ TEST_F(BufferedPlantTest, getInterpState) plant->setLastUsedUpdateTime(0.0); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->updateOdometry(pos, quat, vel, omega, 0.0); plant->updateControls(u, 0.0); @@ -509,7 +518,9 @@ TEST_F(BufferedPlantTest, getInterpBuffer) plant->setLastUsedUpdateTime(0.0); EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state)); - EXPECT_CALL(*mockController, getDt()).Times(2); + // Controls never calculated so no calls to controller in updateState() + EXPECT_CALL(*mockController, getDt()).Times(0); + EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0); plant->updateOdometry(pos, quat, vel, omega, 0.0); plant->updateControls(u, 0.0);