diff --git a/r2ai/local/r2ai/interpreter.py b/r2ai/local/r2ai/interpreter.py index 1d1eb1123..573e0eb2d 100644 --- a/r2ai/local/r2ai/interpreter.py +++ b/r2ai/local/r2ai/interpreter.py @@ -87,6 +87,10 @@ def messages_to_prompt(self,messages): formatted_messages = template_uncensored(self,messages) elif "falcon" in self.model.lower(): formatted_messages = template_falcon(self,messages) + elif "tinyllama" in self.model.lower(): + formatted_messages = template_tinyllama(self,messages) + elif "TinyLlama" in self.model.lower(): + formatted_messages = template_tinyllama(self,messages) else: formatted_messages = template_llama(self,messages) @@ -156,6 +160,31 @@ def template_falcon(self,messages): formatted_messages += f"{message['role'].capitalize()}: {message['content']}" return formatted_messages.strip() +def template_tinyllama(self,messages): + # Llama prompt template + # Extracting the system prompt and initializing the formatted string with it. + self.terminator = "\n" # <|im_end|>" +# TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF +#<|> + system_prompt = messages[0]['content'].strip() + if system_prompt != "": + formatted_messages = f"<|im_start|> assistant\n{system_prompt}\n<|im_end|>" + else: + formatted_messages = f"<|im_start|> " + # Loop starting from the first user message + for index, item in enumerate(messages[1:]): + role = item['role'] + content = item['content'] + if role == 'user': + formatted_messages += f"user {content} " + elif role == 'function': + formatted_messages += f"user {content} " + elif role == 'assistant': + formatted_messages += f"assistant {content} " + # Remove the trailing '[INST] ' from the final output + formatted_messages += f"<|im_end|>" + return formatted_messages + def template_llama(self,messages): # Llama prompt template # Extracting the system prompt and initializing the formatted string with it.