Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions include/mppi/controllers/controller.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,9 @@ public:
* @param rel_time
* @return
*/
virtual control_array getCurrentControl(state_array& state, double rel_time, state_array& target_nominal_state,
control_trajectory& c_traj, TEMPLATED_FEEDBACK_STATE& fb_state)
virtual control_array getCurrentControl(Eigen::Ref<state_array> state, double rel_time,
Eigen::Ref<state_array> target_nominal_state,
Eigen::Ref<control_trajectory> c_traj, TEMPLATED_FEEDBACK_STATE& fb_state)
{
// MPPI control
control_array u_ff = interpolateControls(rel_time, c_traj);
Expand All @@ -348,7 +349,8 @@ public:
* @param steps - number of dt's to slide control sequence forward
* Slide the control sequence forwards by 'steps'
*/
virtual void slideControlSequence(int steps) {
virtual void slideControlSequence(int steps)
{
// Save the control history
this->saveControlHistoryHelper(steps, this->control_, this->control_history_);

Expand All @@ -360,7 +362,7 @@ public:
* @param rel_time time since the solution was calculated
* @return
*/
virtual control_array interpolateControls(double rel_time, control_trajectory& c_traj)
virtual control_array interpolateControls(double rel_time, Eigen::Ref<control_trajectory> c_traj)
{
int lower_idx = (int)(rel_time / getDt());
int upper_idx = lower_idx + 1;
Expand All @@ -370,14 +372,10 @@ public:
control_array prev_cmd = c_traj.col(lower_idx);
control_array next_cmd = c_traj.col(upper_idx);
interpolated_control = (1 - alpha) * prev_cmd + alpha * next_cmd;

// printf("prev: %d %f, %f\n", lower_idx, prev_cmd[0], prev_cmd[1]);
// printf("next: %d %f, %f\n", upper_idx, next_cmd[0], next_cmd[1]);
// printf("smoother: %f\n", alpha);
return interpolated_control;
}

virtual state_array interpolateState(state_trajectory& s_traj, double rel_time)
virtual state_array interpolateState(Eigen::Ref<state_trajectory> s_traj, double rel_time)
{
int lower_idx = (int)(rel_time / getDt());
int upper_idx = lower_idx + 1;
Expand All @@ -392,8 +390,8 @@ public:
* @param rel_time
* @return
*/
virtual control_array interpolateFeedback(state_array& state, state_array& target_nominal_state, double rel_time,
TEMPLATED_FEEDBACK_STATE& fb_state)
virtual control_array interpolateFeedback(Eigen::Ref<state_array> state, Eigen::Ref<state_array> target_nominal_state,
double rel_time, TEMPLATED_FEEDBACK_STATE& fb_state)
{
return fb_controller_->interpolateFeedback_(state, target_nominal_state, rel_time, fb_state);
}
Expand Down
30 changes: 17 additions & 13 deletions include/mppi/core/base_plant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class BasePlant
std::atomic<bool> has_new_dynamics_params_{ false };
std::atomic<bool> has_new_cost_params_{ false };
std::atomic<bool> has_new_controller_params_{ false };
std::atomic<bool> enabled_{ false };
std::atomic<bool> has_received_state_{ false };

// Values needed
s_array init_state_ = s_array::Zero();
Expand Down Expand Up @@ -286,7 +286,7 @@ class BasePlant
* @param state the most recent state from state estimator
* @param time the time of the most recent state from the state estimator
*/
virtual void updateState(s_array& state, double time)
virtual void updateState(Eigen::Ref<s_array> state, double time)
{
// calculate and update all timing variables
double temp_last_state_update_time = last_used_state_update_time_;
Expand All @@ -295,8 +295,9 @@ class BasePlant

state_ = state;
state_time_ = time;
has_received_state_ = true;

if (last_used_state_update_time_ < 0)
if (num_iter_ == 0)
{
// we have not optimized yet so no reason to publish controls
return;
Expand Down Expand Up @@ -446,13 +447,16 @@ class BasePlant
double temp_last_state_time = getStateTime();
double temp_last_used_state_update_time = last_used_state_update_time_;

// If it is the first iteration and we have received state, we should not wait for timestamps to differ
bool skip_first_loop = num_iter_ == 0 && has_received_state_;

// wait for a new state to compute control sequence from
int counter = 0;
while (temp_last_used_state_update_time == temp_last_state_time && is_alive->load())
while (temp_last_used_state_update_time == temp_last_state_time && !skip_first_loop && is_alive->load())
{
usleep(50);
temp_last_state_time = getStateTime();
counter++;
// In case when runControlIteration is ran before getting state and state time is specifically 0
skip_first_loop = num_iter_ == 0 && has_received_state_;
}
if (!is_alive->load())
{
Expand Down Expand Up @@ -487,7 +491,7 @@ class BasePlant

// calculate how much we should slide the control sequence
double dt = temp_last_state_time - temp_last_used_state_update_time;
if (temp_last_used_state_update_time == -1)
if (num_iter_ == 0)
{ //
// should only happen on the first iteration
dt = 0;
Expand Down Expand Up @@ -518,21 +522,21 @@ class BasePlant
{
std::cerr << "ERROR: Nan in control inside plant" << std::endl;
std::cerr << control_traj << std::endl;
exit(-1);
throw std::runtime_error("Control Trajectory inside plant has a NaN");
}
s_traj state_traj = controller_->getTargetStateSeq();
if (!state_traj.allFinite())
{
std::cerr << "ERROR: Nan in state inside plant" << std::endl;
std::cerr << state_traj << std::endl;
exit(-1);
throw std::runtime_error("State Trajectory inside plant has a NaN");
}
o_traj output_traj = controller_->getTargetOutputSeq();
if (!state_traj.allFinite())
if (!output_traj.allFinite())
{
std::cerr << "ERROR: Nan in state inside plant" << std::endl;
std::cerr << state_traj << std::endl;
exit(-1);
std::cerr << "ERROR: Nan in output inside plant" << std::endl;
std::cerr << output_traj << std::endl;
throw std::runtime_error("Output Trajectory inside plant has a NaN");
}
optimization_duration_ = mppi::math::timeDiffms(std::chrono::steady_clock::now(), optimization_start);

Expand Down
14 changes: 10 additions & 4 deletions include/mppi/core/buffered_plant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ class BufferedPlant : public BasePlant<CONTROLLER_T>
buffer_.clearBuffers();
}

double getBufferDt() const
{
return buffer_dt_;
}

void setBufferDt(const double buff_dt)
{
buffer_dt_ = buff_dt;
}

protected:
Buffer<typename CONTROLLER_T::TEMPLATED_DYNAMICS> buffer_;

Expand All @@ -83,8 +93,4 @@ class BufferedPlant : public BasePlant<CONTROLLER_T>
double buffer_dt_ = 0.02; // the spacing between well sampled buffer positions
};

template class BufferMessage<Eigen::Vector3f>;
template class BufferMessage<Eigen::Quaternionf>;
template class BufferMessage<float>;

#endif // MPPIGENERIC_BUFFERED_PLANT_H
5 changes: 3 additions & 2 deletions tests/include/mppi_test/mock_classes/mock_controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ class MockController
MOCK_METHOD0(resetControls, void());
MOCK_METHOD1(computeFeedback, void(const Eigen::Ref<const state_array>& state));
MOCK_METHOD1(slideControlSequence, void(int stride));
MOCK_METHOD5(getCurrentControl,
control_array(state_array&, double, state_array&, control_trajectory&, TEMPLATED_FEEDBACK_STATE&));
MOCK_METHOD5(getCurrentControl, control_array(Eigen::Ref<state_array>, double, Eigen::Ref<state_array>,
Eigen::Ref<control_trajectory>, TEMPLATED_FEEDBACK_STATE&));
MOCK_METHOD2(computeControl, void(const Eigen::Ref<const state_array>& state, int optimization_stride));
MOCK_METHOD(control_trajectory, getControlSeq, (), (const, override));
MOCK_METHOD(state_trajectory, getTargetStateSeq, (), (const, override));
// MOCK_METHOD(output_trajectory, getTargetOutputSeq, (), (const, override));
MOCK_METHOD(TEMPLATED_FEEDBACK_STATE, getFeedbackState, (), (const, override));
MOCK_METHOD(control_array, getFeedbackControl,
(const Eigen::Ref<const state_array>&, const Eigen::Ref<const state_array>&, int), (override));
Expand Down
65 changes: 60 additions & 5 deletions tests/mppi_core/base_plant_tester.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ TEST_F(BasePlantTest, Constructor)
EXPECT_EQ(plant->getHz(), 20);
EXPECT_EQ(plant->getTargetOptimizationStride(), 1);
EXPECT_EQ(plant->getNumIter(), 0);
EXPECT_EQ(plant->getLastUsedPoseUpdateTime(), -1);
EXPECT_EQ(plant->getLastUsedPoseUpdateTime(), 0);
EXPECT_EQ(plant->getStatus(), 1);
EXPECT_EQ(mockController->getFeedbackEnabled(), false);
EXPECT_EQ(plant->hasNewCostParams(), false);
Expand Down Expand Up @@ -275,7 +275,6 @@ TEST_F(BasePlantTest, updateParametersAllTrue)

TEST_F(BasePlantTest, updateStateOutsideTimeTest)
{
mockController->setDt(DT);
plant->setLastTime(0);

EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0);
Expand All @@ -293,7 +292,6 @@ TEST_F(BasePlantTest, updateStateOutsideTimeTest)

TEST_F(BasePlantTest, updateStateTest)
{
mockController->setDt(DT);
plant->setLastTime(0);

MockController::state_array state = MockController::state_array::Zero();
Expand All @@ -303,16 +301,73 @@ TEST_F(BasePlantTest, updateStateTest)
EXPECT_EQ(plant->pubControlCalled, 0);
EXPECT_EQ(plant->pubNominalStateCalled, 0);

EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(1);
// Calling updateState() should not pub controls when none have been calculated as of yet
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0);
plant->setLastUsedTime(11);
plant->updateState(state, 12);
EXPECT_EQ(plant->getState(), state);
EXPECT_EQ(plant->pubControlCalled, 1);
EXPECT_EQ(plant->pubControlCalled, 0);
EXPECT_EQ(plant->pubNominalStateCalled, 0);

// TODO in debug should pub nominal state
}

TEST_F(BasePlantTest, pubControlOnlyAfterControlAreCalculatedTest)
{
::testing::Sequence s1;
// Step 1: calling updateState() before controls are calculated should not call controller->getCurrentControl()
MockController::state_array state = MockController::state_array::Zero();
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0).InSequence(s1);
double curr_time = 0.0;
plant->updateState(state, curr_time);
EXPECT_EQ(plant->pubControlCalled, 0);
EXPECT_EQ(plant->pubNominalStateCalled, 0);
// ::testing::Mock::VerifyAndClearExpectations(mockController.get());

// Step 2: run control iteration inside plant
// Create valid outputs from gmock methods to prevent nan detection from triggering
MockController::control_trajectory valid_control_seq = MockController::control_trajectory::Zero(MockDynamics::CONTROL_DIM, NUM_TIMESTEPS);
MockController::state_trajectory valid_state_seq = MockController::state_trajectory::Zero(MockDynamics::STATE_DIM, NUM_TIMESTEPS);
EXPECT_CALL(*mockController, computeControl(testing::_, testing::_)).Times(1).InSequence(s1);
EXPECT_CALL(*mockController, getControlSeq()).Times(1).WillRepeatedly(testing::Return(valid_control_seq));
EXPECT_CALL(*mockController, getTargetStateSeq()).Times(1).WillRepeatedly(testing::Return(valid_state_seq));
// EXPECT_CALL(*mockController, getTargetOutputSeq()).Times(1);
std::atomic<bool> is_alive(true);
plant->runControlIteration(&is_alive);

// Step 3: calling updateState() now should use controller->getCurrentControl()
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(1).InSequence(s1);
curr_time++;
plant->updateState(state, curr_time);
EXPECT_EQ(plant->pubControlCalled, 1);
EXPECT_EQ(plant->pubNominalStateCalled, 0);
}

TEST_F(BasePlantTest, EnsureReceivingStateCompletesRunControlIterationTest)
{
std::atomic<bool> is_alive(true);
// Create valid outputs from gmock methods to prevent nan detection from triggering
MockController::control_trajectory valid_control_seq = MockController::control_trajectory::Zero(MockDynamics::CONTROL_DIM, NUM_TIMESTEPS);
MockController::state_trajectory valid_state_seq = MockController::state_trajectory::Zero(MockDynamics::STATE_DIM, NUM_TIMESTEPS);
EXPECT_CALL(*mockController, computeControl(testing::_, testing::_)).Times(1);
EXPECT_CALL(*mockController, getControlSeq()).Times(1).WillRepeatedly(testing::Return(valid_control_seq));
EXPECT_CALL(*mockController, getTargetStateSeq()).Times(1).WillRepeatedly(testing::Return(valid_state_seq));
// EXPECT_CALL(*mockController, getTargetOutputSeq()).Times(1);
std::thread new_thread(&MockTestPlant::runControlIteration, plant.get(), &is_alive);
// Wait some period of time and then call updateState()
std::cout << "Wait for new state" << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(300));
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0);
MockController::state_array state = MockController::state_array::Zero();
double curr_time = 0.0;
plant->updateState(state, curr_time);
std::cout << "State sent" << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// is_alive.store(false);
new_thread.join();

}

TEST_F(BasePlantTest, runControlIterationStoppedTest)
{
EXPECT_CALL(*mockController, slideControlSequence(testing::_)).Times(0);
Expand Down
27 changes: 19 additions & 8 deletions tests/mppi_core/buffered_plant_tester.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ public:
{
return this->buffer_tau_;
}
double getBufferDt()
{
return this->buffer_dt_;
}
void setLastUsedUpdateTime(double time)
{
this->last_used_state_update_time_ = time;
Expand Down Expand Up @@ -198,6 +194,13 @@ TEST_F(BufferedPlantTest, Constructor)
EXPECT_FLOAT_EQ(plant->getBufferDt(), 0.02);
}

TEST_F(BufferedPlantTest, setBufferDt)
{
double new_buffer_dt = 30.0;
plant->setBufferDt(new_buffer_dt);
EXPECT_FLOAT_EQ(plant->getBufferDt(), new_buffer_dt);
}

TEST_F(BufferedPlantTest, interpNew)
{
Eigen::Vector3f pos = Eigen::Vector3f::Ones();
Expand All @@ -208,7 +211,9 @@ TEST_F(BufferedPlantTest, interpNew)
MockDynamics::state_array state = MockDynamics::state_array::Random();

EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state));
EXPECT_CALL(*mockController, getDt()).Times(2);
// Controls never calculated so no calls to controller in updateState()
EXPECT_CALL(*mockController, getDt()).Times(0);
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0);

plant->setLastUsedUpdateTime(0);
plant->updateOdometry(pos, quat, vel, omega, 0.0);
Expand Down Expand Up @@ -380,7 +385,9 @@ TEST_F(BufferedPlantTest, updateOdometry)
MockDynamics::state_array state = MockDynamics::state_array::Random();

EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state));
EXPECT_CALL(*mockController, getDt()).Times(2);
// Controls never calculated so no calls to controller in updateState()
EXPECT_CALL(*mockController, getDt()).Times(0);
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0);

