From 0f5e116ba4b76c95140d71f00b71b21cb3ee6723 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Tue, 17 Sep 2024 10:05:44 -0700 Subject: [PATCH] feat(langgraph): Adds stores & managed values (memory) (#476) Co-authored-by: jacoblee93 --- docs/mkdocs.yml | 1 + examples/how-tos/shared-state.ipynb | 564 ++++++++++++++++++++ libs/langgraph/src/channels/base.ts | 17 +- libs/langgraph/src/constants.ts | 3 + libs/langgraph/src/graph/annotation.ts | 21 +- libs/langgraph/src/graph/state.ts | 18 +- libs/langgraph/src/managed/base.ts | 187 +++++++ libs/langgraph/src/managed/index.ts | 3 + libs/langgraph/src/managed/is_last_step.ts | 8 + libs/langgraph/src/managed/shared_value.ts | 118 +++++ libs/langgraph/src/pregel/algo.ts | 244 ++++++--- libs/langgraph/src/pregel/index.ts | 141 ++++- libs/langgraph/src/pregel/loop.ts | 42 +- libs/langgraph/src/pregel/types.ts | 11 +- libs/langgraph/src/pregel/validate.ts | 10 +- libs/langgraph/src/store/base.ts | 18 + libs/langgraph/src/store/batch.ts | 166 ++++++ libs/langgraph/src/store/index.ts | 3 + libs/langgraph/src/store/memory.ts | 34 ++ libs/langgraph/src/tests/pregel.test.ts | 570 ++++++++++++++++++++- libs/langgraph/src/tests/store.test.ts | 46 ++ libs/langgraph/src/utils.ts | 25 + libs/langgraph/src/web.ts | 2 + 23 files changed, 2140 insertions(+), 112 deletions(-) create mode 100644 examples/how-tos/shared-state.ipynb create mode 100644 libs/langgraph/src/managed/base.ts create mode 100644 libs/langgraph/src/managed/index.ts create mode 100644 libs/langgraph/src/managed/is_last_step.ts create mode 100644 libs/langgraph/src/managed/shared_value.ts create mode 100644 libs/langgraph/src/store/base.ts create mode 100644 libs/langgraph/src/store/batch.ts create mode 100644 libs/langgraph/src/store/index.ts create mode 100644 libs/langgraph/src/store/memory.ts create mode 100644 libs/langgraph/src/tests/store.test.ts diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c8ba8d64..c66cea34 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -105,6 +105,7 @@ nav: - Manage conversation history: "how-tos/manage-conversation-history.ipynb" - How to delete messages: "how-tos/delete-messages.ipynb" - Add summary of the conversation history: "how-tos/add-summary-conversation-history.ipynb" + - Share state between threads: "how-tos/shared-state.ipynb" - Human-in-the-loop: - Add breakpoints: "how-tos/breakpoints.ipynb" - Add dynamic breakpoints: "how-tos/dynamic_breakpoints.ipynb" diff --git a/examples/how-tos/shared-state.ipynb b/examples/how-tos/shared-state.ipynb new file mode 100644 index 00000000..6dc9763b --- /dev/null +++ b/examples/how-tos/shared-state.ipynb @@ -0,0 +1,564 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7240d5b5-9dac-4070-8a9e-2350fb01e0be", + "metadata": {}, + "source": [ + "# How to share state between threads\n", + "\n", + "By default, state in a graph is scoped to that thread.\n", + "LangGraph also allows you to specify a \"scope\" for a given key/value pair that exists between threads. This can be useful for storing information that is shared between threads. For instance, you may want to store information about a user's preferences expressed in one thread, and then use that information in another thread.\n", + "\n", + "In this notebook we will go through an example of how to construct and use such a graph.\n", + "\n", + "## Setup\n", + "\n", + "First, let's install the required packages and set our API keys\n", + "\n", + "```bash\n", + "npm install @langchain/openai @langchain/langgraph @langchain/core zod uuid\n", + "```\n", + "\n", + "Then set your enviroment variables for OpenAI:\n", + "\n", + "```typescript\n", + "process.env.OPENAI_API_KEY = \"your-openai-api-key\";\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "51b6817d", + "metadata": {}, + "source": [ + "
\n", + "

Set up LangSmith for LangGraph development

\n", + "

\n", + " Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started here. \n", + "

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "c4c550b5-1954-496b-8b9d-800361af17dc", + "metadata": {}, + "source": [ + "## Create graph\n", + "\n", + "In this example we will create a graph that will let us store information about a user's preferences. We will do so by defining a state key that will be scoped to a `user_id`, and allowing the model to populate this field as it deems fit (by providing the model with a tool to save information about the user).\n", + "\n", + " \n", + "
\n", + "

Typing shared state keys

\n", + "

\n", + " Shared state channels (keys) MUST be objects (see info channel in the AgentState example below)\n", + "

\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "337ee88f", + "metadata": {}, + "outputs": [], + "source": [ + "import { z } from \"zod\";\n", + "import {\n", + " START,\n", + " END,\n", + " Annotation,\n", + " StateGraph,\n", + " MemoryStore,\n", + " SharedValue,\n", + " MemorySaver,\n", + "} from \"@langchain/langgraph\";\n", + "import { ChatOpenAI } from \"@langchain/openai\";\n", + "import {\n", + " type AIMessage,\n", + " type BaseMessage,\n", + " ToolMessage,\n", + "} from \"@langchain/core/messages\";\n", + "import { RunnableConfig } from \"@langchain/core/runnables\";\n", + "import { v4 as uuidv4 } from \"uuid\";\n", + "\n", + "const infoSchema = z.object({\n", + " fact: z.string().describe(\"The fact about the user\"),\n", + " topic: z.string().describe(\"The topic of the fact\"),\n", + "});\n", + "\n", + "const AgentAnnotation = Annotation.Root({\n", + " messages: Annotation({\n", + " reducer: (a, b) => a.concat(b),\n", + " default: () => [],\n", + " }),\n", + " // IMPORTANT\n", + " // This is how you define a shared state value.\n", + " // The string passed to `.on` is the key that will\n", + " // be used to store the value in the shared state.\n", + " info: SharedValue.on(\"user_id\"),\n", + "});\n", + "\n", + "const prompt = `You are helpful assistant.\n", + "\n", + "Here is what you know about the user:\n", + "\n", + "\n", + "{info}\n", + "\n", + "\n", + "Help out the user. If the user tells you any information about themselves, save the information using the \\`Info\\` tool.\n", + "\n", + "This means if the user provides any sort of fact about themselves, be it an opinion they have, a fact about themselves, etc. SAVE IT!\n", + "`;\n", + "\n", + "const model = new ChatOpenAI({\n", + " model: \"gpt-4o\",\n", + " temperature: 0,\n", + "}).bindTools([\n", + " {\n", + " name: \"Info\",\n", + " description: \"Save the information provided by the user\",\n", + " schema: infoSchema,\n", + " },\n", + "]);\n", + "\n", + "const callModel = async (\n", + " state: typeof AgentAnnotation.State\n", + "): Promise> => {\n", + " const facts = Object.values(state.info).map((d) => d.fact);\n", + " const info = facts.join(\"\\n\");\n", + " const systemMsg = prompt.replace(\"{info}\", info);\n", + " const response = await model.invoke([\n", + " { role: \"system\", content: systemMsg },\n", + " ...state.messages,\n", + " ]);\n", + " return { messages: [response] };\n", + "};\n", + "\n", + "const route = (state: typeof AgentAnnotation.State): string => {\n", + " const lastMessage = state.messages[state.messages.length - 1];\n", + " if (!(\"tool_calls\" in lastMessage)) {\n", + " throw new Error(\"Expected an AI message with tool calls.\");\n", + " }\n", + "\n", + " return (lastMessage as AIMessage).tool_calls?.length ? \"update_memory\" : END;\n", + "};\n", + "\n", + "const updateMemory = (\n", + " state: typeof AgentAnnotation.State\n", + "): Partial => {\n", + " const toolResponseMessages: ToolMessage[] = [];\n", + " const memories: Record> = {};\n", + "\n", + " const lastMessage = state.messages[state.messages.length - 1];\n", + " if (!(\"tool_calls\" in lastMessage)) {\n", + " throw new Error(\"Expected an AI message with tool calls.\");\n", + " }\n", + " const castLastMessage = lastMessage as AIMessage;\n", + "\n", + " castLastMessage.tool_calls?.forEach((tc) => {\n", + " toolResponseMessages.push(\n", + " new ToolMessage({\n", + " content: \"Saved!\",\n", + " tool_call_id: tc.id as string,\n", + " })\n", + " );\n", + " memories[uuidv4()] = {\n", + " fact: tc.args.fact,\n", + " topic: tc.args.topic,\n", + " };\n", + " });\n", + "\n", + " return { messages: toolResponseMessages, info: memories };\n", + "};\n", + "\n", + "const memory = new MemorySaver();\n", + "// IMPORTANT\n", + "// In order to use shared values, you must initialize a store like this:\n", + "const kv = new MemoryStore();\n", + "\n", + "const graph = new StateGraph(AgentAnnotation)\n", + " .addNode(\"call_model\", callModel)\n", + " .addNode(\"update_memory\", updateMemory)\n", + " .addEdge(\"update_memory\", END)\n", + " .addEdge(START, \"call_model\")\n", + " .addConditionalEdges(\"call_model\", route);\n", + "\n", + "const compiledGraph = graph.compile({\n", + " checkpointer: memory,\n", + " // Then, pass it to `.compile` like this:\n", + " store: kv,\n", + "});" + ] + }, + { + "cell_type": "markdown", + "id": "552d4e33-556d-4fa5-8094-2a076bc21529", + "metadata": {}, + "source": [ + "## Run graph on one thread\n", + "\n", + "We can now run the graph on one thread and give it some information" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "18bd8679-3a73-4033-bfb4-5093ac1f5d7f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " call_model: {\n", + " messages: [\n", + " AIMessage {\n", + " \"id\": \"chatcmpl-A8VfAQGfZZ4WBbG4zrbwzc34k1Ji3\",\n", + " \"content\": \"Hello! How can I assist you today?\",\n", + " \"additional_kwargs\": {},\n", + " \"response_metadata\": {\n", + " \"tokenUsage\": {\n", + " \"completionTokens\": 10,\n", + " \"promptTokens\": 137,\n", + " \"totalTokens\": 147\n", + " },\n", + " \"finish_reason\": \"stop\",\n", + " \"system_fingerprint\": \"fp_a5d11b2ef2\"\n", + " },\n", + " \"tool_calls\": [],\n", + " \"invalid_tool_calls\": [],\n", + " \"usage_metadata\": {\n", + " \"input_tokens\": 137,\n", + " \"output_tokens\": 10,\n", + " \"total_tokens\": 147\n", + " }\n", + " }\n", + " ]\n", + " }\n", + "}\n", + "{\n", + " call_model: {\n", + " messages: [\n", + " AIMessage {\n", + " \"id\": \"chatcmpl-A8VfAh46qQIvHejVESnvqYwnAIor6\",\n", + " \"content\": \"\",\n", + " \"additional_kwargs\": {\n", + " \"tool_calls\": [\n", + " {\n", + " \"id\": \"call_IbF0aL78Xep9Xpz3UHLn0POR\",\n", + " \"type\": \"function\",\n", + " \"function\": \"[Object]\"\n", + " }\n", + " ]\n", + " },\n", + " \"response_metadata\": {\n", + " \"tokenUsage\": {\n", + " \"completionTokens\": 23,\n", + " \"promptTokens\": 159,\n", + " \"totalTokens\": 182\n", + " },\n", + " \"finish_reason\": \"tool_calls\",\n", + " \"system_fingerprint\": \"fp_25624ae3a5\"\n", + " },\n", + " \"tool_calls\": [\n", + " {\n", + " \"name\": \"Info\",\n", + " \"args\": {\n", + " \"fact\": \"The user likes pepperoni pizza\",\n", + " \"topic\": \"Food Preferences\"\n", + " },\n", + " \"type\": \"tool_call\",\n", + " \"id\": \"call_IbF0aL78Xep9Xpz3UHLn0POR\"\n", + " }\n", + " ],\n", + " \"invalid_tool_calls\": [],\n", + " \"usage_metadata\": {\n", + " \"input_tokens\": 159,\n", + " \"output_tokens\": 23,\n", + " \"total_tokens\": 182\n", + " }\n", + " }\n", + " ]\n", + " }\n", + "}\n", + "{\n", + " update_memory: {\n", + " messages: [\n", + " ToolMessage {\n", + " \"content\": \"Saved!\",\n", + " \"additional_kwargs\": {},\n", + " \"response_metadata\": {},\n", + " \"tool_call_id\": \"call_IbF0aL78Xep9Xpz3UHLn0POR\"\n", + " }\n", + " ],\n", + " info: { '2bf3fed4-028a-4c86-93fe-717cde64e8e7': [Object] }\n", + " }\n", + "}\n", + "{\n", + " call_model: {\n", + " messages: [\n", + " AIMessage {\n", + " \"id\": \"chatcmpl-A8VfBbqFPyFfN65f4Tk3k0kdtXHtJ\",\n", + " \"content\": \"\",\n", + " \"additional_kwargs\": {\n", + " \"tool_calls\": [\n", + " {\n", + " \"id\": \"call_ZkpyxLljOOfhp4oAMyO2R1Lh\",\n", + " \"type\": \"function\",\n", + " \"function\": \"[Object]\"\n", + " }\n", + " ]\n", + " },\n", + " \"response_metadata\": {\n", + " \"tokenUsage\": {\n", + " \"completionTokens\": 23,\n", + " \"promptTokens\": 208,\n", + " \"totalTokens\": 231\n", + " },\n", + " \"finish_reason\": \"tool_calls\",\n", + " \"system_fingerprint\": \"fp_a5d11b2ef2\"\n", + " },\n", + " \"tool_calls\": [\n", + " {\n", + " \"name\": \"Info\",\n", + " \"args\": {\n", + " \"fact\": \"The user just moved to San Francisco\",\n", + " \"topic\": \"Location\"\n", + " },\n", + " \"type\": \"tool_call\",\n", + " \"id\": \"call_ZkpyxLljOOfhp4oAMyO2R1Lh\"\n", + " }\n", + " ],\n", + " \"invalid_tool_calls\": [],\n", + " \"usage_metadata\": {\n", + " \"input_tokens\": 208,\n", + " \"output_tokens\": 23,\n", + " \"total_tokens\": 231\n", + " }\n", + " }\n", + " ]\n", + " }\n", + "}\n", + "{\n", + " update_memory: {\n", + " messages: [\n", + " ToolMessage {\n", + " \"content\": \"Saved!\",\n", + " \"additional_kwargs\": {},\n", + " \"response_metadata\": {},\n", + " \"tool_call_id\": \"call_ZkpyxLljOOfhp4oAMyO2R1Lh\"\n", + " }\n", + " ],\n", + " info: { 'ee0085c3-0396-4be8-b0cd-869194aacd5b': [Object] }\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "const config = {\n", + " configurable: {\n", + " thread_id: \"1\",\n", + " // Notice we're specifying `user_id` here, which matches the key name we passed to `SharedValue.on()`\n", + " // Without this, our graph wouldn't be able to access the shared state value.\n", + " user_id: \"1\"\n", + " },\n", + " streamMode: \"updates\" as const\n", + "};\n", + "\n", + "// First let's just say hi to the AI\n", + "for await (const update of await compiledGraph.stream({\n", + " messages: [{ role: \"user\", content: \"hi\" }],\n", + "}, config)) {\n", + " console.log(update);\n", + "}\n", + "\n", + "// Let's continue the conversation (by passing the same config) and tell the AI we like pepperoni pizza\n", + "for await (const update of await compiledGraph.stream({\n", + " messages: [{ role: \"user\", content: \"i like pepperoni pizza\" }],\n", + "}, config)) {\n", + " console.log(update);\n", + "}\n", + "\n", + "// Let's continue the conversation even further (by passing the same config) and tell the AI we live in SF\n", + "for await (const update of await compiledGraph.stream({\n", + " messages: [{ role: \"user\", content: \"i also just moved to SF\" }],\n", + "}, config)) {\n", + " console.log(update);\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "b8c416fa-086a-491d-a7d3-57091f6413e3", + "metadata": {}, + "source": [ + "## Run graph on a different thread\n", + "\n", + "We can now run the graph on a different thread and see that it remembers facts about the user (specifically that the user likes pepperoni pizza and lives in SF):" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e240f025-ff8b-4d17-beb7-2420c0575dd9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " call_model: {\n", + " messages: [\n", + " AIMessage {\n", + " \"id\": \"chatcmpl-A8VgEKTv4D2VzGdGwq6FVv09HczZf\",\n", + " \"content\": \"Sure! Since you just moved to San Francisco, here are some popular restaurants you might enjoy:\\n\\n1. **Tony's Pizza Napoletana**\\n - **Cuisine:** Italian, Pizza\\n - **Location:** 1570 Stockton St, San Francisco, CA 94133\\n - **Why you might like it:** They have a great selection of pizzas, including pepperoni!\\n\\n2. **House of Prime Rib**\\n - **Cuisine:** American, Steakhouse\\n - **Location:** 1906 Van Ness Ave, San Francisco, CA 94109\\n - **Why you might like it:** If you're in the mood for a hearty meal, their prime rib is highly recommended.\\n\\n3. **Tartine Bakery**\\n - **Cuisine:** Bakery, Cafe\\n - **Location:** 600 Guerrero St, San Francisco, CA 94110\\n - **Why you might like it:** Perfect for a lighter meal or dessert, their pastries are famous.\\n\\n4. **La Taqueria**\\n - **Cuisine:** Mexican\\n - **Location:** 2889 Mission St, San Francisco, CA 94110\\n - **Why you might like it:** Known for their delicious tacos and burritos.\\n\\n5. **Swan Oyster Depot**\\n - **Cuisine:** Seafood\\n - **Location:** 1517 Polk St, San Francisco, CA 94109\\n - **Why you might like it:** Great spot for fresh seafood.\\n\\nWould you like more information on any of these, or do you have a specific type of cuisine in mind?\",\n", + " \"additional_kwargs\": {},\n", + " \"response_metadata\": {\n", + " \"tokenUsage\": {\n", + " \"completionTokens\": 324,\n", + " \"promptTokens\": 166,\n", + " \"totalTokens\": 490\n", + " },\n", + " \"finish_reason\": \"stop\",\n", + " \"system_fingerprint\": \"fp_25624ae3a5\"\n", + " },\n", + " \"tool_calls\": [],\n", + " \"invalid_tool_calls\": [],\n", + " \"usage_metadata\": {\n", + " \"input_tokens\": 166,\n", + " \"output_tokens\": 324,\n", + " \"total_tokens\": 490\n", + " }\n", + " }\n", + " ]\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "const config2 = {\n", + " configurable: {\n", + " // Notice we have a new thread ID, but the same user ID.\n", + " // This allows us to access the shared state value.\n", + " thread_id: \"2\",\n", + " user_id: \"1\"\n", + " },\n", + " streamMode: \"updates\" as const\n", + "};\n", + "\n", + "for await (const update of await compiledGraph.stream({\n", + " messages: [{ role: \"user\", content: \"where and what should i eat for dinner? Can you list some restaurants?\" }],\n", + "}, config2)) {\n", + " console.log(update);\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "091995d3", + "metadata": {}, + "source": [ + "Perfect! The AI recommended restaurants in SF, and included a pizza restaurant at the top of it's list.\n", + "\n", + "Notice that the `messages` in this new thread do NOT contain the messages from the previous thread since we didn't store them as shared values across the `user_id`. However, the `info` we saved in the previous thread was saved since we passed in the same `user_id` in this new thread.\n", + "\n", + "Let's now run the graph for another user to verify that the preferences of the first user are self contained:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f9bf2c15", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " call_model: {\n", + " messages: [\n", + " AIMessage {\n", + " \"id\": \"chatcmpl-A8VgwX6MUbLwBdistYNC0LP1t6y7S\",\n", + " \"content\": \"Sure, I can help with that! To give you the best recommendations, could you please tell me your location or the city you're in? Additionally, do you have any preferences or dietary restrictions?\",\n", + " \"additional_kwargs\": {},\n", + " \"response_metadata\": {\n", + " \"tokenUsage\": {\n", + " \"completionTokens\": 40,\n", + " \"promptTokens\": 151,\n", + " \"totalTokens\": 191\n", + " },\n", + " \"finish_reason\": \"stop\",\n", + " \"system_fingerprint\": \"fp_a5d11b2ef2\"\n", + " },\n", + " \"tool_calls\": [],\n", + " \"invalid_tool_calls\": [],\n", + " \"usage_metadata\": {\n", + " \"input_tokens\": 151,\n", + " \"output_tokens\": 40,\n", + " \"total_tokens\": 191\n", + " }\n", + " }\n", + " ]\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "// Once again, we're specifying a new `user_id` value here.\n", + "// Like the previous examples, this means the graph will not\n", + "// be able to access the memory saved from the previous run.\n", + "const config3 = {\n", + " configurable: {\n", + " thread_id: \"3\",\n", + " user_id: \"2\"\n", + " },\n", + " streamMode: \"updates\" as const\n", + "}\n", + "\n", + "for await (const update of await compiledGraph.stream({\n", + " messages: [{ role: \"user\", content: \"where and what should i eat for dinner? Can you list some restaurants?\" }],\n", + "}, config3)) {\n", + " console.log(update);\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "b7086cea", + "metadata": {}, + "source": [ + "Perfect! The graph has forgotten all of the previous preferences and has to ask the user for it's location and dietary preferences." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "TypeScript", + "language": "typescript", + "name": "tslab" + }, + "language_info": { + "codemirror_mode": { + "mode": "typescript", + "name": "javascript", + "typescript": true + }, + "file_extension": ".ts", + "mimetype": "text/typescript", + "name": "typescript", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langgraph/src/channels/base.ts b/libs/langgraph/src/channels/base.ts index 61ac5598..d583dcfb 100644 --- a/libs/langgraph/src/channels/base.ts +++ b/libs/langgraph/src/channels/base.ts @@ -6,6 +6,10 @@ import { } from "@langchain/langgraph-checkpoint"; import { EmptyChannelError } from "../errors.js"; +export function isBaseChannel(obj: unknown): obj is BaseChannel { + return obj != null && (obj as BaseChannel).lg_is_channel === true; +} + export abstract class BaseChannel< ValueType = unknown, UpdateType = unknown, @@ -20,6 +24,9 @@ export abstract class BaseChannel< */ abstract lc_graph_name: string; + /** @ignore */ + lg_is_channel = true; + /** * Return a new identical channel, optionally initialized from a checkpoint. * Can be thought of as a "restoration" from a checkpoint which is a "snapshot" of the channel's state. @@ -75,11 +82,15 @@ export function emptyChannels>( channels: Cc, checkpoint: ReadonlyCheckpoint ): Cc { + const filteredChannels = Object.fromEntries( + Object.entries(channels).filter(([, value]) => isBaseChannel(value)) + ) as Cc; + const newChannels = {} as Cc; - for (const k in channels) { - if (Object.prototype.hasOwnProperty.call(channels, k)) { + for (const k in filteredChannels) { + if (Object.prototype.hasOwnProperty.call(filteredChannels, k)) { const channelValue = checkpoint.channel_values[k]; - newChannels[k] = channels[k].fromCheckpoint(channelValue); + newChannels[k] = filteredChannels[k].fromCheckpoint(channelValue); } } return newChannels; diff --git a/libs/langgraph/src/constants.ts b/libs/langgraph/src/constants.ts index a46337ec..2a6276cc 100644 --- a/libs/langgraph/src/constants.ts +++ b/libs/langgraph/src/constants.ts @@ -5,6 +5,9 @@ export const CONFIG_KEY_READ = "__pregel_read"; export const CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer"; export const CONFIG_KEY_RESUMING = "__pregel_resuming"; export const INTERRUPT = "__interrupt__"; +export const CONFIG_KEY_STORE = "__pregel_store"; +export const RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__"; +export const RECURSION_LIMIT_DEFAULT = 25; export const TAG_HIDDEN = "langsmith:hidden"; diff --git a/libs/langgraph/src/graph/annotation.ts b/libs/langgraph/src/graph/annotation.ts index f2f8b396..3da245eb 100644 --- a/libs/langgraph/src/graph/annotation.ts +++ b/libs/langgraph/src/graph/annotation.ts @@ -2,6 +2,11 @@ import { RunnableLike } from "@langchain/core/runnables"; import { BaseChannel } from "../channels/base.js"; import { BinaryOperator, BinaryOperatorAggregate } from "../channels/binop.js"; import { LastValue } from "../channels/last_value.js"; +import { + isConfiguredManagedValue, + ManagedValueSpec, + type ConfiguredManagedValue, +} from "../managed/base.js"; export type SingleReducer = | { @@ -18,19 +23,23 @@ export type SingleReducer = | null; export interface StateDefinition { - [key: string]: BaseChannel | (() => BaseChannel); + [key: string]: BaseChannel | (() => BaseChannel) | ConfiguredManagedValue; } type ExtractValueType = C extends BaseChannel ? C["ValueType"] : C extends () => BaseChannel ? ReturnType["ValueType"] + : C extends ConfiguredManagedValue + ? V : never; type ExtractUpdateType = C extends BaseChannel ? C["UpdateType"] : C extends () => BaseChannel ? ReturnType["UpdateType"] + : C extends ConfiguredManagedValue + ? V : never; export type StateType = { @@ -43,7 +52,7 @@ export type UpdateType = { export type NodeType = RunnableLike< StateType, - UpdateType + UpdateType | Partial> >; /** @ignore */ @@ -150,9 +159,11 @@ export const Annotation: AnnotationFunction = function < ValueType, UpdateType = ValueType >( - annotation?: SingleReducer -): BaseChannel { - if (annotation) { + annotation?: SingleReducer | ConfiguredManagedValue +): BaseChannel | ManagedValueSpec { + if (isConfiguredManagedValue(annotation)) { + return annotation; + } else if (annotation) { return getChannel(annotation); } else { // @ts-expect-error - Annotation without reducer diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index 6514c1f4..68e471fc 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -5,7 +5,7 @@ import { RunnableLike, } from "@langchain/core/runnables"; import { All, BaseCheckpointSaver } from "@langchain/langgraph-checkpoint"; -import { BaseChannel } from "../channels/base.js"; +import { BaseChannel, isBaseChannel } from "../channels/base.js"; import { END, CompiledGraph, @@ -41,6 +41,8 @@ import { UpdateType, } from "./annotation.js"; import type { RetryPolicy } from "../pregel/utils.js"; +import { BaseStore } from "../store/base.js"; +import { isConfiguredManagedValue, ManagedValueSpec } from "../managed/base.js"; const ROOT = "__root__"; @@ -160,7 +162,7 @@ export class StateGraph< I extends StateDefinition = SD extends StateDefinition ? SD : StateDefinition, O extends StateDefinition = SD extends StateDefinition ? SD : StateDefinition > extends Graph> { - channels: Record = {}; + channels: Record = {}; // TODO: this doesn't dedupe edges as in py, so worth fixing at some point waitingEdges: Set<[N[], N]> = new Set(); @@ -246,7 +248,10 @@ export class StateGraph< } if (this.channels[key] !== undefined) { if (this.channels[key] !== channel) { - if (channel.lc_graph_name !== "LastValue") { + if ( + !isConfiguredManagedValue(channel) && + channel.lc_graph_name !== "LastValue" + ) { throw new Error( `Channel "${key}" already exists with a different type.` ); @@ -338,10 +343,12 @@ export class StateGraph< compile({ checkpointer, + store, interruptBefore, interruptAfter, }: { checkpointer?: BaseCheckpointSaver; + store?: BaseStore; interruptBefore?: N[] | All; interruptAfter?: N[] | All; } = {}): CompiledStateGraph { @@ -378,6 +385,7 @@ export class StateGraph< outputChannels, streamChannels, streamMode: "updates", + store, }); // attach nodes, edges and branches @@ -587,10 +595,6 @@ export class CompiledStateGraph< } } -function isBaseChannel(obj: unknown): obj is BaseChannel { - return obj != null && typeof (obj as BaseChannel).lc_graph_name === "string"; -} - function isStateDefinition(obj: unknown): obj is StateDefinition { return ( typeof obj === "object" && diff --git a/libs/langgraph/src/managed/base.ts b/libs/langgraph/src/managed/base.ts new file mode 100644 index 00000000..653e8b44 --- /dev/null +++ b/libs/langgraph/src/managed/base.ts @@ -0,0 +1,187 @@ +import { RunnableConfig } from "@langchain/core/runnables"; +import { RUNTIME_PLACEHOLDER } from "../constants.js"; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export interface ManagedValueParams extends Record {} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export abstract class ManagedValue { + runtime: boolean = false; + + config: RunnableConfig; + + private _promises: Promise[] = []; + + lg_is_managed_value = true; + + constructor(config: RunnableConfig, _params?: ManagedValueParams) { + this.config = config; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + static async initialize( + _config: RunnableConfig, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _args?: any + ): Promise> { + throw new Error("Not implemented"); + } + + abstract call(step: number): Value; + + async promises(): Promise { + return Promise.all(this._promises); + } + + protected addPromise(promise: Promise): void { + this._promises.push(promise); + } +} + +export abstract class WritableManagedValue< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + Value = any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + Update = any +> extends ManagedValue { + abstract update(writes: Update[]): Promise; +} + +export const ChannelKeyPlaceholder = "__channel_key_placeholder__"; + +export type ManagedValueSpec = typeof ManagedValue | ConfiguredManagedValue; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export interface ConfiguredManagedValue { + cls: typeof ManagedValue; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + params: ManagedValueParams; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export class ManagedValueMapping extends Map> { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(entries?: Iterable<[string, ManagedValue]> | null) { + super(entries ? Array.from(entries) : undefined); + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + replaceRuntimeValues(step: number, values: Record | any): void { + if (this.size === 0 || !values) { + return; + } + + if (Array.from(this.values()).every((mv) => !mv.runtime)) { + return; + } + + if (typeof values === "object" && !Array.isArray(values)) { + for (const [key, value] of Object.entries(values)) { + for (const [chan, mv] of this.entries()) { + if (mv.runtime && mv.call(step) === value) { + // eslint-disable-next-line no-param-reassign + values[key] = { [RUNTIME_PLACEHOLDER]: chan }; + } + } + } + } else if (typeof values === "object" && "constructor" in values) { + for (const key of Object.getOwnPropertyNames( + Object.getPrototypeOf(values) + )) { + try { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const value = (values as any)[key]; + for (const [chan, mv] of this.entries()) { + if (mv.runtime && mv.call(step) === value) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any, no-param-reassign + (values as any)[key] = { [RUNTIME_PLACEHOLDER]: chan }; + } + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + // Ignore if TypeError + if (error.name !== TypeError.name) { + throw error; + } + } + } + } + } + + replaceRuntimePlaceholders( + step: number, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + values: Record | any + ): void { + if (this.size === 0 || !values) { + return; + } + + if (Array.from(this.values()).every((mv) => !mv.runtime)) { + return; + } + + if (typeof values === "object" && !Array.isArray(values)) { + for (const [key, value] of Object.entries(values)) { + if ( + typeof value === "object" && + value !== null && + RUNTIME_PLACEHOLDER in value + ) { + const placeholder = value[RUNTIME_PLACEHOLDER]; + if (typeof placeholder === "string") { + // eslint-disable-next-line no-param-reassign + values[key] = this.get(placeholder)?.call(step); + } + } + } + } else if (typeof values === "object" && "constructor" in values) { + for (const key of Object.getOwnPropertyNames( + Object.getPrototypeOf(values) + )) { + try { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const value = (values as any)[key]; + if ( + typeof value === "object" && + value !== null && + RUNTIME_PLACEHOLDER in value + ) { + const managedValue = this.get(value[RUNTIME_PLACEHOLDER]); + if (managedValue) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any, no-param-reassign + (values as any)[key] = managedValue.call(step); + } + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + // Ignore if TypeError + if (error.name !== TypeError.name) { + throw error; + } + } + } + } + } +} + +export function isManagedValue(value: unknown): value is typeof ManagedValue { + if (typeof value === "object" && value && "lg_is_managed_value" in value) { + return true; + } + return false; +} + +export function isConfiguredManagedValue( + value: unknown +): value is ConfiguredManagedValue { + if ( + typeof value === "object" && + value && + "cls" in value && + "params" in value + ) { + return true; + } + return false; +} diff --git a/libs/langgraph/src/managed/index.ts b/libs/langgraph/src/managed/index.ts new file mode 100644 index 00000000..52d7d41d --- /dev/null +++ b/libs/langgraph/src/managed/index.ts @@ -0,0 +1,3 @@ +export * from "./base.js"; +export * from "./is_last_step.js"; +export * from "./shared_value.js"; diff --git a/libs/langgraph/src/managed/is_last_step.ts b/libs/langgraph/src/managed/is_last_step.ts new file mode 100644 index 00000000..0cead6cc --- /dev/null +++ b/libs/langgraph/src/managed/is_last_step.ts @@ -0,0 +1,8 @@ +import { RECURSION_LIMIT_DEFAULT } from "../constants.js"; +import { ManagedValue } from "./base.js"; + +export class IsLastStepManager extends ManagedValue { + call(step: number): boolean { + return step === (this.config.recursionLimit ?? RECURSION_LIMIT_DEFAULT) - 1; + } +} diff --git a/libs/langgraph/src/managed/shared_value.ts b/libs/langgraph/src/managed/shared_value.ts new file mode 100644 index 00000000..705032eb --- /dev/null +++ b/libs/langgraph/src/managed/shared_value.ts @@ -0,0 +1,118 @@ +import { RunnableConfig } from "@langchain/core/runnables"; +import { BaseStore, type Values } from "../store/base.js"; +import { + ChannelKeyPlaceholder, + ConfiguredManagedValue, + ManagedValue, + ManagedValueParams, + WritableManagedValue, +} from "./base.js"; +import { CONFIG_KEY_STORE } from "../constants.js"; +import { InvalidUpdateError } from "../errors.js"; + +type Value = Record; +type Update = Record; + +export interface SharedValueParams extends ManagedValueParams { + scope: string; + key: string; +} + +export class SharedValue extends WritableManagedValue { + scope: string; + + store: BaseStore | null; + + ns: string | null; + + value: Value = {}; + + constructor(config: RunnableConfig, params: SharedValueParams) { + super(config, params); + this.scope = params.scope; + this.store = config.configurable?.[CONFIG_KEY_STORE] || null; + + if (!this.store) { + this.ns = null; + } else if (config.configurable?.[this.scope]) { + const scopeValue = config.configurable[this.scope]; + const scopedValueString = + typeof scopeValue === "string" + ? scopeValue + : JSON.stringify(scopeValue); + this.ns = `scoped:${this.scope}:${params.key}:${scopedValueString}`; + } else { + throw new Error( + `Scope ${this.scope} for shared state key not in config.configurable` + ); + } + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + static async initialize( + config: RunnableConfig, + args: SharedValueParams + ): Promise> { + const instance = new this(config, args); + await instance.loadStore(); + return instance as unknown as ManagedValue; + } + + static on(scope: string): ConfiguredManagedValue { + return { + cls: SharedValue, + params: { + scope, + key: ChannelKeyPlaceholder, + }, + }; + } + + call(_step: number): Value { + return { ...this.value }; + } + + private processUpdate( + values: Update[] + ): Array<[string, string, Values | null]> { + const writes: Array<[string, string, Values | null]> = []; + + for (const vv of values) { + for (const [k, v] of Object.entries(vv)) { + if (v === null) { + if (k in this.value) { + delete this.value[k]; + if (this.ns) { + writes.push([this.ns, k, null]); + } + } + } else if (typeof v !== "object" || v === null) { + throw new InvalidUpdateError("Received a non-object value"); + } else { + this.value[k] = v as Values; + if (this.ns) { + writes.push([this.ns, k, v as Values]); + } + } + } + } + + return writes; + } + + async update(values: Update[]): Promise { + if (!this.store) { + this.processUpdate(values); + } else { + await this.store.put(this.processUpdate(values)); + } + } + + private async loadStore(): Promise { + if (this.store && this.ns) { + const saved = await this.store.list([this.ns]); + this.value = saved[this.ns] || {}; + } + return false; + } +} diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index 4b34ac2c..529fe71c 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -20,6 +20,7 @@ import { BaseChannel, createCheckpoint, emptyChannels, + isBaseChannel, } from "../channels/base.js"; import { PregelNode } from "./read.js"; import { readChannel, readChannels } from "./io.js"; @@ -40,6 +41,7 @@ import { import { PregelExecutableTask, PregelTaskDescription } from "./types.js"; import { EmptyChannelError, InvalidUpdateError } from "../errors.js"; import { _getIdMetadata, getNullChannelVersion } from "./utils.js"; +import { ManagedValueMapping } from "../managed/base.js"; /** * Construct a type with a set of properties K of type T @@ -89,30 +91,72 @@ export function shouldInterrupt( return anyChannelUpdated && anyTriggeredNodeInInterruptNodes; } -export function _localRead>( +export function _localRead>( + step: number, checkpoint: ReadonlyCheckpoint, channels: Cc, + managed: ManagedValueMapping, task: WritesProtocol, select: Array | keyof Cc, fresh: boolean = false ): Record | unknown { - if (fresh) { - const newCheckpoint = createCheckpoint(checkpoint, channels, -1); - // create a new copy of channels - const newChannels = emptyChannels(channels, newCheckpoint); - // Note: _applyWrites contains side effects + let managedKeys: Array = []; + let updated = new Set(); + + if (!Array.isArray(select)) { + for (const [c] of task.writes) { + if (c === select) { + updated = new Set([c]); + break; + } + } + updated = updated || new Set(); + } else { + managedKeys = select.filter((k) => managed.get(k as string)) as Array< + keyof Cc + >; + select = select.filter((k) => !managed.get(k as string)) as Array; + updated = new Set( + select.filter((c) => task.writes.some(([key, _]) => key === c)) + ); + } + + let values: Record; + + if (fresh && updated.size > 0) { + const localChannels = Object.fromEntries( + Object.entries(channels).filter(([k, _]) => updated.has(k as keyof Cc)) + ) as Partial; + + const newCheckpoint = createCheckpoint(checkpoint, localChannels as Cc, -1); + const newChannels = emptyChannels(localChannels as Cc, newCheckpoint); + _applyWrites(copyCheckpoint(newCheckpoint), newChannels, [task]); - return readChannels(newChannels, select); + values = readChannels({ ...channels, ...newChannels }, select); } else { - return readChannels(channels, select); + values = readChannels(channels, select); + } + + if (managedKeys.length > 0) { + for (const k of managedKeys) { + const managedValue = managed.get(k as string); + if (managedValue) { + const resultOfManagedCall = managedValue.call(step); + values[k as string] = resultOfManagedCall; + } + } } + + return values; } export function _localWrite( + step: number, // eslint-disable-next-line @typescript-eslint/no-explicit-any - commit: (writes: [string, any][]) => void, + commit: (writes: [string, any][]) => any, processes: Record, channels: Record, + managed: ManagedValueMapping, // eslint-disable-next-line @typescript-eslint/no-explicit-any writes: [string, any][] ) { @@ -130,7 +174,9 @@ export function _localWrite( `Invalid node name ${value.node} in packet` ); } - } else if (!(chan in channels)) { + // replace any runtime values with placeholders + managed.replaceRuntimeValues(step, value.args); + } else if (!(chan in channels) && !managed.get(chan)) { console.warn(`Skipping write for channel '${chan}' which has no readers`); } } @@ -143,7 +189,11 @@ export function _applyWrites>( tasks: WritesProtocol[], // eslint-disable-next-line @typescript-eslint/no-explicit-any getNextVersion?: (version: any, channel: BaseChannel) => any -): void { +): Record { + // Filter out non instances of BaseChannel + const onlyChannels = Object.fromEntries( + Object.entries(channels).filter(([_, value]) => isBaseChannel(value)) + ) as Cc; // Update seen versions for (const task of tasks) { if (checkpoint.versions_seen[task.name] === undefined) { @@ -173,11 +223,11 @@ export function _applyWrites>( ); for (const chan of channelsToConsume) { - if (channels[chan].consume()) { + if (chan in onlyChannels && onlyChannels[chan].consume()) { if (getNextVersion !== undefined) { checkpoint.channel_versions[chan] = getNextVersion( maxVersion, - channels[chan] + onlyChannels[chan] ); } } @@ -193,6 +243,7 @@ export function _applyWrites>( keyof Cc, PendingWriteValue[] >; + const pendingWritesByManaged = {} as Record; for (const task of tasks) { for (const [chan, val] of task.writes) { if (chan === TASKS) { @@ -200,12 +251,18 @@ export function _applyWrites>( node: (val as Send).node, args: (val as Send).args, }); - } else { + } else if (chan in onlyChannels) { if (chan in pendingWriteValuesByChannel) { pendingWriteValuesByChannel[chan].push(val); } else { pendingWriteValuesByChannel[chan] = [val]; } + } else { + if (chan in pendingWritesByManaged) { + pendingWritesByManaged[chan].push(val); + } else { + pendingWritesByManaged[chan] = [val]; + } } } } @@ -221,10 +278,10 @@ export function _applyWrites>( const updatedChannels: Set = new Set(); // Apply writes to channels for (const [chan, vals] of Object.entries(pendingWriteValuesByChannel)) { - if (chan in channels) { + if (chan in onlyChannels) { let updated; try { - updated = channels[chan].update(vals); + updated = onlyChannels[chan].update(vals); // eslint-disable-next-line @typescript-eslint/no-explicit-any } catch (e: any) { if (e.name === InvalidUpdateError.unminifiable_name) { @@ -240,7 +297,7 @@ export function _applyWrites>( if (updated && getNextVersion !== undefined) { checkpoint.channel_versions[chan] = getNextVersion( maxVersion, - channels[chan] + onlyChannels[chan] ); } updatedChannels.add(chan); @@ -248,17 +305,20 @@ export function _applyWrites>( } // Channels that weren't updated in this step are notified of a new step - for (const chan of Object.keys(channels)) { + for (const chan of Object.keys(onlyChannels)) { if (!updatedChannels.has(chan)) { - const updated = channels[chan].update([]); + const updated = onlyChannels[chan].update([]); if (updated && getNextVersion !== undefined) { checkpoint.channel_versions[chan] = getNextVersion( maxVersion, - channels[chan] + onlyChannels[chan] ); } } } + + // Return managed values writes to be applied externally + return pendingWritesByManaged; } export type NextTaskExtraFields = { @@ -275,6 +335,7 @@ export function _prepareNextTasks< checkpoint: ReadonlyCheckpoint, processes: Nn, channels: Cc, + managed: ManagedValueMapping, config: RunnableConfig, forExecution: false, extra: NextTaskExtraFields @@ -287,6 +348,7 @@ export function _prepareNextTasks< checkpoint: ReadonlyCheckpoint, processes: Nn, channels: Cc, + managed: ManagedValueMapping, config: RunnableConfig, forExecution: true, extra: NextTaskExtraFields @@ -299,6 +361,7 @@ export function _prepareNextTasks< checkpoint: ReadonlyCheckpoint, processes: Nn, channels: Cc, + managed: ManagedValueMapping, config: RunnableConfig, forExecution: boolean, extra: NextTaskExtraFields @@ -342,6 +405,7 @@ export function _prepareNextTasks< const node = proc.getNode(); if (node !== undefined) { const writes: [keyof Cc, unknown][] = []; + managed.replaceRuntimePlaceholders(step, packet.args); tasks.push({ name: packet.node, input: packet.args, @@ -356,22 +420,33 @@ export function _prepareNextTasks< runName: packet.node, callbacks: manager?.getChild(`graph:step:${step}`), configurable: { - [CONFIG_KEY_SEND]: _localWrite.bind( - undefined, - (items: [keyof Cc, unknown][]) => writes.push(...items), - processes, - channels - ), - [CONFIG_KEY_READ]: _localRead.bind( - undefined, - checkpoint, - channels, - { - name: packet.node, - writes: writes as Array<[string, unknown]>, - triggers, - } - ), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [CONFIG_KEY_SEND]: (writes_: [string, any][]) => + _localWrite( + step, + (items: [keyof Cc, unknown][]) => writes.push(...items), + processes, + channels, + managed, + writes_ + ), + [CONFIG_KEY_READ]: ( + select_: Array | keyof Cc, + fresh_: boolean = false + ) => + _localRead( + step, + checkpoint, + channels, + managed, + { + name: packet.node, + writes: writes as Array<[string, unknown]>, + triggers, + }, + select_, + fresh_ + ), }, } ), @@ -408,7 +483,7 @@ export function _prepareNextTasks< .sort(); // If any of the channels read by this process were updated if (triggers.length > 0) { - const val = _procInput(proc, channels, forExecution); + const val = _procInput(step, proc, managed, channels, forExecution); if (val === undefined) { continue; } @@ -448,22 +523,33 @@ export function _prepareNextTasks< runName: name, callbacks: manager?.getChild(`graph:step:${step}`), configurable: { - [CONFIG_KEY_SEND]: _localWrite.bind( - undefined, - (items: [keyof Cc, unknown][]) => writes.push(...items), - processes, - channels - ), - [CONFIG_KEY_READ]: _localRead.bind( - undefined, - checkpoint, - channels, - { - name, - writes: writes as Array<[string, unknown]>, - triggers, - } - ), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [CONFIG_KEY_SEND]: (writes_: [string, any][]) => + _localWrite( + step, + (items: [keyof Cc, unknown][]) => writes.push(...items), + processes, + channels, + managed, + writes_ + ), + [CONFIG_KEY_READ]: ( + select_: Array | keyof Cc, + fresh_: boolean = false + ) => + _localRead( + step, + checkpoint, + channels, + managed, + { + name, + writes: writes as Array<[string, unknown]>, + triggers, + }, + select_, + fresh_ + ), [CONFIG_KEY_CHECKPOINTER]: checkpointer, [CONFIG_KEY_RESUMING]: isResuming, checkpoint_id: checkpoint.id, @@ -484,15 +570,45 @@ export function _prepareNextTasks< } function _procInput( + step: number, proc: PregelNode, + managed: ManagedValueMapping, channels: StrRecord, forExecution: boolean ) { // eslint-disable-next-line @typescript-eslint/no-explicit-any let val: any; - // If all trigger channels subscribed by this process are not empty - // then invoke the process with the values of all non-empty channels - if (Array.isArray(proc.channels)) { + + if (typeof proc.channels === "object" && !Array.isArray(proc.channels)) { + val = {}; + for (const [k, chan] of Object.entries(proc.channels)) { + if (proc.triggers.includes(chan)) { + try { + val[k] = readChannel(channels, chan, false); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + if (e.name === EmptyChannelError.unminifiable_name) { + return undefined; + } else { + throw e; + } + } + } else if (chan in channels) { + try { + val[k] = readChannel(channels, chan, true); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + if (e.name === EmptyChannelError.unminifiable_name) { + continue; + } else { + throw e; + } + } + } else { + val[k] = managed.get(k)?.call(step); + } + } + } else if (Array.isArray(proc.channels)) { let successfulRead = false; for (const chan of proc.channels) { try { @@ -509,21 +625,7 @@ function _procInput( } } if (!successfulRead) { - return; - } - } else if (typeof proc.channels === "object") { - val = {}; - for (const [k, chan] of Object.entries(proc.channels)) { - try { - val[k] = readChannel(channels, chan, !proc.triggers.includes(chan)); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - if (e.name === EmptyChannelError.unminifiable_name) { - continue; - } else { - throw e; - } - } + return undefined; } } else { throw new Error( diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index abb86f9e..b8e44109 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -25,6 +25,7 @@ import { BaseChannel, createCheckpoint, emptyChannels, + isBaseChannel, } from "../channels/base.js"; import { PregelNode } from "./read.js"; import { validateGraph, validateKeys } from "./validate.js"; @@ -40,6 +41,7 @@ import { CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_READ, CONFIG_KEY_SEND, + CONFIG_KEY_STORE, ERROR, INTERRUPT, } from "../constants.js"; @@ -65,6 +67,15 @@ import { import { _coerceToDict, getNewChannelVersions, RetryPolicy } from "./utils.js"; import { PregelLoop } from "./loop.js"; import { executeTasksWithRetry } from "./retry.js"; +import { BaseStore } from "../store/base.js"; +import { + ChannelKeyPlaceholder, + isConfiguredManagedValue, + ManagedValue, + ManagedValueMapping, + type ManagedValueSpec, +} from "../managed/base.js"; +import { patchConfigurable } from "../utils.js"; type WriteValue = Runnable | RunnableFunc | unknown; @@ -165,7 +176,7 @@ export class Channel { */ export interface PregelOptions< Nn extends StrRecord, - Cc extends StrRecord + Cc extends StrRecord > extends RunnableConfig { /** The stream mode for the graph run. Default is ["values"]. */ streamMode?: StreamMode | StreamMode[]; @@ -188,7 +199,7 @@ export type PregelOutputType = any; export class Pregel< Nn extends StrRecord, - Cc extends StrRecord + Cc extends StrRecord > extends Runnable> implements PregelInterface @@ -226,6 +237,8 @@ export class Pregel< retryPolicy?: RetryPolicy; + store?: BaseStore; + constructor(fields: PregelParams) { super(fields); @@ -247,6 +260,7 @@ export class Pregel< this.debug = fields.debug ?? this.debug; this.checkpointer = fields.checkpointer; this.retryPolicy = fields.retryPolicy; + this.store = fields.store; if (this.autoValidate) { this.validate(); @@ -254,7 +268,7 @@ export class Pregel< } validate(): this { - validateGraph({ + validateGraph({ nodes: this.nodes, channels: this.channels, outputChannels: this.outputChannels, @@ -295,17 +309,26 @@ export class Pregel< const saved = await this.checkpointer.getTuple(config); const checkpoint = saved ? saved.checkpoint : emptyCheckpoint(); - const channels = emptyChannels(this.channels, checkpoint); + const channels = emptyChannels( + this.channels as Record, + checkpoint + ); + const { managed } = await this.prepareSpecs(config); + const nextTasks = _prepareNextTasks( checkpoint, this.nodes, channels, + managed, saved !== undefined ? saved.config : config, false, { step: saved ? (saved.metadata?.step ?? -1) + 1 : -1 } ); return { - values: readChannels(channels, this.streamChannelsAsIs), + values: readChannels( + channels, + this.streamChannelsAsIs as string | string[] + ), next: nextTasks.map((task) => task.name), tasks: tasksWithWrites(nextTasks, saved?.pendingWrites ?? []), metadata: saved?.metadata, @@ -325,18 +348,28 @@ export class Pregel< if (!this.checkpointer) { throw new GraphValueError("No checkpointer set"); } + const { managed } = await this.prepareSpecs(config); + for await (const saved of this.checkpointer.list(config, options)) { - const channels = emptyChannels(this.channels, saved.checkpoint); + const channels = emptyChannels( + this.channels as Record, + saved.checkpoint + ); + const nextTasks = _prepareNextTasks( saved.checkpoint, this.nodes, channels, + managed, saved.config, false, { step: -1 } ); yield { - values: readChannels(channels, this.streamChannelsAsIs), + values: readChannels( + channels, + this.streamChannelsAsIs as string | string[] + ), next: nextTasks.map((task) => task.name), tasks: tasksWithWrites(nextTasks, saved.pendingWrites ?? []), metadata: saved.metadata, @@ -442,7 +475,12 @@ export class Pregel< ); } // update channels - const channels = emptyChannels(this.channels, checkpoint); + const channels = emptyChannels( + this.channels as Record, + checkpoint + ); + const { managed } = await this.prepareSpecs(config); + // run all writers of the chosen node const writers = this.nodes[asNode].getWriters(); if (!writers.length) { @@ -469,13 +507,20 @@ export class Pregel< configurable: { [CONFIG_KEY_SEND]: (items: [keyof Cc, unknown][]) => task.writes.push(...items), - [CONFIG_KEY_READ]: _localRead.bind( - undefined, - checkpoint, - channels, - // TODO: Why does keyof StrRecord allow number and symbol? - task as PregelExecutableTask - ), + [CONFIG_KEY_READ]: ( + select_: Array | keyof Cc, + fresh_: boolean = false + ) => + _localRead( + step, + checkpoint, + channels, + managed, + // TODO: Why does keyof StrRecord allow number and symbol? + task as PregelExecutableTask, + select_ as string | string[], + fresh_ + ), }, }) ); @@ -611,6 +656,56 @@ export class Pregel< return super.stream(input, options); } + protected async prepareSpecs(config: RunnableConfig) { + const configForManaged = patchConfigurable(config, { + [CONFIG_KEY_STORE]: this.store, + }); + const channelSpecs: Record = {}; + const managedSpecs: Record = {}; + + for (const [name, spec] of Object.entries(this.channels)) { + if (isBaseChannel(spec)) { + channelSpecs[name] = spec; + } else { + managedSpecs[name] = spec; + } + } + const managed = new ManagedValueMapping( + await Object.entries(managedSpecs).reduce( + async (accPromise, [key, value]) => { + const acc = await accPromise; + let initializedValue; + + if (isConfiguredManagedValue(value)) { + if ( + "key" in value.params && + value.params.key === ChannelKeyPlaceholder + ) { + value.params.key = key; + } + initializedValue = await value.cls.initialize( + configForManaged, + value.params + ); + } else { + initializedValue = await value.initialize(configForManaged); + } + + if (initializedValue !== undefined) { + acc.push([key, initializedValue]); + } + + return acc; + }, + Promise.resolve([] as [string, ManagedValue][]) + ) + ); + return { + channelSpecs, + managed, + }; + } + override async *_streamIterator( input: PregelInputType, options?: Partial> @@ -652,6 +747,9 @@ export class Pregel< interruptAfter, checkpointer, ] = this._defaults(inputConfig); + + const { channelSpecs, managed } = await this.prepareSpecs(config); + let loop; try { loop = await PregelLoop.initialize({ @@ -659,9 +757,11 @@ export class Pregel< config, checkpointer, nodes: this.nodes, - channelSpecs: this.channels, + channelSpecs, + managed, outputKeys, streamKeys: this.streamChannelsAsIs as string | string[], + store: this.store, }); while ( await loop.tick({ @@ -775,7 +875,14 @@ export class Pregel< await runManager?.handleChainError(e); throw e; } finally { - await Promise.all(loop?.checkpointerPromises ?? []); + // Call `.stop()` again incase it was not called in the loop, e.g due to an error. + if (loop) { + loop.store?.stop(); + } + await Promise.all([ + loop?.checkpointerPromises ?? [], + ...Array.from(managed.values()).map((mv) => mv.promises()), + ]); } } diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index c7e182e6..54a5953c 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -46,6 +46,9 @@ import { mapDebugTaskResults, } from "./debug.js"; import { PregelNode } from "./read.js"; +import { BaseStore } from "../store/base.js"; +import { AsyncBatchedStore } from "../store/batch.js"; +import { ManagedValueMapping, WritableManagedValue } from "../managed/base.js"; const INPUT_DONE = Symbol.for("INPUT_DONE"); const INPUT_RESUMING = Symbol.for("INPUT_RESUMING"); @@ -60,6 +63,8 @@ export type PregelLoopInitializeParams = { streamKeys: string | string[]; nodes: Record; channelSpecs: Record; + managed: ManagedValueMapping; + store?: BaseStore; }; type PregelLoopParams = { @@ -73,11 +78,13 @@ type PregelLoopParams = { checkpointPendingWrites: CheckpointPendingWrite[]; checkpointConfig: RunnableConfig; channels: Record; + managed: ManagedValueMapping; step: number; stop: number; outputKeys: string | string[]; streamKeys: string | string[]; nodes: Record; + store?: AsyncBatchedStore; }; export class PregelLoop { @@ -95,6 +102,8 @@ export class PregelLoop { channels: Record; + managed: ManagedValueMapping; + protected checkpoint: Checkpoint; protected checkpointConfig: RunnableConfig; @@ -136,6 +145,8 @@ export class PregelLoop { protected _checkpointerChainedPromise: Promise = Promise.resolve(); + store?: AsyncBatchedStore; + constructor(params: PregelLoopParams) { this.input = params.input; this.config = params.config; @@ -154,6 +165,7 @@ export class PregelLoop { this.checkpointMetadata = params.checkpointMetadata; this.checkpointPreviousVersions = params.checkpointPreviousVersions; this.channels = params.channels; + this.managed = params.managed; this.checkpointPendingWrites = params.checkpointPendingWrites; this.step = params.step; this.stop = params.stop; @@ -162,6 +174,7 @@ export class PregelLoop { this.streamKeys = params.streamKeys; this.nodes = params.nodes; this.skipDoneTasks = this.config.configurable?.checkpoint_id === undefined; + this.store = params.store; } static async initialize(params: PregelLoopInitializeParams) { @@ -195,6 +208,15 @@ export class PregelLoop { const stop = step + (params.config.recursionLimit ?? DEFAULT_LOOP_LIMIT) + 1; const checkpointPreviousVersions = { ...checkpoint.channel_versions }; + + const store = params.store + ? new AsyncBatchedStore(params.store) + : undefined; + if (store) { + // Start the store. This is a batch store, so it will run continuously + store.start(); + } + return new PregelLoop({ input: params.input, config: params.config, @@ -203,6 +225,7 @@ export class PregelLoop { checkpointMetadata, checkpointConfig, channels, + managed: params.managed, step, stop, checkpointPreviousVersions, @@ -210,6 +233,7 @@ export class PregelLoop { outputKeys: params.outputKeys ?? [], streamKeys: params.streamKeys ?? [], nodes: params.nodes, + store, }); } @@ -232,6 +256,14 @@ export class PregelLoop { this.checkpointerPromises.push(this._checkpointerChainedPromise); } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + protected async updateManagedValues(key: string, values: any[]) { + const mv = this.managed.get(key); + if (mv && "update" in mv && typeof mv.update === "function") { + await (mv as WritableManagedValue).update(values); + } + } + /** * Put writes for a task, to be read by the next tick. * @param taskId @@ -288,6 +320,9 @@ export class PregelLoop { interruptBefore: string[] | All; manager?: CallbackManagerForChainRun; }): Promise { + if (this.store && !this.store.isRunning) { + this.store?.start(); + } const { inputKeys = [], interruptAfter = [], @@ -304,12 +339,15 @@ export class PregelLoop { } else if (this.tasks.every((task) => task.writes.length > 0)) { const writes = this.tasks.flatMap((t) => t.writes); // All tasks have finished - _applyWrites( + const myWrites = _applyWrites( this.checkpoint, this.channels, this.tasks, this.checkpointerGetNextVersion ); + for (const [key, values] of Object.entries(myWrites)) { + await this.updateManagedValues(key, values); + } // produce values output const valuesOutput = await gatherIterator( prefixGenerator( @@ -346,6 +384,7 @@ export class PregelLoop { this.checkpoint, this.nodes, this.channels, + this.managed, this.config, true, { @@ -453,6 +492,7 @@ export class PregelLoop { this.checkpoint, this.nodes, this.channels, + this.managed, this.config, true, { step: this.step } diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index cc6ba30c..7ce7d675 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -9,6 +9,8 @@ import type { BaseChannel } from "../channels/base.js"; import type { PregelNode } from "./read.js"; import { RetryPolicy } from "./utils.js"; import { Interrupt } from "../constants.js"; +import { BaseStore } from "../store/base.js"; +import { type ManagedValueSpec } from "../managed/base.js"; export type StreamMode = "values" | "updates" | "debug"; @@ -21,7 +23,7 @@ type StrRecord = { export interface PregelInterface< Nn extends StrRecord, - Cc extends StrRecord + Cc extends StrRecord > { nodes: Nn; @@ -68,11 +70,16 @@ export interface PregelInterface< checkpointer?: BaseCheckpointSaver; retryPolicy?: RetryPolicy; + + /** + * Memory store to use for SharedValues. + */ + store?: BaseStore; } export type PregelParams< Nn extends StrRecord, - Cc extends StrRecord + Cc extends StrRecord > = Omit, "streamChannelsAsIs">; export interface PregelTaskDescription { diff --git a/libs/langgraph/src/pregel/validate.ts b/libs/langgraph/src/pregel/validate.ts index 9a013fc6..7f86e161 100644 --- a/libs/langgraph/src/pregel/validate.ts +++ b/libs/langgraph/src/pregel/validate.ts @@ -2,6 +2,7 @@ import { All } from "@langchain/langgraph-checkpoint"; import { BaseChannel } from "../channels/index.js"; import { INTERRUPT } from "../constants.js"; import { PregelNode } from "./read.js"; +import { type ManagedValueSpec } from "../managed/base.js"; export class GraphValidationError extends Error { constructor(message?: string) { @@ -12,7 +13,7 @@ export class GraphValidationError extends Error { export function validateGraph< Nn extends Record, - Cc extends Record + Cc extends Record >({ nodes, channels, @@ -113,10 +114,9 @@ export function validateGraph< } } -export function validateKeys>( - keys: keyof Cc | Array, - channels: Cc -): void { +export function validateKeys< + Cc extends Record +>(keys: keyof Cc | Array, channels: Cc): void { if (Array.isArray(keys)) { for (const key of keys) { if (!(key in channels)) { diff --git a/libs/langgraph/src/store/base.ts b/libs/langgraph/src/store/base.ts new file mode 100644 index 00000000..24b7ed88 --- /dev/null +++ b/libs/langgraph/src/store/base.ts @@ -0,0 +1,18 @@ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type Values = Record; + +export abstract class BaseStore { + abstract list( + prefixes: string[] + ): Promise>>; + + abstract put(writes: Array<[string, string, Values | null]>): Promise; + + stop(): void { + // no-op if not implemented. + } + + start(): void { + // no-op if not implemented. + } +} diff --git a/libs/langgraph/src/store/batch.ts b/libs/langgraph/src/store/batch.ts new file mode 100644 index 00000000..58a049a6 --- /dev/null +++ b/libs/langgraph/src/store/batch.ts @@ -0,0 +1,166 @@ +import { BaseStore, type Values } from "./base.js"; + +/** + * A list operation to be processed in batch. + */ +interface ListOp { + /** + * An array of prefixes to list. + * @type {string[]} + */ + prefixes: string[]; +} + +/** + * A put operation to be processed in batch. + */ +interface PutOp { + /** + * An array of write operations to be performed. + * @type {Array<[string, string, Values | null]>} + */ + writes: Array<[string, string, Values | null]>; +} + +type QueueItem = { + /** + * A function to resolve the promise. This function should be called when the operation is complete. + * @param {any | undefined} value The value to resolve the promise with. + * @returns {void} + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + resolve: (value?: any) => void; + /** + * A function to reject the promise. This function should be called when the operation fails. + * @param {any | undefined} reason The reason for rejecting the promise. + * @returns {void} + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + reject: (reason?: any) => void; + /** + * The operation to be processed. This can be either a list or put operation. + */ + op: ListOp | PutOp; +}; + +/** + * AsyncBatchedStore extends BaseStore to provide batched operations for list and put methods. + * It queues operations and processes them in batches for improved efficiency. This store is + * designed to run for the full duration of the process, or until `stop()` is called. + */ +export class AsyncBatchedStore extends BaseStore { + /** + * The store to batch operations for. + * @type {BaseStore} + */ + private store: BaseStore; + + /** + * A queue of operations to be processed in batch. + * @type {QueueItem[]} + */ + private queue: QueueItem[] = []; + + /** + * Whether or not the batched processing is currently running. + * @type {boolean} + * @default {false} + */ + private running = false; + + get isRunning(): boolean { + return this.running; + } + + constructor(store: BaseStore) { + super(); + this.store = store; + } + + /** + * Queues a list operation to be processed in batch. + * @param {string[]} prefixes An array of prefixes to list. + * @returns {Promise>>} A promise that resolves with the list results. + */ + async list( + prefixes: string[] + ): Promise>> { + return new Promise((resolve, reject) => { + this.queue.push({ resolve, reject, op: { prefixes } }); + }); + } + + /** + * Queues a put operation to be processed in batch. + * @param {Array<[string, string, Values | null]>} writes An array of write operations to be performed. + * @returns {Promise} A promise that resolves when the put operation is complete. + */ + async put(writes: Array<[string, string, Values | null]>): Promise { + return new Promise((resolve, reject) => { + this.queue.push({ resolve, reject, op: { writes } }); + }); + } + + /** + * Start running the batched processing of operations. + * This process will run continuously until the store is stopped, + * which can be done by calling the `stop()` method. + */ + start() { + this.running = true; + void this.processBatchQueue(); + } + + /** + * Stops the batched processing of operations. + */ + stop() { + this.running = false; + } + + /** + * Runs the task that processes queued operations in batches. + * This method runs continuously until the store is stopped, + * or the process is terminated. + * @returns {Promise} A promise that resolves when the task is complete. + */ + private async processBatchQueue(): Promise { + while (this.running) { + await new Promise((resolve) => { + setTimeout(resolve, 0); + }); + if (this.queue.length === 0) continue; + + const taken = this.queue.splice(0); + + const lists = taken.filter((item) => "prefixes" in item.op); + if (lists.length > 0) { + try { + const allPrefixes = lists.flatMap( + (item) => (item.op as ListOp).prefixes + ); + const results = await this.store.list(allPrefixes); + lists.forEach((item) => { + const { prefixes } = item.op as ListOp; + item.resolve( + Object.fromEntries(prefixes.map((p) => [p, results[p] || {}])) + ); + }); + } catch (e) { + lists.forEach((item) => item.reject(e)); + } + } + + const puts = taken.filter((item) => "writes" in item.op); + if (puts.length > 0) { + try { + const allWrites = puts.flatMap((item) => (item.op as PutOp).writes); + await this.store.put(allWrites); + puts.forEach((item) => item.resolve()); + } catch (e) { + puts.forEach((item) => item.reject(e)); + } + } + } + } +} diff --git a/libs/langgraph/src/store/index.ts b/libs/langgraph/src/store/index.ts new file mode 100644 index 00000000..df367124 --- /dev/null +++ b/libs/langgraph/src/store/index.ts @@ -0,0 +1,3 @@ +export * from "./base.js"; +export * from "./batch.js"; +export * from "./memory.js"; diff --git a/libs/langgraph/src/store/memory.ts b/libs/langgraph/src/store/memory.ts new file mode 100644 index 00000000..4ec232a6 --- /dev/null +++ b/libs/langgraph/src/store/memory.ts @@ -0,0 +1,34 @@ +import { BaseStore, type Values } from "./base.js"; + +export class MemoryStore extends BaseStore { + private data: Map> = new Map(); + + async list( + prefixes: string[] + ): Promise>> { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const result: Record> = {}; + for (const prefix of prefixes) { + if (this.data.has(prefix)) { + result[prefix] = Object.fromEntries(this.data.get(prefix)!); + } else { + result[prefix] = {}; + } + } + return Promise.resolve(result); + } + + async put(writes: Array<[string, string, Values | null]>): Promise { + for (const [namespace, key, value] of writes) { + if (!this.data.has(namespace)) { + this.data.set(namespace, new Map()); + } + const namespaceMap = this.data.get(namespace)!; + if (value === null) { + namespaceMap.delete(key); + } else { + namespaceMap.set(key, value); + } + } + } +} diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 8bbbcfa8..2a66c95c 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -3,7 +3,7 @@ /* eslint-disable no-instanceof/no-instanceof */ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable import/no-extraneous-dependencies */ -import { it, expect, jest, describe } from "@jest/globals"; +import { it, expect, jest, describe, beforeEach } from "@jest/globals"; import { RunnableConfig, RunnableLambda, @@ -23,6 +23,7 @@ import { } from "@langchain/core/messages"; import { ToolCall } from "@langchain/core/messages/tool"; import { + BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuple, @@ -70,6 +71,10 @@ import { NodeInterrupt, } from "../errors.js"; import { ERROR, INTERRUPT, Send, TASKS } from "../constants.js"; +import { ManagedValueMapping } from "../managed/base.js"; +import { SharedValue } from "../managed/shared_value.js"; +import { MemoryStore } from "../store/memory.js"; +import { MessagesAnnotation } from "../graph/messages_annotation.js"; describe("Channel", () => { describe("writeTo", () => { @@ -574,6 +579,7 @@ describe("_localRead", () => { channel1, channel2, }; + const managed = new ManagedValueMapping(); // eslint-disable-next-line @typescript-eslint/no-explicit-any const writes: Array<[string, any]> = []; @@ -581,8 +587,10 @@ describe("_localRead", () => { // call method / assertions expect( _localRead( + 0, checkpoint, channels, + managed, { name: "test", writes, triggers: [] }, "channel1", false @@ -590,8 +598,10 @@ describe("_localRead", () => { ).toBe(1); expect( _localRead( + 0, checkpoint, channels, + managed, { name: "test", writes, triggers: [] }, ["channel1", "channel2"], false @@ -626,12 +636,15 @@ describe("_localRead", () => { ["channel1", 100], ["channel2", 200], ]; + const managed = new ManagedValueMapping(); // call method / assertions expect( _localRead( + 0, checkpoint, channels, + managed, { name: "test", writes, triggers: [] }, "channel1", true @@ -639,8 +652,10 @@ describe("_localRead", () => { ).toBe(100); expect( _localRead( + 0, checkpoint, channels, + managed, { name: "test", writes, triggers: [] }, ["channel1", "channel2"], true @@ -791,12 +806,14 @@ describe("_prepareNextTasks", () => { channel1, channel2, }; + const managed = new ManagedValueMapping(); // call method / assertions const taskDescriptions = _prepareNextTasks( checkpoint, processes, channels, + managed, { configurable: { thread_id: "foo" } }, false, { step: -1 } @@ -911,12 +928,14 @@ describe("_prepareNextTasks", () => { channel5, channel6, }; + const managed = new ManagedValueMapping(); // call method / assertions const tasks = _prepareNextTasks( checkpoint, processes, channels, + managed, { configurable: { thread_id: "foo" } }, true, { step: -1 } @@ -4648,3 +4667,552 @@ it("StateGraph branch then node", async () => { market: "FR", }); }); + +describe("StateGraph start branch then end", () => { + let checkpointer: BaseCheckpointSaver; + + const GraphAnnotation = Annotation.Root({ + my_key: Annotation({ + reducer: (a: string, b: string) => a + b, + }), + market: Annotation(), + shared: SharedValue.on("assistant_id"), + }); + + beforeEach(() => { + checkpointer = new MemorySaver(); + }); + + const assertSharedValue = ( + data: typeof GraphAnnotation.State, + config: RunnableConfig + ): Partial => { + expect(data).toHaveProperty("shared"); + const threadId = config.configurable?.thread_id; + if (threadId) { + if (threadId === "1") { + expect(data.shared).toEqual({}); + return { shared: { "1": { hello: "world" } } }; + } else if (threadId === "2") { + expect(data.shared).toEqual({ "1": { hello: "world" } }); + } else if (threadId === "3") { + // Should not contain a value because the "assistant_id" is different + expect(data.shared).toEqual({}); + } + } + return {}; + }; + + const toolTwoSlow = ( + data: typeof GraphAnnotation.State, + config: any + ): Partial => { + return { my_key: " slow", ...assertSharedValue(data, config) }; + }; + + const toolTwoFast = ( + data: typeof GraphAnnotation.State, + config: any + ): Partial => { + return { my_key: " fast", ...assertSharedValue(data, config) }; + }; + + it("should handle start branch then end", async () => { + const toolTwoGraph = new StateGraph(GraphAnnotation); + const debug = false; + + toolTwoGraph + .addNode("tool_two_slow", toolTwoSlow) + .addNode("tool_two_fast", toolTwoFast) + .addConditionalEdges(START, (s) => + s.market === "DE" ? "tool_two_slow" : "tool_two_fast" + ) + .addEdge("tool_two_slow", END) + .addEdge("tool_two_fast", END); + + let toolTwo = toolTwoGraph.compile(); + + expect( + await toolTwo.invoke({ my_key: "value", market: "DE" }, { debug }) + ).toEqual({ + my_key: "value slow", + market: "DE", + }); + + expect( + await toolTwo.invoke({ my_key: "value", market: "US" }, { debug }) + ).toEqual({ + my_key: "value fast", + market: "US", + }); + + toolTwo = toolTwoGraph.compile({ + store: new MemoryStore(), + checkpointer, + interruptBefore: ["tool_two_fast", "tool_two_slow"] as any[], + }); + + // Will throw an error if a checkpointer is passed but `configurable` isn't. + await expect( + toolTwo.invoke({ my_key: "value", market: "DE" }) + ).rejects.toThrow(/thread_id/); + + const thread1 = { + configurable: { thread_id: "1", assistant_id: "a" }, + debug, + }; + + expect( + await toolTwo.invoke({ my_key: "value ⛰️", market: "DE" }, thread1) + ).toEqual({ + my_key: "value ⛰️", + market: "DE", + }); + + const checkpoints = []; + if (toolTwo.checkpointer) { + for await (const checkpoint of toolTwo.checkpointer.list(thread1)) { + checkpoints.push(checkpoint); + } + } + + expect(checkpoints.map((c: any) => c.metadata)).toEqual([ + { + source: "loop", + step: 0, + writes: null, + }, + { + source: "input", + step: -1, + writes: { __start__: { my_key: "value ⛰️", market: "DE" } }, + }, + ]); + + expect(await toolTwo.getState(thread1)).toMatchObject({ + values: { my_key: "value ⛰️", market: "DE" }, + tasks: [{ name: "tool_two_slow" }], + next: ["tool_two_slow"], + metadata: { source: "loop", step: 0, writes: null }, + }); + + expect(await toolTwo.invoke(null, thread1)).toEqual({ + my_key: "value ⛰️ slow", + market: "DE", + }); + + expect(await toolTwo.getState(thread1)).toMatchObject({ + values: { + my_key: "value ⛰️ slow", + market: "DE", + }, + tasks: [], + next: [], + metadata: { + source: "loop", + step: 1, + writes: { + tool_two_slow: { + my_key: " slow", + }, + }, + }, + }); + + const thread2 = { + configurable: { thread_id: "2", assistant_id: "a" }, + debug, + }; + expect( + await toolTwo.invoke( + { + my_key: "value", + market: "US", + }, + thread2 + ) + ).toEqual({ + my_key: "value", + market: "US", + }); + + expect(await toolTwo.getState(thread2)).toMatchObject({ + values: { + my_key: "value", + market: "US", + }, + tasks: [{ name: "tool_two_fast" }], + next: ["tool_two_fast"], + metadata: { source: "loop", step: 0, writes: null }, + }); + + expect(await toolTwo.invoke(null, thread2)).toEqual({ + my_key: "value fast", + market: "US", + }); + + expect(await toolTwo.getState(thread2)).toMatchObject({ + values: { + my_key: "value fast", + market: "US", + }, + tasks: [], + next: [], + metadata: { + source: "loop", + step: 1, + writes: { tool_two_fast: { my_key: " fast" } }, + }, + }); + + const thread3 = { configurable: { thread_id: "3", assistant_id: "b" } }; + expect( + await toolTwo.invoke({ my_key: "value", market: "US" }, thread3) + ).toEqual({ + my_key: "value", + market: "US", + }); + + expect(await toolTwo.getState(thread3)).toMatchObject({ + values: { my_key: "value", market: "US" }, + tasks: [{ name: "tool_two_fast" }], + next: ["tool_two_fast"], + metadata: { source: "loop", step: 0, writes: null }, + }); + + await toolTwo.updateState(thread3, { my_key: "key" }); + + expect(await toolTwo.getState(thread3)).toMatchObject({ + values: { my_key: "valuekey", market: "US" }, + tasks: [{ name: "tool_two_fast" }], + next: ["tool_two_fast"], + metadata: { + source: "update", + step: 1, + writes: { [START]: { my_key: "key" } }, + }, + }); + + expect(await toolTwo.invoke(null, thread3)).toEqual({ + my_key: "valuekey fast", + market: "US", + }); + + expect(await toolTwo.getState(thread3)).toMatchObject({ + values: { my_key: "valuekey fast", market: "US" }, + tasks: [], + next: [], + metadata: { + source: "loop", + step: 2, + writes: { tool_two_fast: { my_key: " fast" } }, + }, + }); + }); +}); + +describe("Managed Values (context) can be passed through state", () => { + let store: MemoryStore; + let checkpointer: MemorySaver; + let threadId = ""; + let iter = 0; + + beforeEach(() => { + iter += 1; + threadId = iter.toString(); + store = new MemoryStore(); + checkpointer = new MemorySaver(); + }); + + const AgentAnnotation = Annotation.Root({ + ...MessagesAnnotation.spec, + sharedStateKey: SharedValue.on("assistant_id"), + }); + + it("should be passed through state but not stored in checkpointer", async () => { + const nodeOne = async ( + data: typeof AgentAnnotation.State, + config?: RunnableConfig + ): Promise> => { + if (!config) { + throw new Error("config is undefined"); + } + expect(config.configurable?.thread_id).toEqual(threadId); + + expect(data.sharedStateKey).toEqual({}); + + return { + sharedStateKey: { + sharedStateValue: { + value: "shared", + }, + }, + messages: [new AIMessage("hello")], + }; + }; + + const nodeTwo = async ( + data: typeof AgentAnnotation.State, + config?: RunnableConfig + ): Promise> => { + if (!config) { + throw new Error("config is undefined"); + } + + expect(data.sharedStateKey).toEqual({ + sharedStateValue: { + value: "shared", + }, + }); + + const storeData: Map< + string, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + Map> + // @ts-expect-error protected property, API not yet built for accessing values. + > = store.data; + expect(storeData.size).toEqual(1); + + // Namespace is scoped: + const namespace = "scoped:assistant_id:sharedStateKey:a"; + const scopedData = storeData.get(namespace); + expect(scopedData).toBeDefined(); + expect(scopedData?.size).toEqual(1); + const sharedValue = scopedData?.get("sharedStateValue"); + + expect(sharedValue).toEqual({ + value: "shared", + }); + + return { + sharedStateKey: { + sharedStateValue: { + value: "updated", + }, + }, + }; + }; + + const nodeThree = async ( + data: typeof AgentAnnotation.State, + config?: RunnableConfig + ): Promise> => { + if (!config) { + throw new Error("config is undefined"); + } + + expect(data.sharedStateKey).toEqual({ + sharedStateValue: { + value: "updated", + }, + }); + + // Return entire state so the result of `.invoke` can be verified. + return data; + }; + + const workflow = new StateGraph(AgentAnnotation) + .addNode("nodeOne", nodeOne) + .addNode("nodeTwo", nodeTwo) + .addNode("nodeThree", nodeThree) + .addEdge(START, "nodeOne") + .addEdge("nodeOne", "nodeTwo") + .addEdge("nodeTwo", "nodeThree") + .addEdge("nodeThree", END); + + const app = workflow.compile({ + store, + checkpointer, + interruptBefore: ["nodeTwo", "nodeThree"], + }); + + const config = { configurable: { thread_id: threadId, assistant_id: "a" } }; + + // Invoke the first time to cause `nodeOne` to be executed. + await app.invoke( + { + messages: [ + new HumanMessage({ + content: "what is weather in sf", + }), + ], + }, + config + ); + + // Get state and verify shared value is not present + const currentState1 = await app.getState(config); + expect(currentState1.next).toEqual(["nodeTwo"]); + expect(currentState1.values).toHaveProperty("messages"); + expect(currentState1.values).not.toHaveProperty("sharedStateKey"); + + // Invoke a second time to cause `nodeTwo` to be executed. + await app.invoke(null, config); + + const currentState2 = await app.getState(config); + expect(currentState2.next).toEqual(["nodeThree"]); + expect(currentState2.values).toHaveProperty("messages"); + expect(currentState2.values).not.toHaveProperty("sharedStateKey"); + + // Invoke the final time to cause `nodeThree` to be executed. + const result = await app.invoke(null, config); + + const currentState3 = await app.getState(config); + expect(currentState3.next).toEqual([]); + expect(currentState3.values).toHaveProperty("messages"); + expect(currentState3.values).not.toHaveProperty("sharedStateKey"); + + expect(result).not.toHaveProperty("sharedStateKey"); + expect(Object.keys(result)).toEqual(["messages"]); + }); + + it("can not access shared values from other 'on' keys", async () => { + const nodeOne = async ( + data: typeof AgentAnnotation.State, + config?: RunnableConfig + ): Promise> => { + if (!config) { + throw new Error("config is undefined"); + } + expect(config.configurable?.thread_id).toBe(threadId); + expect(config.configurable?.assistant_id).toBe("a"); + + expect(data.sharedStateKey).toEqual({}); + + return { + sharedStateKey: { + valueForA: { + value: "assistant_id a", + }, + }, + }; + }; + + const nodeTwo = async ( + data: typeof AgentAnnotation.State, + config?: RunnableConfig + ): Promise> => { + if (!config) { + throw new Error("config is undefined"); + } + expect(config.configurable?.thread_id).toBe(threadId); + expect(config.configurable?.assistant_id).toBe("b"); + + expect(data.sharedStateKey).toEqual({}); + + return { + sharedStateKey: { + valueForB: { + value: "assistant_id b", + }, + }, + }; + }; + + const nodeThree = async ( + data: typeof AgentAnnotation.State, + config?: RunnableConfig + ): Promise> => { + if (!config) { + throw new Error("config is undefined"); + } + + expect(config.configurable?.thread_id).toBe(threadId); + expect(config.configurable?.assistant_id).toBe("a"); + + expect(data.sharedStateKey).toEqual({ + valueForA: { + value: "assistant_id a", + }, + }); + + return {}; + }; + + const nodeFour = async ( + data: typeof AgentAnnotation.State, + config?: RunnableConfig + ): Promise> => { + if (!config) { + throw new Error("config is undefined"); + } + + expect(config.configurable?.thread_id).toBe(threadId); + expect(config.configurable?.assistant_id).toBe("b"); + + expect(data.sharedStateKey).toEqual({ + valueForB: { + value: "assistant_id b", + }, + }); + + return {}; + }; + + const workflow = new StateGraph(AgentAnnotation) + .addNode("nodeOne", nodeOne) + .addNode("nodeTwo", nodeTwo) + .addNode("nodeThree", nodeThree) + .addNode("nodeFour", nodeFour) + .addEdge(START, "nodeOne") + .addEdge("nodeOne", "nodeTwo") + .addEdge("nodeTwo", "nodeThree") + .addEdge("nodeThree", "nodeFour") + .addEdge("nodeFour", END); + + const app = workflow.compile({ + store, + checkpointer, + interruptBefore: ["nodeTwo", "nodeThree", "nodeFour"], + }); + + const input = { + messages: [ + new HumanMessage({ + content: "what is weather in sf", + id: "1", + }), + ], + }; + + // Invoke once, passing in config with `assistant_id` set to `a`. + // This will cause the shared value to be set in the state. + // After we'll update the config to have `assistant_id` set to `b`, + // and verify that the shared value set under `assistant_id` `a` is not accessible. + // Finally, we'll repeat for `b` after switching back to `a`. + const config1 = { + configurable: { thread_id: threadId, assistant_id: "a" }, + }; + await app.invoke(input, config1); + + const currentState1 = await app.getState(config1); + expect(currentState1.next).toEqual(["nodeTwo"]); + expect(currentState1.values).toEqual(input); + + // Will resume the graph, execute `nodeTwo` then interrupt again. + const config2 = { + configurable: { thread_id: threadId, assistant_id: "b" }, + }; + await app.invoke(null, config2); + + const currentState2 = await app.getState(config2); + expect(currentState2.next).toEqual(["nodeThree"]); + expect(currentState1.values).toEqual(input); + + // Will resume the graph, execute `nodeThree` then finish. + const config3 = { + configurable: { thread_id: threadId, assistant_id: "a" }, + }; + await app.invoke(null, config3); + + const currentState3 = await app.getState(config3); + expect(currentState3.next).toEqual(["nodeFour"]); + expect(currentState1.values).toEqual(input); + + // Finally, resume the graph with `assistant_id` set to `b`, and verify that the shared value is accessible. + const config4 = { + configurable: { thread_id: threadId, assistant_id: "b" }, + }; + await app.invoke(null, config4); + }); +}); diff --git a/libs/langgraph/src/tests/store.test.ts b/libs/langgraph/src/tests/store.test.ts new file mode 100644 index 00000000..17787d65 --- /dev/null +++ b/libs/langgraph/src/tests/store.test.ts @@ -0,0 +1,46 @@ +import { describe, it, expect, jest } from "@jest/globals"; +import { BaseStore, type Values } from "../store/base.js"; +import { AsyncBatchedStore } from "../store/batch.js"; + +describe("AsyncBatchedStore", () => { + it("should batch concurrent calls", async () => { + const listMock = jest.fn(); + + class MockStore extends BaseStore { + async list( + prefixes: string[] + ): Promise>> { + listMock(prefixes); + return Object.fromEntries( + prefixes.map((prefix) => [prefix, { [prefix]: { value: 1 } }]) + ); + } + + async put( + _writes: Array<[string, string, Values | null]> + ): Promise { + // Not used in this test + } + } + + const store = new AsyncBatchedStore(new MockStore()); + + // Start the store + store.start(); + + // Concurrent calls are batched + const results = await Promise.all([ + store.list(["a", "b"]), + store.list(["c", "d"]), + ]); + + expect(results).toEqual([ + { a: { a: { value: 1 } }, b: { b: { value: 1 } } }, + { c: { c: { value: 1 } }, d: { d: { value: 1 } } }, + ]); + + expect(listMock.mock.calls).toEqual([[["a", "b", "c", "d"]]]); + + store.stop(); + }); +}); diff --git a/libs/langgraph/src/utils.ts b/libs/langgraph/src/utils.ts index c25b51df..194606a1 100644 --- a/libs/langgraph/src/utils.ts +++ b/libs/langgraph/src/utils.ts @@ -143,3 +143,28 @@ export function gatherIteratorSync(i: Iterable): Array { } return out; } + +export function patchConfigurable( + config: RunnableConfig | undefined, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + patch: Record +): RunnableConfig { + if (!config) { + return { + configurable: patch, + }; + } else if (!("configurable" in config)) { + return { + ...config, + configurable: patch, + }; + } else { + return { + ...config, + configurable: { + ...config.configurable, + ...patch, + }, + }; + } +} diff --git a/libs/langgraph/src/web.ts b/libs/langgraph/src/web.ts index 4f41cb18..ba5490bf 100644 --- a/libs/langgraph/src/web.ts +++ b/libs/langgraph/src/web.ts @@ -41,3 +41,5 @@ export { emptyCheckpoint, BaseCheckpointSaver, } from "@langchain/langgraph-checkpoint"; +export * from "./store/index.js"; +export * from "./managed/index.js";