Skip to content

Commit

Permalink
nodes replaces by lifecycles nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Aug 20, 2024
1 parent 90fb49d commit cec55ef
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 164 deletions.
2 changes: 1 addition & 1 deletion llama_cpp_vendor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ find_package(ament_cmake REQUIRED)
FetchContent_Declare(
llama
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG b3604
GIT_TAG b3609
)

FetchContent_MakeAvailable(llama)
Expand Down
3 changes: 3 additions & 0 deletions llama_ros/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ endif()
find_package(ament_cmake REQUIRED)
find_package(rclcpp REQUIRED)
find_package(rclcpp_action REQUIRED)
find_package(rclcpp_lifecycle REQUIRED)
find_package(llama_msgs REQUIRED)
find_package(llama_cpp_vendor REQUIRED)
find_package(Threads REQUIRED)
Expand Down Expand Up @@ -56,6 +57,7 @@ target_link_libraries(llama_node
ament_target_dependencies(llama_node
rclcpp
rclcpp_action
rclcpp_lifecycle
llama_msgs
llama_cpp_vendor
)
Expand All @@ -78,6 +80,7 @@ target_link_libraries(llava_node
ament_target_dependencies(llava_node
rclcpp
rclcpp_action
rclcpp_lifecycle
llama_msgs
llama_cpp_vendor
cv_bridge
Expand Down
2 changes: 1 addition & 1 deletion llama_ros/include/llama_ros/llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class Llama {

std::vector<token_prob> get_probs();
struct completion_output sample();
void update_sampling_params(const struct llama_sampling_params &params);
void update_sampling_context(const struct llama_sampling_params &params);

private:
// lock
Expand Down
26 changes: 22 additions & 4 deletions llama_ros/include/llama_ros/llama_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <rclcpp/rclcpp.hpp>
#include <rclcpp_action/rclcpp_action.hpp>
#include <rclcpp_lifecycle/lifecycle_node.hpp>

#include <memory>
#include <string>
Expand All @@ -39,20 +40,37 @@

namespace llama_ros {

class LlamaNode : public rclcpp::Node {
using CallbackReturn =
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn;

class LlamaNode : public rclcpp_lifecycle::LifecycleNode {

using GenerateResponse = llama_msgs::action::GenerateResponse;
using GoalHandleGenerateResponse =
rclcpp_action::ServerGoalHandle<GenerateResponse>;

public:
LlamaNode(bool load_llama = true);
LlamaNode();

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_configure(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_activate(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_deactivate(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_cleanup(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_shutdown(const rclcpp_lifecycle::State &);

protected:
std::shared_ptr<Llama> llama;
llama_utils::GptParams gpt_params;
std::unique_ptr<Llama> llama;
std::unique_ptr<llama_utils::GptParams> gpt_params;
std::shared_ptr<GoalHandleGenerateResponse> goal_handle_;

virtual void create_llama();
void destroy_llama();

virtual bool goal_empty(std::shared_ptr<const GenerateResponse::Goal> goal);
virtual void
execute(const std::shared_ptr<GoalHandleGenerateResponse> goal_handle);
Expand Down
10 changes: 8 additions & 2 deletions llama_ros/include/llama_utils/gpt_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <memory>
#include <rclcpp/rclcpp.hpp>
#include <rclcpp_lifecycle/lifecycle_node.hpp>

#include "common.h"
#include "llama.h"
Expand All @@ -36,12 +37,17 @@ namespace llama_utils {
class GptParams {

public:
GptParams();
std::shared_ptr<struct gpt_params> load_params(rclcpp::Node *node);
GptParams(rclcpp_lifecycle::LifecycleNode::SharedPtr node);

void declare_params();
std::shared_ptr<struct gpt_params> get_params();
bool
update_sampling_params(const llama_msgs::msg::SamplingConfig &sampling_config,
int n_vocab, llama_token token_eos);

rclcpp_lifecycle::LifecycleNode::SharedPtr node;

// params
bool debug;
struct llava_ros::llava_params llava_params;
std::shared_ptr<struct gpt_params> params;
Expand Down
3 changes: 1 addition & 2 deletions llama_ros/include/llava_ros/llava_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ class LlavaNode : public llama_ros::LlamaNode {
bool url = false);

protected:
std::shared_ptr<Llava> llava;

void create_llama();
bool goal_empty(std::shared_ptr<const GenerateResponse::Goal> goal) override;
void execute(
const std::shared_ptr<GoalHandleGenerateResponse> goal_handle) override;
Expand Down
1 change: 1 addition & 0 deletions llama_ros/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

<depend>rclcpp</depend>
<depend>rclcpp_action</depend>
<depend>rclcpp_lifecycle</depend>
<depend>cv_bridge</depend>
<depend>llama_msgs</depend>
<depend>llama_cpp_vendor</depend>
Expand Down
10 changes: 9 additions & 1 deletion llama_ros/src/llama_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ int main(int argc, char *argv[]) {
sigaction(SIGINT, &sigint_action, NULL);

rclcpp::init(argc, argv);
rclcpp::spin(std::make_shared<LlamaNode>());

auto node = std::make_shared<LlamaNode>();
node->configure();
node->activate();

rclcpp::executors::SingleThreadedExecutor executor;
executor.add_node(node->get_node_base_interface());
executor.spin();

rclcpp::shutdown();
return 0;
}
5 changes: 3 additions & 2 deletions llama_ros/src/llama_ros/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ response_output Llama::generate_response(const std::string &input_prompt,
llama_set_embeddings(this->ctx, false);

// load params
this->update_sampling_params(this->params->sparams);
this->update_sampling_context(this->params->sparams);

// load prompt
this->load_prompt(input_prompt, true, true);
Expand Down Expand Up @@ -671,7 +671,8 @@ struct completion_output Llama::sample() {
return result;
}

void Llama::update_sampling_params(const struct llama_sampling_params &params) {
void Llama::update_sampling_context(
const struct llama_sampling_params &params) {

this->ctx_sampling->params = params;

Expand Down
105 changes: 95 additions & 10 deletions llama_ros/src/llama_ros/llama_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,52 @@ using namespace llama_ros;
using std::placeholders::_1;
using std::placeholders::_2;

LlamaNode::LlamaNode(bool load_llama) : rclcpp::Node("llama_node") {
LlamaNode::LlamaNode()
: rclcpp_lifecycle::LifecycleNode("llama_node"), gpt_params(nullptr) {
RCLCPP_INFO(this->get_logger(), "%s started", this->get_name());
}

void LlamaNode::create_llama() {
this->llama = std::make_unique<Llama>(this->gpt_params->params,
this->gpt_params->debug);
}

void LlamaNode::destroy_llama() {
this->llama.reset();
this->llama = nullptr;
}

if (load_llama) {
this->llama = std::make_shared<Llama>(this->gpt_params.load_params(this),
this->gpt_params.debug);
/*
*****************************
* LIFECYCLE *
*****************************
*/
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
LlamaNode::on_configure(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Configuring...", this->get_name());

if (this->gpt_params == nullptr) {
this->gpt_params =
std::make_unique<llama_utils::GptParams>(this->shared_from_this());
this->gpt_params->declare_params();
}

this->gpt_params->get_params();
RCLCPP_INFO(get_logger(), "[%s] Configured", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
LlamaNode::on_activate(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Activating...", this->get_name());

// create llama
this->create_llama();

// services
this->tokenize_service_ = this->create_service<llama_msgs::srv::Tokenize>(
"tokenize",
Expand All @@ -61,7 +100,53 @@ LlamaNode::LlamaNode(bool load_llama) : rclcpp::Node("llama_node") {
std::bind(&LlamaNode::handle_cancel, this, _1),
std::bind(&LlamaNode::handle_accepted, this, _1));

RCLCPP_INFO(this->get_logger(), "%s started", this->get_name());
RCLCPP_INFO(get_logger(), "[%s] Activated", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
LlamaNode::on_deactivate(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Deactivating...", this->get_name());

this->destroy_llama();

this->tokenize_service_.reset();
this->tokenize_service_ = nullptr;

this->generate_embeddings_service_.reset();
this->generate_embeddings_service_ = nullptr;

this->goal_handle_ = nullptr;
this->generate_response_action_server_.reset();
this->generate_response_action_server_ = nullptr;

RCLCPP_INFO(get_logger(), "[%s] Deactivated", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
LlamaNode::on_cleanup(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Cleaning up...", this->get_name());
RCLCPP_INFO(get_logger(), "[%s] Cleaned up", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
LlamaNode::on_shutdown(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Shutting down...", this->get_name());
RCLCPP_INFO(get_logger(), "[%s] Shutted down", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

/*
Expand All @@ -78,7 +163,7 @@ void LlamaNode::tokenize_service_callback(

/*
*****************************
* EMBEEDINGS SERVICE *
* EMBEDDINGS SERVICE *
*****************************
*/
void LlamaNode::generate_embeddings_service_callback(
Expand Down Expand Up @@ -144,7 +229,7 @@ void LlamaNode::execute(
return;
}

if (this->gpt_params.debug) {
if (this->gpt_params->debug) {
RCLCPP_INFO(this->get_logger(), "Prompt received:\n%s", prompt.c_str());
}

Expand All @@ -155,9 +240,9 @@ void LlamaNode::execute(

// update sampling params of gpt_params
auto sampling_config = goal_handle->get_goal()->sampling_config;
this->gpt_params.update_sampling_params(sampling_config,
this->llama->get_n_vocab(),
this->llama->get_token_eos());
this->gpt_params->update_sampling_params(sampling_config,
this->llama->get_n_vocab(),
this->llama->get_token_eos());

// call llama
struct response_output output = this->llama->generate_response(
Expand Down
Loading

0 comments on commit cec55ef

Please sign in to comment.