diff --git a/include/mppi/controllers/controller.cuh b/include/mppi/controllers/controller.cuh index c356d107..26a4c306 100644 --- a/include/mppi/controllers/controller.cuh +++ b/include/mppi/controllers/controller.cuh @@ -837,6 +837,7 @@ public: { bool change_seed = p.seed_ != params_.seed_; bool change_num_timesteps = p.num_timesteps_ != params_.num_timesteps_; + bool change_dt = p.dt_ != params_.dt_; // bool change_std_dev = p.control_std_dev_ != params_.control_std_dev_; params_ = p; if (change_num_timesteps) @@ -847,6 +848,10 @@ public: { setSeedCUDARandomNumberGen(params_.seed_); } + if (change_dt) + { + fb_controller_->setDt(p.dt_); + } } int getNumIters() const diff --git a/include/mppi/core/base_plant.hpp b/include/mppi/core/base_plant.hpp index 5203f97e..9b5a0954 100644 --- a/include/mppi/core/base_plant.hpp +++ b/include/mppi/core/base_plant.hpp @@ -41,6 +41,8 @@ class BasePlant using COST_PARAMS_T = typename COST_T::COST_PARAMS_T; using TEMPLATED_CONTROLLER = CONTROLLER_T; using CONTROLLER_PARAMS_T = typename CONTROLLER_T::TEMPLATED_PARAMS; + using SAMPLER_T = typename CONTROLLER_T::TEMPLATED_SAMPLING; + using SAMPLER_PARAMS_T = typename CONTROLLER_T::TEMPLATED_SAMPLING_PARAMS; // Feedback related aliases using FB_STATE_T = typename CONTROLLER_T::TEMPLATED_FEEDBACK::TEMPLATED_FEEDBACK_STATE; @@ -61,10 +63,13 @@ class BasePlant std::mutex cost_params_guard_; CONTROLLER_PARAMS_T controller_params_; std::mutex controller_params_guard_; + SAMPLER_PARAMS_T sampler_params_; + std::mutex sampler_params_guard_; std::atomic has_new_dynamics_params_{ false }; std::atomic has_new_cost_params_{ false }; std::atomic has_new_controller_params_{ false }; + std::atomic has_new_sampler_params_{ false }; std::atomic enabled_{ false }; // Values needed @@ -332,6 +337,10 @@ class BasePlant { return has_new_controller_params_; }; + virtual bool hasNewSamplerParams() + { + return has_new_sampler_params_; + }; virtual DYN_PARAMS_T getNewDynamicsParams(bool set_flag = false) { @@ -348,6 +357,11 @@ class BasePlant has_new_controller_params_ = set_flag; return controller_params_; } + virtual SAMPLER_PARAMS_T getNewSamplerParams(bool set_flag = false) + { + has_new_sampler_params_ = set_flag; + return sampler_params_; + } virtual void setDynamicsParams(const DYN_PARAMS_T& params) { @@ -367,6 +381,12 @@ class BasePlant controller_params_ = params; has_new_controller_params_ = true; } + virtual void setSamplerParams(const SAMPLER_PARAMS_T& params) + { + std::lock_guard guard(sampler_params_guard_); + sampler_params_ = params; + has_new_sampler_params_ = true; + } virtual void setLogger(const mppi::util::MPPILoggerPtr& logger) { @@ -423,6 +443,14 @@ class BasePlant CONTROLLER_PARAMS_T controller_params = getNewControllerParams(); controller_->setParams(controller_params); } + // Update sampler params + if (hasNewSamplerParams()) + { + std::lock_guard guard(sampler_params_guard_); + changed = true; + SAMPLER_PARAMS_T sampler_params = getNewSamplerParams(); + controller_->setSamplingParams(sampler_params); + } return changed; } diff --git a/include/mppi/utils/logger.hpp b/include/mppi/utils/logger.hpp index 144723a9..dbeae2df 100644 --- a/include/mppi/utils/logger.hpp +++ b/include/mppi/utils/logger.hpp @@ -3,10 +3,10 @@ * Created by Bogdan on 11/01/2023 */ -#include #include -#include #include +#include +#include namespace mppi { @@ -89,91 +89,127 @@ class MPPILogger } /** - * @brief Log debug messages to the output stream in green if the log level is set for DEBUG - * @param fmt Format string (if additional arguments are passed) or message to display + * @brief Log debug messages using virtual debug_impl() method + * + * @tparam ...Args variadic template type of args used in the format string + * @param fmt format string used in printf + * @param args extra args used by the format string fmt + */ + template + void debug(const char* fmt, Args const&... args) + { + std::string message = format_string(fmt, args...); + this->debug_impl(message); + } + + /** + * @brief Log info messages using virtual info_impl() method + * + * @tparam ...Args variadic template type of args used in the format string + * @param fmt format string used in printf + * @param args extra args used by the format string fmt + */ + template + void info(const char* fmt, Args const&... args) + { + std::string message = format_string(fmt, args...); + this->info_impl(message); + } + + /** + * @brief Log warning messages using virtual warning_impl() method + * + * @tparam ...Args variadic template type of args used in the format string + * @param fmt format string used in printf + * @param args extra args used by the format string fmt */ - virtual void debug(const char* fmt, ...) + template + void warning(const char* fmt, Args const&... args) + { + std::string message = format_string(fmt, args...); + this->warning_impl(message); + } + + /** + * @brief Log errror messages using virtual errror_impl() method + * + * @tparam ...Args variadic template type of args used in the format string + * @param fmt format string used in printf + * @param args extra args used by the format string fmt + */ + template + void error(const char* fmt, Args const&... args) + { + std::string message = format_string(fmt, args...); + this->error_impl(message); + } + +protected: + LOG_LEVEL log_level_ = GLOBAL_LOG_LEVEL; + std::FILE* output_stream_ = stdout; + + virtual void debug_impl(const std::string& message) { if (log_level_ <= LOG_LEVEL::DEBUG) { - std::va_list argptr; - va_start(argptr, fmt); - surround_fprintf(output_stream_, GREEN, RESET, fmt, argptr); - va_end(argptr); + surround_fprintf(output_stream_, GREEN, RESET, message); } } - /** - * @brief Log info messages to the output stream in cyan if the log level is set for INFO - * @param fmt Format string (if additional arguments are passed) or message to display - */ - virtual void info(const char* fmt, ...) + virtual void info_impl(const std::string& message) { if (log_level_ <= LOG_LEVEL::INFO) { - std::va_list argptr; - va_start(argptr, fmt); - surround_fprintf(output_stream_, CYAN, RESET, fmt, argptr); - va_end(argptr); + surround_fprintf(output_stream_, CYAN, RESET, message); } } - /** - * @brief Log debug messages to the output stream in yellow if the log level is set for WARNING - * @param fmt Format string (if additional arguments are passed) or message to display - */ - virtual void warning(const char* fmt, ...) + virtual void warning_impl(const std::string& message) { if (log_level_ <= LOG_LEVEL::WARNING) { - std::va_list argptr; - va_start(argptr, fmt); - surround_fprintf(output_stream_, YELLOW, RESET, fmt, argptr); - va_end(argptr); + surround_fprintf(output_stream_, YELLOW, RESET, message); } } - /** - * @brief Log debug messages to the output stream in red if the log level is set for ERROR - * @param fmt Format string (if additional arguments are passed) or message to display - */ - virtual void error(const char* fmt, ...) + virtual void error_impl(const std::string& message) { if (log_level_ <= LOG_LEVEL::ERROR) { - std::va_list argptr; - va_start(argptr, fmt); - surround_fprintf(output_stream_, RED, RESET, fmt, argptr); - va_end(argptr); + surround_fprintf(output_stream_, RED, RESET, message); } } -protected: - LOG_LEVEL log_level_ = GLOBAL_LOG_LEVEL; - std::FILE* output_stream_ = stdout; + /** + * @brief Print message to stream with coloring defined by prefix + * + * @param fstream where the message will be printed to + * @param prefix prefix string to print before message. Expected to be a color code + * @param suffix suffix string to print after message. Expected to be a color reset code + * @param message actual message to be printed + */ + virtual void surround_fprintf(std::FILE* fstream, const char* prefix, const char* suffix, const std::string& message) + { + std::fprintf(fstream, "%s%s%s", prefix, message.c_str(), suffix); + } /** - * @brief Prints a colored output to a provided fstream. It does this by first creating the formatted string - * as a std::vector so that it can be used as an input to fprintf with a different format string + * @brief create a string out of format string and variable number of additional arguments + * + * @tparam ...Args variadic template type for extra arguments passed to format_string() + * @param fmt format string defining how to display additional arguments + * @param args additional arguments for formatting * - * @param fstream file stream to write output to - * @param color color code to use on provided string - * @param fmt format string - * @param ... extra variables for format string + * @return std::string containing formatted text */ - virtual void surround_fprintf(std::FILE* fstream, const char* prefix, const char* suffix, const char* fmt, - std::va_list args) + template + std::string format_string(const char* fmt, Args const&... args) { - // introducing a second copy of the args as calling vsnprintf leaves args in an indeterminate state - std::va_list args_cpy; - va_copy(args_cpy, args); - // figure out size of formatted string, also uses up args - std::vector buf(1 + std::vsnprintf(nullptr, 0, fmt, args)); - // Fill buffer with formatted string using copy of the args - std::vsnprintf(buf.data(), buf.size(), fmt, args_cpy); - va_end(args_cpy); - // print formatted string but colored - std::fprintf(fstream, "%s%s%s", prefix, buf.data(), suffix); + // figure out size of formatted string + std::vector buf(1 + std::snprintf(nullptr, 0, fmt, args...)); + // Fill buffer with formatted string + std::snprintf(buf.data(), buf.size(), fmt, args...); + return std::string(buf.data()); } };