Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/mppi/controllers/controller.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ public:
{
bool change_seed = p.seed_ != params_.seed_;
bool change_num_timesteps = p.num_timesteps_ != params_.num_timesteps_;
bool change_dt = p.dt_ != params_.dt_;
// bool change_std_dev = p.control_std_dev_ != params_.control_std_dev_;
params_ = p;
if (change_num_timesteps)
Expand All @@ -847,6 +848,10 @@ public:
{
setSeedCUDARandomNumberGen(params_.seed_);
}
if (change_dt)
{
fb_controller_->setDt(p.dt_);
}
}

int getNumIters() const
Expand Down
28 changes: 28 additions & 0 deletions include/mppi/core/base_plant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class BasePlant
using COST_PARAMS_T = typename COST_T::COST_PARAMS_T;
using TEMPLATED_CONTROLLER = CONTROLLER_T;
using CONTROLLER_PARAMS_T = typename CONTROLLER_T::TEMPLATED_PARAMS;
using SAMPLER_T = typename CONTROLLER_T::TEMPLATED_SAMPLING;
using SAMPLER_PARAMS_T = typename CONTROLLER_T::TEMPLATED_SAMPLING_PARAMS;

// Feedback related aliases
using FB_STATE_T = typename CONTROLLER_T::TEMPLATED_FEEDBACK::TEMPLATED_FEEDBACK_STATE;
Expand All @@ -61,10 +63,13 @@ class BasePlant
std::mutex cost_params_guard_;
CONTROLLER_PARAMS_T controller_params_;
std::mutex controller_params_guard_;
SAMPLER_PARAMS_T sampler_params_;
std::mutex sampler_params_guard_;

std::atomic<bool> has_new_dynamics_params_{ false };
std::atomic<bool> has_new_cost_params_{ false };
std::atomic<bool> has_new_controller_params_{ false };
std::atomic<bool> has_new_sampler_params_{ false };
std::atomic<bool> enabled_{ false };

// Values needed
Expand Down Expand Up @@ -332,6 +337,10 @@ class BasePlant
{
return has_new_controller_params_;
};
virtual bool hasNewSamplerParams()
{
return has_new_sampler_params_;
};

virtual DYN_PARAMS_T getNewDynamicsParams(bool set_flag = false)
{
Expand All @@ -348,6 +357,11 @@ class BasePlant
has_new_controller_params_ = set_flag;
return controller_params_;
}
virtual SAMPLER_PARAMS_T getNewSamplerParams(bool set_flag = false)
{
has_new_sampler_params_ = set_flag;
return sampler_params_;
}

virtual void setDynamicsParams(const DYN_PARAMS_T& params)
{
Expand All @@ -367,6 +381,12 @@ class BasePlant
controller_params_ = params;
has_new_controller_params_ = true;
}
virtual void setSamplerParams(const SAMPLER_PARAMS_T& params)
{
std::lock_guard<std::mutex> guard(sampler_params_guard_);
sampler_params_ = params;
has_new_sampler_params_ = true;
}

virtual void setLogger(const mppi::util::MPPILoggerPtr& logger)
{
Expand Down Expand Up @@ -423,6 +443,14 @@ class BasePlant
CONTROLLER_PARAMS_T controller_params = getNewControllerParams();
controller_->setParams(controller_params);
}
// Update sampler params
if (hasNewSamplerParams())
{
std::lock_guard<std::mutex> guard(sampler_params_guard_);
changed = true;
SAMPLER_PARAMS_T sampler_params = getNewSamplerParams();
controller_->setSamplingParams(sampler_params);
}
return changed;
}

Expand Down