Skip to content

Commit 223e5ef

Browse files
authored
Add JSON representation of runnable graph to serialized representation (langchain-ai#17745)
Sent to LangSmith Thank you for contributing to LangChain! Checklist: - [ ] PR title: Please title your PR "package: description", where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] PR message: **Delete this entire template message** and replace it with the following bulleted list - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] Pass lint and test: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified to check that you're passing lint and testing. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ - [ ] Add tests and docs: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17.
1 parent 6e854ae commit 223e5ef

File tree

7 files changed

+51499
-1213
lines changed

7 files changed

+51499
-1213
lines changed

libs/core/langchain_core/load/serializable.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
from abc import ABC
2-
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
2+
from typing import (
3+
Any,
4+
Dict,
5+
List,
6+
Literal,
7+
Optional,
8+
TypedDict,
9+
Union,
10+
cast,
11+
)
12+
13+
from typing_extensions import NotRequired
314

415
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
516

@@ -9,6 +20,8 @@ class BaseSerialized(TypedDict):
920

1021
lc: int
1122
id: List[str]
23+
name: NotRequired[str]
24+
graph: NotRequired[Dict[str, Any]]
1225

1326

1427
class SerializedConstructor(BaseSerialized):

libs/core/langchain_core/runnables/base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737

3838
from langchain_core._api import beta_decorator
3939
from langchain_core.load.dump import dumpd
40-
from langchain_core.load.serializable import Serializable
40+
from langchain_core.load.serializable import (
41+
Serializable,
42+
SerializedConstructor,
43+
SerializedNotImplemented,
44+
)
4145
from langchain_core.pydantic_v1 import BaseModel, Field
4246
from langchain_core.runnables.config import (
4347
RunnableConfig,
@@ -1630,6 +1634,16 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
16301634
name: Optional[str] = None
16311635
"""The name of the runnable. Used for debugging and tracing."""
16321636

1637+
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
1638+
"""Serialize the runnable to JSON."""
1639+
dumped = super().to_json()
1640+
try:
1641+
dumped["name"] = self.get_name()
1642+
dumped["graph"] = self.get_graph().to_json()
1643+
except Exception:
1644+
pass
1645+
return dumped
1646+
16331647
def configurable_fields(
16341648
self, **kwargs: AnyConfigurableField
16351649
) -> RunnableSerializable[Input, Output]:

libs/core/langchain_core/runnables/graph.py

Lines changed: 103 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3+
import inspect
34
from dataclasses import dataclass, field
4-
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
5-
from uuid import uuid4
5+
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union
6+
from uuid import UUID, uuid4
67

78
from langchain_core.pydantic_v1 import BaseModel
89
from langchain_core.runnables.graph_draw import draw
@@ -11,11 +12,20 @@
1112
from langchain_core.runnables.base import Runnable as RunnableType
1213

1314

15+
def is_uuid(value: str) -> bool:
16+
try:
17+
UUID(value)
18+
return True
19+
except ValueError:
20+
return False
21+
22+
1423
class Edge(NamedTuple):
1524
"""Edge in a graph."""
1625

1726
source: str
1827
target: str
28+
data: Optional[str] = None
1929

2030

2131
class Node(NamedTuple):
@@ -25,22 +35,108 @@ class Node(NamedTuple):
2535
data: Union[Type[BaseModel], RunnableType]
2636

2737

38+
def node_data_str(node: Node) -> str:
39+
from langchain_core.runnables.base import Runnable
40+
41+
if not is_uuid(node.id):
42+
return node.id
43+
elif isinstance(node.data, Runnable):
44+
try:
45+
data = str(node.data)
46+
if (
47+
data.startswith("<")
48+
or data[0] != data[0].upper()
49+
or len(data.splitlines()) > 1
50+
):
51+
data = node.data.__class__.__name__
52+
elif len(data) > 42:
53+
data = data[:42] + "..."
54+
except Exception:
55+
data = node.data.__class__.__name__
56+
else:
57+
data = node.data.__name__
58+
return data if not data.startswith("Runnable") else data[8:]
59+
60+
61+
def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
62+
from langchain_core.load.serializable import to_json_not_implemented
63+
from langchain_core.runnables.base import Runnable, RunnableSerializable
64+
65+
if isinstance(node.data, RunnableSerializable):
66+
return {
67+
"type": "runnable",
68+
"data": {
69+
"id": node.data.lc_id(),
70+
"name": node.data.get_name(),
71+
},
72+
}
73+
elif isinstance(node.data, Runnable):
74+
return {
75+
"type": "runnable",
76+
"data": {
77+
"id": to_json_not_implemented(node.data)["id"],
78+
"name": node.data.get_name(),
79+
},
80+
}
81+
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
82+
return {
83+
"type": "schema",
84+
"data": node.data.schema(),
85+
}
86+
else:
87+
return {
88+
"type": "unknown",
89+
"data": node_data_str(node),
90+
}
91+
92+
2893
@dataclass
2994
class Graph:
3095
"""Graph of nodes and edges."""
3196

3297
nodes: Dict[str, Node] = field(default_factory=dict)
3398
edges: List[Edge] = field(default_factory=list)
3499

100+
def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
101+
"""Convert the graph to a JSON-serializable format."""
102+
stable_node_ids = {
103+
node.id: i if is_uuid(node.id) else node.id
104+
for i, node in enumerate(self.nodes.values())
105+
}
106+
107+
return {
108+
"nodes": [
109+
{"id": stable_node_ids[node.id], **node_data_json(node)}
110+
for node in self.nodes.values()
111+
],
112+
"edges": [
113+
{
114+
"source": stable_node_ids[edge.source],
115+
"target": stable_node_ids[edge.target],
116+
"data": edge.data,
117+
}
118+
if edge.data is not None
119+
else {
120+
"source": stable_node_ids[edge.source],
121+
"target": stable_node_ids[edge.target],
122+
}
123+
for edge in self.edges
124+
],
125+
}
126+
35127
def __bool__(self) -> bool:
36128
return bool(self.nodes)
37129

38130
def next_id(self) -> str:
39131
return uuid4().hex
40132

41-
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
133+
def add_node(
134+
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
135+
) -> Node:
42136
"""Add a node to the graph and return it."""
43-
node = Node(id=self.next_id(), data=data)
137+
if id is not None and id in self.nodes:
138+
raise ValueError(f"Node with id {id} already exists")
139+
node = Node(id=id or self.next_id(), data=data)
44140
self.nodes[node.id] = node
45141
return node
46142

@@ -53,13 +149,13 @@ def remove_node(self, node: Node) -> None:
53149
if edge.source != node.id and edge.target != node.id
54150
]
55151

