Skip to content

Commit

Permalink
feat(langgraph): Allow tools to return Commands and update graph state (
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Dec 12, 2024
1 parent 80a7e81 commit 3ef8ace
Show file tree
Hide file tree
Showing 14 changed files with 1,081 additions and 120 deletions.
771 changes: 771 additions & 0 deletions examples/how-tos/update-state-from-tools.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"devDependencies": {
"@langchain/anthropic": "^0.3.5",
"@langchain/community": "^0.3.9",
"@langchain/core": "^0.3.22",
"@langchain/core": "^0.3.23",
"@langchain/groq": "^0.1.2",
"@langchain/langgraph": "workspace:*",
"@langchain/mistralai": "^0.1.1",
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"@jest/globals": "^29.5.0",
"@langchain/anthropic": "^0.3.5",
"@langchain/community": "^0.3.9",
"@langchain/core": "^0.3.22",
"@langchain/core": "^0.3.23",
"@langchain/langgraph-checkpoint-postgres": "workspace:*",
"@langchain/langgraph-checkpoint-sqlite": "workspace:*",
"@langchain/openai": "^0.3.11",
Expand Down
28 changes: 25 additions & 3 deletions libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,14 @@ export type CommandParams<R> = {
export class Command<R = unknown> {
lg_name = "Command";

resume?: R;
lc_direct_tool_output = true;

graph?: string;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
update?: Record<string, any>;
update?: Record<string, any> | [string, any][] = [];

resume?: R;

goto: string | Send | (string | Send)[] = [];

Expand All @@ -178,8 +180,28 @@ export class Command<R = unknown> {
this.goto = Array.isArray(args.goto) ? args.goto : [args.goto];
}
}

_updateAsTuples(): [string, unknown][] {
if (
this.update &&
typeof this.update === "object" &&
!Array.isArray(this.update)
) {
return Object.entries(this.update);
} else if (
Array.isArray(this.update) &&
this.update.every(
(t): t is [string, unknown] =>
Array.isArray(t) && t.length === 2 && typeof t[0] === "string"
)
) {
return this.update;
} else {
return [["__root__", this.update]];
}
}
}

export function _isCommand(x: unknown): x is Command {
export function isCommand(x: unknown): x is Command {
return typeof x === "object" && !!x && (x as Command).lg_name === "Command";
}
121 changes: 81 additions & 40 deletions libs/langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ import {
import {
ChannelWrite,
ChannelWriteEntry,
ChannelWriteTupleEntry,
PASSTHROUGH,
SKIP_WRITE,
} from "../pregel/write.js";
import { ChannelRead, PregelNode } from "../pregel/read.js";
import { NamedBarrierValue } from "../channels/named_barrier_value.js";
import { EphemeralValue } from "../channels/ephemeral_value.js";
import { RunnableCallable } from "../utils.js";
import {
_isCommand,
isCommand,
_isSend,
CHECKPOINT_NAMESPACE_END,
CHECKPOINT_NAMESPACE_SEPARATOR,
Expand Down Expand Up @@ -503,65 +503,106 @@ export class CompiledStateGraph<
attachNode(key: N, node: StateGraphNodeSpec<S, U>): void;

attachNode(key: N | typeof START, node?: StateGraphNodeSpec<S, U>): void {
const stateKeys = Object.keys(this.builder.channels);
let outputKeys: string[];
if (key === START) {
// Get input schema keys excluding managed values
outputKeys = Object.entries(
this.builder._schemaDefinitions.get(this.builder._inputDefinition)
)
.filter(([_, v]) => !isConfiguredManagedValue(v))
.map(([k]) => k);
} else {
outputKeys = Object.keys(this.builder.channels);
}

function _getRoot(input: unknown): unknown {
if (_isCommand(input)) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function _getRoot(input: unknown): [string, any][] | null {
if (isCommand(input)) {
if (input.graph === Command.PARENT) {
return SKIP_WRITE;
return null;
}
return input._updateAsTuples();
} else if (
Array.isArray(input) &&
input.length > 0 &&
input.some((i) => isCommand(i))
) {
const updates: [string, unknown][] = [];
for (const i of input) {
if (isCommand(i)) {
if (i.graph === Command.PARENT) {
continue;
}
updates.push(...i._updateAsTuples());
} else {
updates.push([ROOT, i]);
}
}
return input.update;
return updates;
} else if (input != null) {
return [[ROOT, input]];
}
return input;
return null;
}

// to avoid name collision below
const nodeKey = key;

function getStateKey(key: keyof U, input: U): unknown {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function _getUpdates(input: U): [string, any][] | null {
if (!input) {
return SKIP_WRITE;
} else if (_isCommand(input)) {
return null;
} else if (isCommand(input)) {
if (input.graph === Command.PARENT) {
return SKIP_WRITE;
return null;
}
return getStateKey(key, input.update as U);
} else if (typeof input !== "object" || Array.isArray(input)) {
return input._updateAsTuples();
} else if (
Array.isArray(input) &&
input.length > 0 &&
input.some(isCommand)
) {
const updates: [string, unknown][] = [];
for (const item of input) {
if (isCommand(item)) {
if (item.graph === Command.PARENT) {
continue;
}
updates.push(...item._updateAsTuples());
} else {
const itemUpdates = _getUpdates(item);
if (itemUpdates) {
updates.push(...(itemUpdates ?? []));
}
}
}
return updates;
} else if (typeof input === "object" && !Array.isArray(input)) {
return Object.entries(input).filter(([k]) => outputKeys.includes(k));
} else {
const typeofInput = Array.isArray(input) ? "array" : typeof input;
throw new InvalidUpdateError(
`Expected node "${nodeKey.toString()}" to return an object, received ${typeofInput}`,
{
lc_error_code: "INVALID_GRAPH_NODE_RETURN_VALUE",
}
);
} else {
return key in input ? input[key] : SKIP_WRITE;
}
}

// state updaters
const stateWriteEntries: ChannelWriteEntry[] = stateKeys.map((key) =>
key === ROOT
? {
channel: key,
value: PASSTHROUGH,
skipNone: true,
mapper: new RunnableCallable({
func: _getRoot,
trace: false,
recurse: false,
}),
}
: {
channel: key,
value: PASSTHROUGH,
mapper: new RunnableCallable({
func: getStateKey.bind(null, key as keyof U),
trace: false,
recurse: false,
}),
}
);
const stateWriteEntries: (ChannelWriteTupleEntry | ChannelWriteEntry)[] = [
{
value: PASSTHROUGH,
mapper: new RunnableCallable({
func:
outputKeys.length && outputKeys[0] === ROOT
? _getRoot
: _getUpdates,
trace: false,
recurse: false,
}),
},
];

// add node and output channel
if (key === START) {
Expand Down Expand Up @@ -768,7 +809,7 @@ function _controlBranch(value: any): (string | Send)[] {
if (_isSend(value)) {
return [value];
}
if (!_isCommand(value)) {
if (!isCommand(value)) {
return [];
}
if (value.graph === Command.PARENT) {
Expand Down
34 changes: 24 additions & 10 deletions libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ import {
BaseStore,
} from "@langchain/langgraph-checkpoint";

import { END, START, StateGraph, CompiledStateGraph } from "../graph/index.js";
import {
END,
START,
StateGraph,
CompiledStateGraph,
AnnotationRoot,
} from "../graph/index.js";
import { MessagesAnnotation } from "../graph/messages_annotation.js";
import { ToolNode } from "./tool_node.js";
import { LangGraphRunnableConfig } from "../pregel/runnable_types.js";
Expand Down Expand Up @@ -147,13 +153,14 @@ export type MessageModifier =
| ((messages: BaseMessage[]) => Promise<BaseMessage[]>)
| Runnable;

export type CreateReactAgentParams = {
export type CreateReactAgentParams<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
A extends AnnotationRoot<any> = AnnotationRoot<any>
> = {
/** The chat model that can utilize OpenAI-style tool calling. */
llm: BaseChatModel;
/** A list of tools or a ToolNode. */
tools:
| ToolNode<typeof MessagesAnnotation.State>
| (StructuredToolInterface | RunnableToolLike)[];
tools: ToolNode | (StructuredToolInterface | RunnableToolLike)[];
/**
* @deprecated
* Use stateModifier instead. stateModifier works the same as
Expand Down Expand Up @@ -208,6 +215,7 @@ export type CreateReactAgentParams = {
* - Runnable: This runnable should take in full graph state and the output is then passed to the language model.
*/
stateModifier?: StateModifier;
stateSchema?: A;
/** An optional checkpoint saver to persist the agent's state. */
checkpointSaver?: BaseCheckpointSaver;
/** An optional list of node names to interrupt before running. */
Expand Down Expand Up @@ -260,18 +268,24 @@ export type CreateReactAgentParams = {
* ```
*/

export function createReactAgent(
params: CreateReactAgentParams
export function createReactAgent<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
A extends AnnotationRoot<any> = AnnotationRoot<any>
>(
params: CreateReactAgentParams<A>
): CompiledStateGraph<
(typeof MessagesAnnotation)["State"],
(typeof MessagesAnnotation)["Update"],
typeof START | "agent" | "tools"
typeof START | "agent" | "tools",
typeof MessagesAnnotation.spec & A["spec"],
typeof MessagesAnnotation.spec & A["spec"]
> {
const {
llm,
tools,
messageModifier,
stateModifier,
stateSchema,
checkpointSaver,
interruptBefore,
interruptAfter,
Expand Down Expand Up @@ -314,9 +328,9 @@ export function createReactAgent(
return { messages: [await modelRunnable.invoke(state, config)] };
};

const workflow = new StateGraph(MessagesAnnotation)
const workflow = new StateGraph(stateSchema ?? MessagesAnnotation)
.addNode("agent", callModel)
.addNode("tools", new ToolNode<AgentState>(toolClasses))
.addNode("tools", new ToolNode(toolClasses))
.addEdge(START, "agent")
.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
Expand Down
20 changes: 18 additions & 2 deletions libs/langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { RunnableCallable } from "../utils.js";
import { END } from "../graph/graph.js";
import { MessagesAnnotation } from "../graph/messages_annotation.js";
import { isGraphInterrupt } from "../errors.js";
import { isCommand } from "../constants.js";

export type ToolNodeOptions = {
name?: string;
Expand Down Expand Up @@ -173,7 +174,10 @@ export class ToolNode<T = any> extends RunnableCallable<T, T> {
{ ...call, type: "tool_call" },
config
);
if (isBaseMessage(output) && output._getType() === "tool") {
if (
(isBaseMessage(output) && output._getType() === "tool") ||
isCommand(output)
) {
return output;
} else {
return new ToolMessage({
Expand Down Expand Up @@ -203,7 +207,19 @@ export class ToolNode<T = any> extends RunnableCallable<T, T> {
}) ?? []
);

return (Array.isArray(input) ? outputs : { messages: outputs }) as T;
// Preserve existing behavior for non-command tool outputs for backwards compatibility
if (!outputs.some(isCommand)) {
return (Array.isArray(input) ? outputs : { messages: outputs }) as T;
}

// Handle mixed Command and non-Command outputs
const combinedOutputs = outputs.map((output) => {
if (isCommand(output)) {
return output;
}
return Array.isArray(input) ? [output] : { messages: [output] };
});
return combinedOutputs as T;
}
}

Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/src/pregel/loop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
} from "../channels/base.js";
import { PregelExecutableTask, StreamMode } from "./types.js";
import {
_isCommand,
isCommand,
CHECKPOINT_NAMESPACE_SEPARATOR,
Command,
CONFIG_KEY_CHECKPOINT_MAP,
Expand Down Expand Up @@ -738,8 +738,8 @@ export class PregelLoop {
Object.keys(this.checkpoint.channel_versions).length !== 0 &&
(this.config.configurable?.[CONFIG_KEY_RESUMING] !== undefined ||
this.input === null ||
_isCommand(this.input));
if (_isCommand(this.input)) {
isCommand(this.input));
if (isCommand(this.input)) {
const writes: { [key: string]: PendingWrite[] } = {};
// group writes by task id
for (const [tid, key, value] of mapCommand(
Expand Down
Loading

0 comments on commit 3ef8ace

Please sign in to comment.