From 63c16e7b26b4b8b7b198f7098b5b1b46fc48930d Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Wed, 11 Dec 2024 16:02:26 -0800 Subject: [PATCH] fix(langgraph): Fix state graph invoke typing (#735) --- libs/langgraph/src/graph/graph.ts | 10 ++- libs/langgraph/src/graph/state.ts | 2 +- libs/langgraph/src/tests/pregel.test.ts | 85 ++++++++++++++++++++++--- 3 files changed, 85 insertions(+), 12 deletions(-) diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 215c5434..aabb19a8 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -492,14 +492,18 @@ export class CompiledGraph< // eslint-disable-next-line @typescript-eslint/no-explicit-any Update = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any - ConfigurableFieldType extends Record = Record + ConfigurableFieldType extends Record = Record, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + InputType = any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + OutputType = any > extends Pregel< Record>, Record, // eslint-disable-next-line @typescript-eslint/no-explicit-any ConfigurableFieldType & Record, - Update, - State + InputType, + OutputType > { declare NodeType: N; diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index a1fd168a..d4b96e6d 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -495,7 +495,7 @@ export class CompiledStateGraph< I extends StateDefinition = StateDefinition, O extends StateDefinition = StateDefinition, C extends StateDefinition = StateDefinition -> extends CompiledGraph> { +> extends CompiledGraph, UpdateType, StateType> { declare builder: StateGraph; attachNode(key: typeof START, node?: never): void; diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 2d5644cd..532c4688 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -3629,6 +3629,7 @@ graph TD; const OutputAnnotation = Annotation.Root({ messages: Annotation, + extraOutput: Annotation, }); const nodeA = (state: { hello: string; messages: string[] }) => { @@ -3676,14 +3677,21 @@ graph TD; .addEdge("b", "c") .compile(); - expect( - await graph.invoke({ - hello: "there", - bye: "world", - messages: ["hello"], - }) - ).toEqual({ + const res = await graph.invoke({ + hello: "there", + bye: "world", + messages: ["hello"], + // @ts-expect-error Output schema properties should not be part of input types + extraOutput: "bar", + }); + + // State graph should respect output typing + void res.extraOutput; + + expect(res).toEqual({ messages: ["hello"], + // Will still be added to state despite typing + extraOutput: "bar", }); const graphWithInput = new StateGraph({ @@ -3721,6 +3729,68 @@ graph TD; }) ) ).toEqual([{}, { b: { hello: "again" } }, {}]); + + const res2 = await graph.invoke({ + hello: "there", + bye: "world", + messages: ["hello"], + // @ts-expect-error Output schema properties should not be part of input types + extraOutput: "bar", + }); + + // State graph should respect output typing + void res2.extraOutput; + // @ts-expect-error Output type should not have a field not in the output schema, even if in other state + void res2.hello; + // @ts-expect-error Output type should not have a field not in the output schema, even if in other state + void res2.random; + + expect(res2).toEqual({ + messages: ["hello"], + // Will still be added to state despite typing + extraOutput: "bar", + }); + + const InputStateAnnotation = Annotation.Root({ + specialInputField: Annotation, + }); + + const graphWithAllSchemas = new StateGraph({ + input: InputStateAnnotation, + output: OutputAnnotation, + stateSchema: StateAnnotation, + }) + .addNode("preA", async () => { + return { + bye: "world", + hello: "there", + messages: ["hello"], + }; + }) + .addNode("a", nodeA) + .addNode("b", nodeB) + .addNode("c", nodeC) + .addEdge(START, "preA") + .addEdge("preA", "a") + .addEdge("a", "b") + .addEdge("b", "c") + .compile(); + + const res3 = await graphWithAllSchemas.invoke({ + // @ts-expect-error Input type should not contain fields outside input schema, even if in other states + hello: "there", + specialInputField: "foo", + }); + expect(res3).toEqual({ + messages: ["hello"], + }); + + // Extra output fields should be respected + void res3.extraOutput; + // @ts-expect-error Output type should not have a field not in the output schema, even if in other state + void res3.hello; + // @ts-expect-error Output type should not have a field not in the output schema, even if in other state + void res3.random; }); it("should use a retry policy", async () => { @@ -3876,7 +3946,6 @@ graph TD; }; const res = await app.invoke( { - // @ts-expect-error Messages is not in schema messages: ["initial input"], }, config