56-
def add_edge(self, source: Node, target: Node) -> Edge:
152+
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge:
57153
"""Add an edge to the graph and return it."""
58154
if source.id not in self.nodes:
59155
raise ValueError(f"Source node {source.id} not in graph")
60156
if target.id not in self.nodes:
61157
raise ValueError(f"Target node {target.id} not in graph")
62-
edge = Edge(source=source.id, target=target.id)
158+
edge = Edge(source=source.id, target=target.id, data=data)
63159
self.edges.append(edge)
64160
return edge
65161

@@ -117,28 +213,8 @@ def trim_last_node(self) -> None:
117213
self.remove_node(last_node)
118214

119215
def draw_ascii(self) -> str:
120-
from langchain_core.runnables.base import Runnable
121-
122-
def node_data(node: Node) -> str:
123-
if isinstance(node.data, Runnable):
124-
try:
125-
data = str(node.data)
126-
if (
127-
data.startswith("<")
128-
or data[0] != data[0].upper()
129-
or len(data.splitlines()) > 1
130-
):
131-
data = node.data.__class__.__name__
132-
elif len(data) > 42:
133-
data = data[:42] + "..."
134-
except Exception:
135-
data = node.data.__class__.__name__
136-
else:
137-
data = node.data.__name__
138-
return data if not data.startswith("Runnable") else data[8:]
139-
140216
return draw(
141-
{node.id: node_data(node) for node in self.nodes.values()},
217+
{node.id: node_data_str(node) for node in self.nodes.values()},
142218
[(edge.source, edge.target) for edge in self.edges],
143219
)
144220

0 commit comments

Comments
 (0)