Skip to content

Commit

Permalink
fix(langgraph): Fix state graph invoke typing (#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Dec 12, 2024
1 parent 02fe24c commit 63c16e7
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 12 deletions.
10 changes: 7 additions & 3 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,18 @@ export class CompiledGraph<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Update = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType extends Record<string, any> = Record<string, any>
ConfigurableFieldType extends Record<string, any> = Record<string, any>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
InputType = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
OutputType = any
> extends Pregel<
Record<N | typeof START, PregelNode<State, Update>>,
Record<N | typeof START | typeof END | string, BaseChannel>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType & Record<string, any>,
Update,
State
InputType,
OutputType
> {
declare NodeType: N;

Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ export class CompiledStateGraph<
I extends StateDefinition = StateDefinition,
O extends StateDefinition = StateDefinition,
C extends StateDefinition = StateDefinition
> extends CompiledGraph<N, S, U, StateType<C>> {
> extends CompiledGraph<N, S, U, StateType<C>, UpdateType<I>, StateType<O>> {
declare builder: StateGraph<unknown, S, U, N, I, O, C>;

attachNode(key: typeof START, node?: never): void;
Expand Down
85 changes: 77 additions & 8 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3629,6 +3629,7 @@ graph TD;

const OutputAnnotation = Annotation.Root({
messages: Annotation<string[]>,
extraOutput: Annotation<string>,
});

const nodeA = (state: { hello: string; messages: string[] }) => {
Expand Down Expand Up @@ -3676,14 +3677,21 @@ graph TD;
.addEdge("b", "c")
.compile();

expect(
await graph.invoke({
hello: "there",
bye: "world",
messages: ["hello"],
})
).toEqual({
const res = await graph.invoke({
hello: "there",
bye: "world",
messages: ["hello"],
// @ts-expect-error Output schema properties should not be part of input types
extraOutput: "bar",
});

// State graph should respect output typing
void res.extraOutput;

expect(res).toEqual({
messages: ["hello"],
// Will still be added to state despite typing
extraOutput: "bar",
});

const graphWithInput = new StateGraph({
Expand Down Expand Up @@ -3721,6 +3729,68 @@ graph TD;
})
)
).toEqual([{}, { b: { hello: "again" } }, {}]);

const res2 = await graph.invoke({
hello: "there",
bye: "world",
messages: ["hello"],
// @ts-expect-error Output schema properties should not be part of input types
extraOutput: "bar",
});

// State graph should respect output typing
void res2.extraOutput;
// @ts-expect-error Output type should not have a field not in the output schema, even if in other state
void res2.hello;
// @ts-expect-error Output type should not have a field not in the output schema, even if in other state
void res2.random;

expect(res2).toEqual({
messages: ["hello"],
// Will still be added to state despite typing
extraOutput: "bar",
});

const InputStateAnnotation = Annotation.Root({
specialInputField: Annotation<string>,
});

const graphWithAllSchemas = new StateGraph({
input: InputStateAnnotation,
output: OutputAnnotation,
stateSchema: StateAnnotation,
})
.addNode("preA", async () => {
return {
bye: "world",
hello: "there",
messages: ["hello"],
};
})
.addNode("a", nodeA)
.addNode("b", nodeB)
.addNode("c", nodeC)
.addEdge(START, "preA")
.addEdge("preA", "a")
.addEdge("a", "b")
.addEdge("b", "c")
.compile();

const res3 = await graphWithAllSchemas.invoke({
// @ts-expect-error Input type should not contain fields outside input schema, even if in other states
hello: "there",
specialInputField: "foo",
});
expect(res3).toEqual({
messages: ["hello"],
});

// Extra output fields should be respected
void res3.extraOutput;
// @ts-expect-error Output type should not have a field not in the output schema, even if in other state
void res3.hello;
// @ts-expect-error Output type should not have a field not in the output schema, even if in other state
void res3.random;
});

it("should use a retry policy", async () => {
Expand Down Expand Up @@ -3876,7 +3946,6 @@ graph TD;
};
const res = await app.invoke(
{
// @ts-expect-error Messages is not in schema
messages: ["initial input"],
},
config
Expand Down

0 comments on commit 63c16e7

Please sign in to comment.