From cc0099a134fc638ae20456766112fa3bcc3cbb8b Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Wed, 18 Sep 2024 14:09:53 -0700 Subject: [PATCH] fix(langgraph): Fix bug in getting state with managed values (#493) --- libs/langgraph/src/managed/base.ts | 16 ++++++ libs/langgraph/src/pregel/index.ts | 21 ++++++-- libs/langgraph/src/tests/pregel.test.ts | 71 +++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/src/managed/base.ts b/libs/langgraph/src/managed/base.ts index 653e8b440..21ea3e60e 100644 --- a/libs/langgraph/src/managed/base.ts +++ b/libs/langgraph/src/managed/base.ts @@ -185,3 +185,19 @@ export function isConfiguredManagedValue( } return false; } + +/** + * No-op class used when getting state values, as managed values should never be returned + * in get state calls. + */ +export class NoopManagedValue extends ManagedValue { + call() {} + + static async initialize( + config: RunnableConfig, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _args?: any + ): Promise { + return Promise.resolve(new NoopManagedValue(config)); + } +} diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 45f94f1eb..0bd2e2fc3 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -73,6 +73,7 @@ import { isConfiguredManagedValue, ManagedValue, ManagedValueMapping, + NoopManagedValue, type ManagedValueSpec, } from "../managed/base.js"; import { patchConfigurable } from "../utils.js"; @@ -316,7 +317,8 @@ export class Pregel< this.channels as Record, checkpoint ); - const { managed } = await this.prepareSpecs(config); + // Pass `skipManaged: true` as managed values should not be returned in get state calls. + const { managed } = await this.prepareSpecs(config, { skipManaged: true }); const nextTasks = _prepareNextTasks( checkpoint, @@ -351,7 +353,8 @@ export class Pregel< if (!this.checkpointer) { throw new GraphValueError("No checkpointer set"); } - const { managed } = await this.prepareSpecs(config); + // Pass `skipManaged: true` as managed values should not be returned in get state calls. + const { managed } = await this.prepareSpecs(config, { skipManaged: true }); for await (const saved of this.checkpointer.list(config, options)) { const channels = emptyChannels( @@ -668,7 +671,14 @@ export class Pregel< return super.stream(input, options); } - protected async prepareSpecs(config: RunnableConfig) { + protected async prepareSpecs( + config: RunnableConfig, + options?: { + // Equivalent to the `skip_context` option in Python, but renamed + // to `managed` since JS does not implement the `Context` class. + skipManaged?: boolean; + } + ) { const configForManaged = patchConfigurable(config, { [CONFIG_KEY_STORE]: this.store, }); @@ -678,6 +688,11 @@ export class Pregel< for (const [name, spec] of Object.entries(this.channels)) { if (isBaseChannel(spec)) { channelSpecs[name] = spec; + } else if (options?.skipManaged) { + managedSpecs[name] = { + cls: NoopManagedValue, + params: { config: {} }, + }; } else { managedSpecs[name] = spec; } diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index c68fa48f3..899b1c33d 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -5299,4 +5299,75 @@ describe("Managed Values (context) can be passed through state", () => { }; await app.invoke(null, config4); }); + + it("can get state when state has shared values", async () => { + const nodeOne = (_: typeof AgentAnnotation.State) => { + return { + messages: [ + { + role: "assistant", + content: "no-op", + }, + ], + sharedStateKey: { + data: { + value: "shared", + }, + }, + }; + }; + + const nodeTwo = (_: typeof AgentAnnotation.State) => { + // no-op + return {}; + }; + + const workflow = new StateGraph(AgentAnnotation) + .addNode("nodeOne", nodeOne) + .addNode("nodeTwo", nodeTwo) + .addEdge(START, "nodeOne") + .addEdge("nodeOne", "nodeTwo") + .addEdge("nodeTwo", END); + + const app = workflow.compile({ + store, + checkpointer, + interruptBefore: ["nodeTwo"], + }); + + const config: Record> = { + configurable: { thread_id: threadId, assistant_id: "a" }, + }; + + // Execute the graph. This will run `nodeOne` which sets the shared value, + // then is interrupted before executing `nodeTwo`. + await app.invoke( + { + messages: [ + { + role: "user", + content: "no-op", + }, + ], + }, + config + ); + + // Remove the "assistant_id" from the config and attempt to fetch the state. + // Since a `noop` managed value class is used when getting state, it should work + // even though the shared value key is not present. + if (config.configurable.assistant_id) { + delete config.configurable.assistant_id; + } + // Expect it does not throw an error complaining that the `assistant_id` key + // is not found in the config. + expect(await app.getState(config)).toBeTruthy(); + + // Re-running without re-setting the `assistant_id` key in the config should throw an error. + await expect(app.invoke(null, config)).rejects.toThrow(/assistant_id/); + + // Re-set the `assistant_id` key in the config and attempt to fetch the state. + config.configurable.assistant_id = "a"; + await app.invoke(null, config); + }); });