Skip to content

Commit

Permalink
use_llama_template renamed to use_default_template
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Jan 13, 2025
1 parent d5f07be commit 0171812
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def get_curr_temperature(city: str) -> int:
"""Get the current temperature of a city"""
return randint(20, 30)

chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True)
chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_default_template=True)

messages = [
HumanMessage(
Expand Down
8 changes: 4 additions & 4 deletions llama_ros/llama_ros/langchain/chat_llama_ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@

class ChatLlamaROS(BaseChatModel, LlamaROSCommon):

use_llama_template: bool = False

use_default_template: bool = False
use_gguf_template: bool = True

jinja_env: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
Expand Down Expand Up @@ -140,7 +139,7 @@ def _generate_prompt(self, messages: List[dict[str, str]], **kwargs) -> str:

tools_grammar = kwargs.get("tools_grammar", None)

if self.use_llama_template:
if self.use_default_template:
chat_template = DEFAULT_TEMPLATE
else:
chat_template = self.model_metadata.tokenizer.chat_template
Expand Down Expand Up @@ -172,7 +171,7 @@ def _generate_prompt(self, messages: List[dict[str, str]], **kwargs) -> str:
Detokenize.Request(tokens=[self.model_metadata.tokenizer.bos_token_id])
).text

if self.use_gguf_template or self.use_llama_template:
if self.use_gguf_template or self.use_default_template:
formatted_prompt = self.jinja_env.from_string(chat_template).render(
messages=messages,
add_generation_prompt=True,
Expand Down Expand Up @@ -355,6 +354,7 @@ def _generate(

result, status = self.llama_client.generate_response(goal_action)
response = result.response
print(response.text)

if status != GoalStatus.STATUS_SUCCEEDED:
return ""
Expand Down

0 comments on commit 0171812

Please sign in to comment.