diff --git a/README.md b/README.md index f9a130b4..5d7d213b 100644 --- a/README.md +++ b/README.md @@ -822,7 +822,7 @@ def get_curr_temperature(city: str) -> int: """Get the current temperature of a city""" return randint(20, 30) -chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True) +chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_default_template=True) messages = [ HumanMessage( diff --git a/llama_ros/llama_ros/langchain/chat_llama_ros.py b/llama_ros/llama_ros/langchain/chat_llama_ros.py index 6cb96c99..6198fcd8 100644 --- a/llama_ros/llama_ros/langchain/chat_llama_ros.py +++ b/llama_ros/llama_ros/langchain/chat_llama_ros.py @@ -104,8 +104,7 @@ class ChatLlamaROS(BaseChatModel, LlamaROSCommon): - use_llama_template: bool = False - + use_default_template: bool = False use_gguf_template: bool = True jinja_env: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( @@ -140,7 +139,7 @@ def _generate_prompt(self, messages: List[dict[str, str]], **kwargs) -> str: tools_grammar = kwargs.get("tools_grammar", None) - if self.use_llama_template: + if self.use_default_template: chat_template = DEFAULT_TEMPLATE else: chat_template = self.model_metadata.tokenizer.chat_template @@ -172,7 +171,7 @@ def _generate_prompt(self, messages: List[dict[str, str]], **kwargs) -> str: Detokenize.Request(tokens=[self.model_metadata.tokenizer.bos_token_id]) ).text - if self.use_gguf_template or self.use_llama_template: + if self.use_gguf_template or self.use_default_template: formatted_prompt = self.jinja_env.from_string(chat_template).render( messages=messages, add_generation_prompt=True, @@ -355,6 +354,7 @@ def _generate( result, status = self.llama_client.generate_response(goal_action) response = result.response + print(response.text) if status != GoalStatus.STATUS_SUCCEEDED: return ""