diff --git a/docs/docs/integrations/chat/premai.ipynb b/docs/docs/integrations/chat/premai.ipynb index dded643754389..5a7b8f2bde47c 100644 --- a/docs/docs/integrations/chat/premai.ipynb +++ b/docs/docs/integrations/chat/premai.ipynb @@ -82,9 +82,9 @@ "outputs": [], "source": [ "# By default it will use the model which was deployed through the platform\n", - "# in my case it will is \"claude-3-haiku\"\n", + "# in my case it will is \"gpt-4o\"\n", "\n", - "chat = ChatPremAI(project_id=8)" + "chat = ChatPremAI(project_id=1234, model_name=\"gpt-4o\")" ] }, { @@ -107,7 +107,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "I am an artificial intelligence created by Anthropic. I'm here to help with a wide variety of tasks, from research and analysis to creative projects and open-ended conversation. I have general knowledge and capabilities, but I'm not a real person - I'm an AI assistant. Please let me know if you have any other questions!\n" + "I am an AI language model created by OpenAI, designed to assist with answering questions and providing information based on the context provided. How can I help you today?\n" ] } ], @@ -133,7 +133,7 @@ { "data": { "text/plain": [ - "AIMessage(content=\"I am an artificial intelligence created by Anthropic. My purpose is to assist and converse with humans in a friendly and helpful way. I have a broad knowledge base that I can use to provide information, answer questions, and engage in discussions on a wide range of topics. Please let me know if you have any other questions - I'm here to help!\")" + "AIMessage(content=\"I'm your friendly assistant! How can I help you today?\", response_metadata={'document_chunks': [{'repository_id': 1985, 'document_id': 1306, 'chunk_id': 173899, 'document_name': '[D] Difference between sparse and dense informati…', 'similarity_score': 0.3209080100059509, 'content': \"with the difference or anywhere\\nwhere I can read about it?\\n\\n\\n 17 9\\n\\n\\n u/ScotiabankCanada • Promoted\\n\\n\\n Accelerate your study permit process\\n with Scotiabank's Student GIC\\n Program. We're here to help you tur…\\n\\n\\n startright.scotiabank.com Learn More\\n\\n\\n Add a Comment\\n\\n\\nSort by: Best\\n\\n\\n DinosParkour • 1y ago\\n\\n\\n Dense Retrieval (DR) m\"}]}, id='run-510bbd0e-3f8f-4095-9b1f-c2d29fd89719-0')" ] }, "execution_count": 5, @@ -160,10 +160,18 @@ "execution_count": 6, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/anindya/prem/langchain/libs/community/langchain_community/chat_models/premai.py:355: UserWarning: WARNING: Parameter top_p is not supported in kwargs.\n", + " warnings.warn(f\"WARNING: Parameter {key} is not supported in kwargs.\")\n" + ] + }, { "data": { "text/plain": [ - "AIMessage(content='I am an artificial intelligence created by Anthropic')" + "AIMessage(content=\"Hello! I'm your friendly assistant. How can I\", response_metadata={'document_chunks': [{'repository_id': 1985, 'document_id': 1306, 'chunk_id': 173899, 'document_name': '[D] Difference between sparse and dense informati…', 'similarity_score': 0.3209080100059509, 'content': \"with the difference or anywhere\\nwhere I can read about it?\\n\\n\\n 17 9\\n\\n\\n u/ScotiabankCanada • Promoted\\n\\n\\n Accelerate your study permit process\\n with Scotiabank's Student GIC\\n Program. We're here to help you tur…\\n\\n\\n startright.scotiabank.com Learn More\\n\\n\\n Add a Comment\\n\\n\\nSort by: Best\\n\\n\\n DinosParkour • 1y ago\\n\\n\\n Dense Retrieval (DR) m\"}]}, id='run-c4b06b98-4161-4cca-8495-fd2fc98fa8f8-0')" ] }, "execution_count": 6, @@ -195,13 +203,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "query = \"what is the diameter of individual Galaxy\"\n", + "query = \"Which models are used for dense retrieval\"\n", "repository_ids = [\n", - " 1991,\n", + " 1985,\n", "]\n", "repositories = dict(ids=repository_ids, similarity_threshold=0.3, limit=3)" ] @@ -219,9 +227,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dense retrieval models typically include:\n", + "\n", + "1. **BERT-based Models**: Such as DPR (Dense Passage Retrieval) which uses BERT for encoding queries and passages.\n", + "2. **ColBERT**: A model that combines BERT with late interaction mechanisms.\n", + "3. **ANCE (Approximate Nearest Neighbor Negative Contrastive Estimation)**: Uses BERT and focuses on efficient retrieval.\n", + "4. **TCT-ColBERT**: A variant of ColBERT that uses a two-tower\n", + "{\n", + " \"document_chunks\": [\n", + " {\n", + " \"repository_id\": 1985,\n", + " \"document_id\": 1306,\n", + " \"chunk_id\": 173899,\n", + " \"document_name\": \"[D] Difference between sparse and dense informati\\u2026\",\n", + " \"similarity_score\": 0.3209080100059509,\n", + " \"content\": \"with the difference or anywhere\\nwhere I can read about it?\\n\\n\\n 17 9\\n\\n\\n u/ScotiabankCanada \\u2022 Promoted\\n\\n\\n Accelerate your study permit process\\n with Scotiabank's Student GIC\\n Program. We're here to help you tur\\u2026\\n\\n\\n startright.scotiabank.com Learn More\\n\\n\\n Add a Comment\\n\\n\\nSort by: Best\\n\\n\\n DinosParkour \\u2022 1y ago\\n\\n\\n Dense Retrieval (DR) m\"\n", + " }\n", + " ]\n", + "}\n" + ] + } + ], "source": [ "import json\n", "\n", @@ -262,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -288,7 +321,7 @@ "outputs": [], "source": [ "template_id = \"78069ce8-xxxxx-xxxxx-xxxx-xxx\"\n", - "response = chat.invoke([human_message], template_id=template_id)\n", + "response = chat.invoke([human_messages], template_id=template_id)\n", "print(response.content)" ] }, @@ -310,14 +343,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Hello! As an AI language model, I don't have feelings or a physical state, but I'm functioning properly and ready to assist you with any questions or tasks you might have. How can I help you today?" + "It looks like your message got cut off. If you need information about Dense Retrieval (DR) or any other topic, please provide more details or clarify your question." ] } ], @@ -338,14 +371,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Hello! As an AI language model, I don't have feelings or a physical form, but I'm functioning properly and ready to assist you. How can I help you today?" + "Woof! 🐾 How can I help you today? Want to play fetch or maybe go for a walk 🐶🦴" ] } ], @@ -365,6 +398,275 @@ " sys.stdout.write(chunk.content)\n", " sys.stdout.flush()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tool/Function Calling\n", + "\n", + "LangChain PremAI supports tool/function calling. Tool/function calling allows a model to respond to a given prompt by generating output that matches a user-defined schema. \n", + "\n", + "- You can learn all about tool calling in details [in our documentation here](https://docs.premai.io/get-started/function-calling).\n", + "- You can learn more about langchain tool calling in [this part of the docs](https://python.langchain.com/v0.1/docs/modules/model_io/chat/function_calling).\n", + "\n", + "**NOTE:**\n", + "The current version of LangChain ChatPremAI do not support function/tool calling with streaming support. Streaming support along with function calling will come soon. \n", + "\n", + "#### Passing tools to model\n", + "\n", + "In order to pass tools and let the LLM choose the tool it needs to call, we need to pass a tool schema. A tool schema is the function definition along with proper docstring on what does the function do, what each argument of the function is etc. Below are some simple arithmetic functions with their schema. \n", + "\n", + "**NOTE:** When defining function/tool schema, do not forget to add information around the function arguments, otherwise it would throw error." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "from langchain_core.tools import tool\n", + "\n", + "\n", + "# Define the schema for function arguments\n", + "class OperationInput(BaseModel):\n", + " a: int = Field(description=\"First number\")\n", + " b: int = Field(description=\"Second number\")\n", + "\n", + "\n", + "# Now define the function where schema for argument will be OperationInput\n", + "@tool(\"add\", args_schema=OperationInput, return_direct=True)\n", + "def add(a: int, b: int) -> int:\n", + " \"\"\"Adds a and b.\n", + "\n", + " Args:\n", + " a: first int\n", + " b: second int\n", + " \"\"\"\n", + " return a + b\n", + "\n", + "\n", + "@tool(\"multiply\", args_schema=OperationInput, return_direct=True)\n", + "def multiply(a: int, b: int) -> int:\n", + " \"\"\"Multiplies a and b.\n", + "\n", + " Args:\n", + " a: first int\n", + " b: second int\n", + " \"\"\"\n", + " return a * b" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Binding tool schemas with our LLM\n", + "\n", + "We will now use the `bind_tools` method to convert our above functions to a \"tool\" and binding it with the model. This means we are going to pass these tool informations everytime we invoke the model. " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "tools = [add, multiply]\n", + "llm_with_tools = chat.bind_tools(tools)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After this, we get the response from the model which is now binded with the tools. " + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What is 3 * 12? Also, what is 11 + 49?\"\n", + "\n", + "messages = [HumanMessage(query)]\n", + "ai_msg = llm_with_tools.invoke(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see, when our chat model is binded with tools, then based on the given prompt, it calls the correct set of the tools and sequentially. " + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'multiply',\n", + " 'args': {'a': 3, 'b': 12},\n", + " 'id': 'call_A9FL20u12lz6TpOLaiS6rFa8'},\n", + " {'name': 'add',\n", + " 'args': {'a': 11, 'b': 49},\n", + " 'id': 'call_MPKYGLHbf39csJIyb5BZ9xIk'}]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ai_msg.tool_calls" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We append this message shown above to the LLM which acts as a context and makes the LLM aware that what all functions it has called. " + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "messages.append(ai_msg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since tool calling happens into two phases, where:\n", + "\n", + "1. in our first call, we gathered all the tools that the LLM decided to tool, so that it can get the result as an added context to give more accurate and hallucination free result. \n", + "\n", + "2. in our second call, we will parse those set of tools decided by LLM and run them (in our case it will be the functions we defined, with the LLM's extracted arguments) and pass this result to the LLM" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.messages import ToolMessage\n", + "\n", + "for tool_call in ai_msg.tool_calls:\n", + " selected_tool = {\"add\": add, \"multiply\": multiply}[tool_call[\"name\"].lower()]\n", + " tool_output = selected_tool.invoke(tool_call[\"args\"])\n", + " messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we call the LLM (binded with the tools) with the function response added in it's context. " + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The final answers are:\n", + "\n", + "- 3 * 12 = 36\n", + "- 11 + 49 = 60\n" + ] + } + ], + "source": [ + "response = llm_with_tools.invoke(messages)\n", + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining tool schemas: Pydantic class\n", + "\n", + "Above we have shown how to define schema using `tool` decorator, however we can equivalently define the schema using Pydantic. Pydantic is useful when your tool inputs are more complex:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n", + "\n", + "\n", + "class add(BaseModel):\n", + " \"\"\"Add two integers together.\"\"\"\n", + "\n", + " a: int = Field(..., description=\"First integer\")\n", + " b: int = Field(..., description=\"Second integer\")\n", + "\n", + "\n", + "class multiply(BaseModel):\n", + " \"\"\"Multiply two integers together.\"\"\"\n", + "\n", + " a: int = Field(..., description=\"First integer\")\n", + " b: int = Field(..., description=\"Second integer\")\n", + "\n", + "\n", + "tools = [add, multiply]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can bind them to chat models and directly get the result:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[multiply(a=3, b=12), add(a=11, b=49)]" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain = llm_with_tools | PydanticToolsParser(tools=[multiply, add])\n", + "chain.invoke(query)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, as done above, we parse this and run this functions and call the LLM once again to get the result." + ] } ], "metadata": { @@ -383,7 +685,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.9.19" } }, "nbformat": 4, diff --git a/docs/docs/integrations/providers/premai.md b/docs/docs/integrations/providers/premai.md index e0592ee7666c6..7bf88d1fd0208 100644 --- a/docs/docs/integrations/providers/premai.md +++ b/docs/docs/integrations/providers/premai.md @@ -38,7 +38,7 @@ import getpass if "PREMAI_API_KEY" not in os.environ: os.environ["PREMAI_API_KEY"] = getpass.getpass("PremAI API Key:") -chat = ChatPremAI(project_id=8) +chat = ChatPremAI(project_id=1234, model_name="gpt-4o") ``` ### Chat Completions @@ -50,7 +50,8 @@ The first one will give us a static result. Whereas the second one will stream t ```python human_message = HumanMessage(content="Who are you?") -chat.invoke([human_message]) +response = chat.invoke([human_message]) +print(response.content) ``` You can provide system prompt here like this: @@ -84,8 +85,8 @@ Repositories are also supported in langchain premai. Here is how you can do it. ```python -query = "what is the diameter of individual Galaxy" -repository_ids = [1991, ] +query = "Which models are used for dense retrieval" +repository_ids = [1985,] repositories = dict( ids=repository_ids, similarity_threshold=0.3, @@ -100,6 +101,8 @@ First we start by defining our repository with some repository ids. Make sure th Now, we connect the repository with our chat object to invoke RAG based generations. ```python +import json + response = chat.invoke(query, max_tokens=100, repositories=repositories) print(response.content) @@ -109,25 +112,22 @@ print(json.dumps(response.response_metadata, indent=4)) This is how an output looks like. ```bash -The diameters of individual galaxies range from 80,000-150,000 light-years. +Dense retrieval models typically include: + +1. **BERT-based Models**: Such as DPR (Dense Passage Retrieval) which uses BERT for encoding queries and passages. +2. **ColBERT**: A model that combines BERT with late interaction mechanisms. +3. **ANCE (Approximate Nearest Neighbor Negative Contrastive Estimation)**: Uses BERT and focuses on efficient retrieval. +4. **TCT-ColBERT**: A variant of ColBERT that uses a two-tower { "document_chunks": [ { - "repository_id": 19xx, - "document_id": 13xx, - "chunk_id": 173xxx, - "document_name": "Kegy 202 Chapter 2", - "similarity_score": 0.586126983165741, - "content": "n thousands\n of light-years. The diameters of individual\n galaxies range from 80,000-150,000 light\n " - }, - { - "repository_id": 19xx, - "document_id": 13xx, - "chunk_id": 173xxx, - "document_name": "Kegy 202 Chapter 2", - "similarity_score": 0.4815782308578491, - "content": " for development of galaxies. A galaxy contains\n a large number of stars. Galaxies spread over\n vast distances that are measured in thousands\n " - }, + "repository_id": 1985, + "document_id": 1306, + "chunk_id": 173899, + "document_name": "[D] Difference between sparse and dense informati\u2026", + "similarity_score": 0.3209080100059509, + "content": "with the difference or anywhere\nwhere I can read about it?\n\n\n 17 9\n\n\n u/ScotiabankCanada \u2022 Promoted\n\n\n Accelerate your study permit process\n with Scotiabank's Student GIC\n Program. We're here to help you tur\u2026\n\n\n startright.scotiabank.com Learn More\n\n\n Add a Comment\n\n\nSort by: Best\n\n\n DinosParkour \u2022 1y ago\n\n\n Dense Retrieval (DR) m" + } ] } ``` @@ -264,4 +264,164 @@ doc_result[:5] 0.0008162345038726926, -0.004556538071483374, 0.02918623760342598, - -0.02547479420900345] \ No newline at end of file + -0.02547479420900345] + +## Tool/Function Calling + +LangChain PremAI supports tool/function calling. Tool/function calling allows a model to respond to a given prompt by generating output that matches a user-defined schema. + +- You can learn all about tool calling in details [in our documentation here](https://docs.premai.io/get-started/function-calling). +- You can learn more about langchain tool calling in [this part of the docs](https://python.langchain.com/v0.1/docs/modules/model_io/chat/function_calling). + +**NOTE:** + +> The current version of LangChain ChatPremAI do not support function/tool calling with streaming support. Streaming support along with function calling will come soon. + +### Passing tools to model + +In order to pass tools and let the LLM choose the tool it needs to call, we need to pass a tool schema. A tool schema is the function definition along with proper docstring on what does the function do, what each argument of the function is etc. Below are some simple arithmetic functions with their schema. + +**NOTE:** +> When defining function/tool schema, do not forget to add information around the function arguments, otherwise it would throw error. + +```python +from langchain_core.tools import tool +from langchain_core.pydantic_v1 import BaseModel, Field + +# Define the schema for function arguments +class OperationInput(BaseModel): + a: int = Field(description="First number") + b: int = Field(description="Second number") + + +# Now define the function where schema for argument will be OperationInput +@tool("add", args_schema=OperationInput, return_direct=True) +def add(a: int, b: int) -> int: + """Adds a and b. + + Args: + a: first int + b: second int + """ + return a + b + + +@tool("multiply", args_schema=OperationInput, return_direct=True) +def multiply(a: int, b: int) -> int: + """Multiplies a and b. + + Args: + a: first int + b: second int + """ + return a * b +``` + +### Binding tool schemas with our LLM + +We will now use the `bind_tools` method to convert our above functions to a "tool" and binding it with the model. This means we are going to pass these tool informations everytime we invoke the model. + +```python +tools = [add, multiply] +llm_with_tools = chat.bind_tools(tools) +``` + +After this, we get the response from the model which is now binded with the tools. + +```python +query = "What is 3 * 12? Also, what is 11 + 49?" + +messages = [HumanMessage(query)] +ai_msg = llm_with_tools.invoke(messages) +``` + +As we can see, when our chat model is binded with tools, then based on the given prompt, it calls the correct set of the tools and sequentially. + +```python +ai_msg.tool_calls +``` +**Output** + +```python +[{'name': 'multiply', + 'args': {'a': 3, 'b': 12}, + 'id': 'call_A9FL20u12lz6TpOLaiS6rFa8'}, + {'name': 'add', + 'args': {'a': 11, 'b': 49}, + 'id': 'call_MPKYGLHbf39csJIyb5BZ9xIk'}] +``` + +We append this message shown above to the LLM which acts as a context and makes the LLM aware that what all functions it has called. + +```python +messages.append(ai_msg) +``` + +Since tool calling happens into two phases, where: + +1. in our first call, we gathered all the tools that the LLM decided to tool, so that it can get the result as an added context to give more accurate and hallucination free result. + +2. in our second call, we will parse those set of tools decided by LLM and run them (in our case it will be the functions we defined, with the LLM's extracted arguments) and pass this result to the LLM + +```python +from langchain_core.messages import ToolMessage + +for tool_call in ai_msg.tool_calls: + selected_tool = {"add": add, "multiply": multiply}[tool_call["name"].lower()] + tool_output = selected_tool.invoke(tool_call["args"]) + messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"])) +``` + +Finally, we call the LLM (binded with the tools) with the function response added in it's context. + +```python +response = llm_with_tools.invoke(messages) +print(response.content) +``` +**Output** + +```txt +The final answers are: + +- 3 * 12 = 36 +- 11 + 49 = 60 +``` + +### Defining tool schemas: Pydantic class `Optional` + +Above we have shown how to define schema using `tool` decorator, however we can equivalently define the schema using Pydantic. Pydantic is useful when your tool inputs are more complex: + +```python +from langchain_core.output_parsers.openai_tools import PydanticToolsParser + +class add(BaseModel): + """Add two integers together.""" + + a: int = Field(..., description="First integer") + b: int = Field(..., description="Second integer") + + +class multiply(BaseModel): + """Multiply two integers together.""" + + a: int = Field(..., description="First integer") + b: int = Field(..., description="Second integer") + + +tools = [add, multiply] +``` + +Now, we can bind them to chat models and directly get the result: + +```python +chain = llm_with_tools | PydanticToolsParser(tools=[multiply, add]) +chain.invoke(query) +``` + +**Output** + +```txt +[multiply(a=3, b=12), add(a=11, b=49)] +``` + +Now, as done above, we parse this and run this functions and call the LLM once again to get the result. \ No newline at end of file diff --git a/libs/community/langchain_community/chat_models/premai.py b/libs/community/langchain_community/chat_models/premai.py index 311e94e763e8f..1b91bca2c6a26 100644 --- a/libs/community/langchain_community/chat_models/premai.py +++ b/libs/community/langchain_community/chat_models/premai.py @@ -12,6 +12,7 @@ Iterator, List, Optional, + Sequence, Tuple, Type, Union, @@ -20,6 +21,7 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.messages import ( @@ -33,6 +35,7 @@ HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import ( @@ -41,7 +44,10 @@ Field, SecretStr, ) +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import get_from_dict_or_env, pre_init +from langchain_core.utils.function_calling import convert_to_openai_tool if TYPE_CHECKING: from premai.api.chat_completions.v1_chat_completions_create import ( @@ -51,6 +57,19 @@ logger = logging.getLogger(__name__) +TOOL_PROMPT_HEADER = """ +Given the set of tools you used and the response, provide the final answer\n +""" + +INTERMEDIATE_TOOL_RESULT_TEMPLATE = """ +{json} +""" + +SINGLE_TOOL_PROMPT_TEMPLATE = """ +tool id: {tool_id} +tool_response: {tool_response} +""" + class ChatPremAPIError(Exception): """Error with the `PremAI` API.""" @@ -91,8 +110,22 @@ def _response_to_result( raise ChatPremAPIError(f"ChatResponse must have a content: {content}") if role == "assistant": + tool_calls = choice.message["tool_calls"] + if tool_calls is None: + tools = [] + else: + tools = [ + { + "id": tool_call["id"], + "name": tool_call["function"]["name"], + "args": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] generations.append( - ChatGeneration(text=content, message=AIMessage(content=content)) + ChatGeneration( + text=content, message=AIMessage(content=content, tool_calls=tools) + ) ) elif role == "user": generations.append( @@ -156,41 +189,65 @@ def _messages_to_prompt_dict( system_prompt: Optional[str] = None examples_and_messages: List[Dict[str, Any]] = [] - if template_id is not None: - params: Dict[str, str] = {} - for input_msg in input_messages: - if isinstance(input_msg, SystemMessage): - system_prompt = str(input_msg.content) + for input_msg in input_messages: + if isinstance(input_msg, SystemMessage): + system_prompt = str(input_msg.content) + + elif isinstance(input_msg, HumanMessage): + if template_id is None: + examples_and_messages.append( + {"role": "user", "content": str(input_msg.content)} + ) else: + params: Dict[str, str] = {} assert (input_msg.id is not None) and (input_msg.id != ""), ValueError( "When using prompt template there should be id associated ", "with each HumanMessage", ) params[str(input_msg.id)] = str(input_msg.content) - - examples_and_messages.append( - {"role": "user", "template_id": template_id, "params": params} - ) - - for input_msg in input_messages: - if isinstance(input_msg, AIMessage): examples_and_messages.append( - {"role": "assistant", "content": str(input_msg.content)} + {"role": "user", "template_id": template_id, "params": params} ) - else: - for input_msg in input_messages: - if isinstance(input_msg, SystemMessage): - system_prompt = str(input_msg.content) - elif isinstance(input_msg, HumanMessage): - examples_and_messages.append( - {"role": "user", "content": str(input_msg.content)} - ) - elif isinstance(input_msg, AIMessage): + elif isinstance(input_msg, AIMessage): + if input_msg.tool_calls is None or len(input_msg.tool_calls) == 0: examples_and_messages.append( {"role": "assistant", "content": str(input_msg.content)} ) else: - raise ChatPremAPIError("No such role explicitly exists") + ai_msg_to_json = { + "id": input_msg.id, + "content": input_msg.content, + "response_metadata": input_msg.response_metadata, + "tool_calls": input_msg.tool_calls, + } + examples_and_messages.append( + { + "role": "assistant", + "content": INTERMEDIATE_TOOL_RESULT_TEMPLATE.format( + json=ai_msg_to_json, + ), + } + ) + elif isinstance(input_msg, ToolMessage): + pass + + else: + raise ChatPremAPIError("No such role explicitly exists") + + # do a seperate search for tool calls + tool_prompt = "" + for input_msg in input_messages: + if isinstance(input_msg, ToolMessage): + tool_id = input_msg.tool_call_id + tool_result = input_msg.content + tool_prompt += SINGLE_TOOL_PROMPT_TEMPLATE.format( + tool_id=tool_id, tool_response=tool_result + ) + if tool_prompt != "": + prompt = TOOL_PROMPT_HEADER + prompt += tool_prompt + examples_and_messages.append({"role": "user", "content": prompt}) + return system_prompt, examples_and_messages @@ -289,7 +346,6 @@ def _default_params(self) -> Dict[str, Any]: def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: kwargs_to_ignore = [ "top_p", - "tools", "frequency_penalty", "presence_penalty", "logit_bias", @@ -392,6 +448,14 @@ def _stream( except Exception as _: continue + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) + def create_prem_retry_decorator( llm: ChatPremAI, diff --git a/libs/community/tests/unit_tests/chat_models/test_premai.py b/libs/community/tests/unit_tests/chat_models/test_premai.py index 8318fc890a608..b4275fca47bb6 100644 --- a/libs/community/tests/unit_tests/chat_models/test_premai.py +++ b/libs/community/tests/unit_tests/chat_models/test_premai.py @@ -3,12 +3,16 @@ from typing import cast import pytest -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture from langchain_community.chat_models import ChatPremAI -from langchain_community.chat_models.premai import _messages_to_prompt_dict +from langchain_community.chat_models.premai import ( + SINGLE_TOOL_PROMPT_TEMPLATE, + TOOL_PROMPT_HEADER, + _messages_to_prompt_dict, +) @pytest.mark.requires("premai") @@ -36,13 +40,20 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None: AIMessage(content="AI message #1"), HumanMessage(content="User message #2"), AIMessage(content="AI message #2"), + ToolMessage(content="Tool Message #1", tool_call_id="test_tool"), + AIMessage(content="AI message #3"), ] ) + expected_tool_message = SINGLE_TOOL_PROMPT_TEMPLATE.format( + tool_id="test_tool", tool_response="Tool Message #1" + ) expected = [ {"role": "user", "content": "User message #1"}, {"role": "assistant", "content": "AI message #1"}, {"role": "user", "content": "User message #2"}, {"role": "assistant", "content": "AI message #2"}, + {"role": "assistant", "content": "AI message #3"}, + {"role": "user", "content": TOOL_PROMPT_HEADER + expected_tool_message}, ] assert system_message == "System Prompt"