plant->setLastUsedUpdateTime(0.0);
plant->updateOdometry(pos, quat, vel, omega, 0.0);
Expand Down Expand Up @@ -448,7 +455,9 @@ TEST_F(BufferedPlantTest, getInterpState)

plant->setLastUsedUpdateTime(0.0);
EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state));
EXPECT_CALL(*mockController, getDt()).Times(2);
// Controls never calculated so no calls to controller in updateState()
EXPECT_CALL(*mockController, getDt()).Times(0);
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0);

plant->updateOdometry(pos, quat, vel, omega, 0.0);
plant->updateControls(u, 0.0);
Expand Down Expand Up @@ -509,7 +518,9 @@ TEST_F(BufferedPlantTest, getInterpBuffer)

plant->setLastUsedUpdateTime(0.0);
EXPECT_CALL(mockDynamics, stateFromMap(testing::_)).Times(2).WillRepeatedly(testing::Return(state));
EXPECT_CALL(*mockController, getDt()).Times(2);
// Controls never calculated so no calls to controller in updateState()
EXPECT_CALL(*mockController, getDt()).Times(0);
EXPECT_CALL(*mockController, getCurrentControl(testing::_, testing::_, testing::_, testing::_, testing::_)).Times(0);

plant->updateOdometry(pos, quat, vel, omega, 0.0);
plant->updateControls(u, 0.0);
Expand Down