Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Jun 4, 2024
1 parent ed8da74 commit dd9d4fa
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/agents/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .llm import create_tool_use_llm
from ._openai import OpenAIChatCompletion
from ._cohere import CohereChatCompletion
from ._cohere import CohereChatCompletion, CohereChatCompletionV2
from ._llamacpp import LlamaCppChatCompletion

__all__ = [
"create_tool_use_llm",
"CohereChatCompletion",
"OpenAIChatCompletion",
"LlamaCppChatCompletion",
"CohereChatCompletionV2",
]
67 changes: 65 additions & 2 deletions src/agents/llms/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time

from langchain_cohere import ChatCohere
import cohere

from typing import List, Optional, Any, Dict

Expand All @@ -14,8 +13,41 @@
logger = logging.getLogger(__name__)


def _format_cohere_to_openai(output: cohere.NonStreamedChatResponse):
response_meta = output.meta
_tool_calls = output.tool_calls
tool_calls = []
for tool in _tool_calls:
tool = ToolCall(id=tool["id"], type=tool["type"], function=tool["function"])
tool_calls.append(tool)

message = Message(role="assistant", content=output.content, tool_calls=tool_calls)
choices = Choice(
index=0, logprobs=None, message=message, finish_reason=output.finish_reason
)
usage = Usage(
prompt_tokens=response_meta.tokens.input_tokens,
completion_tokens=response_meta.tokens.output_tokens,
total_tokens=response_meta.tokens.total_tokens,
)
response = {
"id": output.id,
"object": "",
"created": int(time.time()),
"model": "Cohere",
"choices": [choices],
"usage": usage,
}
return ChatCompletion(**response)


class CohereChatCompletion:
def __init__(self, **kwargs):
try:
from langchain_cohere import ChatCohere
except:
logger.error("Make sure that you have correct cohere version installed.")
raise
self.llm = ChatCohere()
self.tool_registry = ToolRegistry()

Expand Down Expand Up @@ -60,3 +92,34 @@ def chat_completion(

def run_tools(self, chat_completion: ChatCompletion) -> List[Dict[str, Any]]:
return self.tool_registry.call_tools(chat_completion)


class CohereChatCompletionV2:
def __init__(self, model="command-r", preamble: Optional[str] = None, **kwargs):
self.model = model
self.preamble = preamble
self.client = cohere.Client()

def chat(
self,
message: str,
chat_history: Optional[List[dict]] = None,
force_single_step=False,
**kwargs,
) -> cohere.NonStreamedChatResponse:
output = self.client.chat(
model=self.model,
message=message,
chat_history=chat_history,
preamble=self.preamble,
force_single_step=force_single_step,
)
return output

def chat_completion(
self, messages: List[Dict[str, str]], **kwargs
) -> ChatCompletion:
message = messages[-1]["content"]
chat_history = messages[:-1]
output = self.chat(message=message, chat_history=chat_history, **kwargs)
return _format_cohere_to_openai(output)

0 comments on commit dd9d4fa

Please sign in to comment.