Skip to content

Commit

Permalink
fix(langgraph): Serialize command objects passed as input into remote…
Browse files Browse the repository at this point in the history
… graph (#766)
  • Loading branch information
jacoblee93 authored Dec 30, 2024
1 parent 3483450 commit 163edfa
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 14 deletions.
2 changes: 1 addition & 1 deletion libs/langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"license": "MIT",
"dependencies": {
"@langchain/langgraph-checkpoint": "~0.0.13",
"@langchain/langgraph-sdk": "~0.0.21",
"@langchain/langgraph-sdk": "~0.0.32",
"uuid": "^10.0.0",
"zod": "^3.23.8"
},
Expand Down
37 changes: 33 additions & 4 deletions libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ export class Send implements SendInterface {

// eslint-disable-next-line @typescript-eslint/no-explicit-any
constructor(public node: string, public args: any) {}

toJSON() {
return {
node: this.node,
args: this.args,
};
}
}

export function _isSend(x: unknown): x is Send {
Expand All @@ -138,11 +145,12 @@ export type CommandParams<R> = {
* - GraphCommand.PARENT: closest parent graph
*/
graph?: string;

/**
* Update to apply to the graph's state.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
update?: Record<string, any>;
update?: Record<string, unknown> | [string, unknown][];

/**
* Can be one of the following:
* - name of the node to navigate to next (any node that belongs to the specified `graph`)
Expand Down Expand Up @@ -222,8 +230,7 @@ export class Command<R = unknown> {

graph?: string;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
update?: Record<string, any> | [string, any][] = [];
update?: Record<string, unknown> | [string, unknown][];

resume?: R;

Expand Down Expand Up @@ -259,6 +266,28 @@ export class Command<R = unknown> {
return [["__root__", this.update]];
}
}

toJSON() {
let serializedGoto;
if (typeof this.goto === "string") {
serializedGoto = this.goto;
} else if (_isSend(this.goto)) {
serializedGoto = this.goto.toJSON();
} else {
serializedGoto = this.goto.map((innerGoto) => {
if (typeof innerGoto === "string") {
return innerGoto;
} else {
return innerGoto.toJSON();
}
});
}
return {
update: this.update,
resume: this.resume,
goto: serializedGoto,
};
}
}

export function isCommand(x: unknown): x is Command {
Expand Down
18 changes: 15 additions & 3 deletions libs/langgraph/src/pregel/remote.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import {
CHECKPOINT_NAMESPACE_SEPARATOR,
CONFIG_KEY_STREAM,
INTERRUPT,
isCommand,
} from "../constants.js";

export type RemoteGraphParams = Omit<
Expand Down Expand Up @@ -399,8 +400,8 @@ export class RemoteGraph<
const streamSubgraphs =
options?.subgraphs ?? streamProtocolInstance !== undefined;

const interruptBefore = this.interruptBefore ?? options?.interruptBefore;
const interruptAfter = this.interruptAfter ?? options?.interruptAfter;
const interruptBefore = options?.interruptBefore ?? this.interruptBefore;
const interruptAfter = options?.interruptAfter ?? this.interruptAfter;

const { updatedStreamModes, reqSingle, reqUpdates } = getStreamModes(
options?.streamMode
Expand All @@ -413,11 +414,22 @@ export class RemoteGraph<
]),
];

let command;
let serializedInput;
if (isCommand(input)) {
// TODO: Remove cast when SDK type fix gets merged
command = input.toJSON() as Record<string, unknown>;
serializedInput = undefined;
} else {
serializedInput = _serializeInputs(input);
}

for await (const chunk of this.client.runs.stream(
sanitizedConfig.configurable.thread_id as string,
this.graphId,
{
input: _serializeInputs(input),
command,
input: serializedInput,
config: sanitizedConfig,
streamMode: extendedStreamModes,
interruptBefore: interruptBefore as string[],
Expand Down
57 changes: 56 additions & 1 deletion libs/langgraph/src/tests/remote.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { jest } from "@jest/globals";
import { Client } from "@langchain/langgraph-sdk";
import { RemoteGraph } from "../pregel/remote.js";
import { gatherIterator } from "../utils.js";
import { INTERRUPT } from "../constants.js";
import { Command, INTERRUPT, Send } from "../constants.js";
import { GraphInterrupt } from "../errors.js";

describe("RemoteGraph", () => {
Expand Down Expand Up @@ -474,4 +474,59 @@ describe("RemoteGraph", () => {
);
expect(result).toEqual({ messages: [{ type: "human", content: "world" }] });
});

test("invoke with a Command serializes properly", async () => {
const client = new Client({});
let streamArgs;
jest
.spyOn((client as any).runs, "stream")
.mockImplementation(async function* (...args) {
streamArgs = args;
const chunks = [
{ event: "values", data: { chunk: "data1" } },
{ event: "values", data: { chunk: "data2" } },
{
event: "values",
data: { messages: [{ type: "human", content: "world" }] },
},
];
for (const chunk of chunks) {
yield chunk;
}
});

const remotePregel = new RemoteGraph({
client,
graphId: "test_graph_id",
});

const config = { configurable: { thread_id: "thread_1" } };
const result = await remotePregel.invoke(
new Command({
goto: ["one", new Send("foo", { baz: "qux" })],
resume: "bar",
update: { foo: "bar" },
}),
config
);
expect(result).toEqual({ messages: [{ type: "human", content: "world" }] });
expect(streamArgs).toEqual([
"thread_1",
"test_graph_id",
{
command: {
update: { foo: "bar" },
resume: "bar",
goto: ["one", { node: "foo", args: { baz: "qux" } }],
},
input: undefined,
config: expect.anything(),
streamMode: ["values", "updates"],
interruptBefore: undefined,
interruptAfter: undefined,
streamSubgraphs: false,
ifNotExists: "create",
},
]);
});
});
10 changes: 5 additions & 5 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1832,15 +1832,15 @@ __metadata:
languageName: unknown
linkType: soft

"@langchain/langgraph-sdk@npm:~0.0.21":
version: 0.0.21
resolution: "@langchain/langgraph-sdk@npm:0.0.21"
"@langchain/langgraph-sdk@npm:~0.0.32":
version: 0.0.32
resolution: "@langchain/langgraph-sdk@npm:0.0.32"
dependencies:
"@types/json-schema": ^7.0.15
p-queue: ^6.6.2
p-retry: 4
uuid: ^9.0.0
checksum: 5dabe873b6cf3fbaa66b445c7032f1188d0c3eb7872457d0a8131521992a1e999e73514eeb9a344c10ee2315992e7130739de32257e2b3107a6dfe5f4027d6a2
checksum: dd0e3fd1f3880e1ddff65108f338e75365bdff60f8ca5dc4e1b1e1be79fa216f1e599fa5d7f62ab9b44778787426f20b82db4e97463d6f3f4eda43f81666bfe9
languageName: node
linkType: hard

Expand All @@ -1855,7 +1855,7 @@ __metadata:
"@langchain/langgraph-checkpoint": ~0.0.13
"@langchain/langgraph-checkpoint-postgres": "workspace:*"
"@langchain/langgraph-checkpoint-sqlite": "workspace:*"
"@langchain/langgraph-sdk": ~0.0.21
"@langchain/langgraph-sdk": ~0.0.32
"@langchain/openai": ^0.3.11
"@langchain/scripts": ">=0.1.3 <0.2.0"
"@swc/core": ^1.3.90
Expand Down

0 comments on commit 163edfa

Please sign in to comment.