From 163edfae7380926bc927edf62086b0e91f99d220 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 30 Dec 2024 14:06:57 -0800 Subject: [PATCH] fix(langgraph): Serialize command objects passed as input into remote graph (#766) --- libs/langgraph/package.json | 2 +- libs/langgraph/src/constants.ts | 37 ++++++++++++++-- libs/langgraph/src/pregel/remote.ts | 18 ++++++-- libs/langgraph/src/tests/remote.test.ts | 57 ++++++++++++++++++++++++- yarn.lock | 10 ++--- 5 files changed, 110 insertions(+), 14 deletions(-) diff --git a/libs/langgraph/package.json b/libs/langgraph/package.json index f660287f8..6bcb0c188 100644 --- a/libs/langgraph/package.json +++ b/libs/langgraph/package.json @@ -32,7 +32,7 @@ "license": "MIT", "dependencies": { "@langchain/langgraph-checkpoint": "~0.0.13", - "@langchain/langgraph-sdk": "~0.0.21", + "@langchain/langgraph-sdk": "~0.0.32", "uuid": "^10.0.0", "zod": "^3.23.8" }, diff --git a/libs/langgraph/src/constants.ts b/libs/langgraph/src/constants.ts index 186199ed3..595880d8f 100644 --- a/libs/langgraph/src/constants.ts +++ b/libs/langgraph/src/constants.ts @@ -112,6 +112,13 @@ export class Send implements SendInterface { // eslint-disable-next-line @typescript-eslint/no-explicit-any constructor(public node: string, public args: any) {} + + toJSON() { + return { + node: this.node, + args: this.args, + }; + } } export function _isSend(x: unknown): x is Send { @@ -138,11 +145,12 @@ export type CommandParams = { * - GraphCommand.PARENT: closest parent graph */ graph?: string; + /** * Update to apply to the graph's state. */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - update?: Record; + update?: Record | [string, unknown][]; + /** * Can be one of the following: * - name of the node to navigate to next (any node that belongs to the specified `graph`) @@ -222,8 +230,7 @@ export class Command { graph?: string; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - update?: Record | [string, any][] = []; + update?: Record | [string, unknown][]; resume?: R; @@ -259,6 +266,28 @@ export class Command { return [["__root__", this.update]]; } } + + toJSON() { + let serializedGoto; + if (typeof this.goto === "string") { + serializedGoto = this.goto; + } else if (_isSend(this.goto)) { + serializedGoto = this.goto.toJSON(); + } else { + serializedGoto = this.goto.map((innerGoto) => { + if (typeof innerGoto === "string") { + return innerGoto; + } else { + return innerGoto.toJSON(); + } + }); + } + return { + update: this.update, + resume: this.resume, + goto: serializedGoto, + }; + } } export function isCommand(x: unknown): x is Command { diff --git a/libs/langgraph/src/pregel/remote.ts b/libs/langgraph/src/pregel/remote.ts index f54ea45f3..1e6bec719 100644 --- a/libs/langgraph/src/pregel/remote.ts +++ b/libs/langgraph/src/pregel/remote.ts @@ -42,6 +42,7 @@ import { CHECKPOINT_NAMESPACE_SEPARATOR, CONFIG_KEY_STREAM, INTERRUPT, + isCommand, } from "../constants.js"; export type RemoteGraphParams = Omit< @@ -399,8 +400,8 @@ export class RemoteGraph< const streamSubgraphs = options?.subgraphs ?? streamProtocolInstance !== undefined; - const interruptBefore = this.interruptBefore ?? options?.interruptBefore; - const interruptAfter = this.interruptAfter ?? options?.interruptAfter; + const interruptBefore = options?.interruptBefore ?? this.interruptBefore; + const interruptAfter = options?.interruptAfter ?? this.interruptAfter; const { updatedStreamModes, reqSingle, reqUpdates } = getStreamModes( options?.streamMode @@ -413,11 +414,22 @@ export class RemoteGraph< ]), ]; + let command; + let serializedInput; + if (isCommand(input)) { + // TODO: Remove cast when SDK type fix gets merged + command = input.toJSON() as Record; + serializedInput = undefined; + } else { + serializedInput = _serializeInputs(input); + } + for await (const chunk of this.client.runs.stream( sanitizedConfig.configurable.thread_id as string, this.graphId, { - input: _serializeInputs(input), + command, + input: serializedInput, config: sanitizedConfig, streamMode: extendedStreamModes, interruptBefore: interruptBefore as string[], diff --git a/libs/langgraph/src/tests/remote.test.ts b/libs/langgraph/src/tests/remote.test.ts index a9542e1e9..54dd2e954 100644 --- a/libs/langgraph/src/tests/remote.test.ts +++ b/libs/langgraph/src/tests/remote.test.ts @@ -3,7 +3,7 @@ import { jest } from "@jest/globals"; import { Client } from "@langchain/langgraph-sdk"; import { RemoteGraph } from "../pregel/remote.js"; import { gatherIterator } from "../utils.js"; -import { INTERRUPT } from "../constants.js"; +import { Command, INTERRUPT, Send } from "../constants.js"; import { GraphInterrupt } from "../errors.js"; describe("RemoteGraph", () => { @@ -474,4 +474,59 @@ describe("RemoteGraph", () => { ); expect(result).toEqual({ messages: [{ type: "human", content: "world" }] }); }); + + test("invoke with a Command serializes properly", async () => { + const client = new Client({}); + let streamArgs; + jest + .spyOn((client as any).runs, "stream") + .mockImplementation(async function* (...args) { + streamArgs = args; + const chunks = [ + { event: "values", data: { chunk: "data1" } }, + { event: "values", data: { chunk: "data2" } }, + { + event: "values", + data: { messages: [{ type: "human", content: "world" }] }, + }, + ]; + for (const chunk of chunks) { + yield chunk; + } + }); + + const remotePregel = new RemoteGraph({ + client, + graphId: "test_graph_id", + }); + + const config = { configurable: { thread_id: "thread_1" } }; + const result = await remotePregel.invoke( + new Command({ + goto: ["one", new Send("foo", { baz: "qux" })], + resume: "bar", + update: { foo: "bar" }, + }), + config + ); + expect(result).toEqual({ messages: [{ type: "human", content: "world" }] }); + expect(streamArgs).toEqual([ + "thread_1", + "test_graph_id", + { + command: { + update: { foo: "bar" }, + resume: "bar", + goto: ["one", { node: "foo", args: { baz: "qux" } }], + }, + input: undefined, + config: expect.anything(), + streamMode: ["values", "updates"], + interruptBefore: undefined, + interruptAfter: undefined, + streamSubgraphs: false, + ifNotExists: "create", + }, + ]); + }); }); diff --git a/yarn.lock b/yarn.lock index 3621f3ba3..e2bfe8390 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1832,15 +1832,15 @@ __metadata: languageName: unknown linkType: soft -"@langchain/langgraph-sdk@npm:~0.0.21": - version: 0.0.21 - resolution: "@langchain/langgraph-sdk@npm:0.0.21" +"@langchain/langgraph-sdk@npm:~0.0.32": + version: 0.0.32 + resolution: "@langchain/langgraph-sdk@npm:0.0.32" dependencies: "@types/json-schema": ^7.0.15 p-queue: ^6.6.2 p-retry: 4 uuid: ^9.0.0 - checksum: 5dabe873b6cf3fbaa66b445c7032f1188d0c3eb7872457d0a8131521992a1e999e73514eeb9a344c10ee2315992e7130739de32257e2b3107a6dfe5f4027d6a2 + checksum: dd0e3fd1f3880e1ddff65108f338e75365bdff60f8ca5dc4e1b1e1be79fa216f1e599fa5d7f62ab9b44778787426f20b82db4e97463d6f3f4eda43f81666bfe9 languageName: node linkType: hard @@ -1855,7 +1855,7 @@ __metadata: "@langchain/langgraph-checkpoint": ~0.0.13 "@langchain/langgraph-checkpoint-postgres": "workspace:*" "@langchain/langgraph-checkpoint-sqlite": "workspace:*" - "@langchain/langgraph-sdk": ~0.0.21 + "@langchain/langgraph-sdk": ~0.0.32 "@langchain/openai": ^0.3.11 "@langchain/scripts": ">=0.1.3 <0.2.0" "@swc/core": ^1.3.90