-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbuild_router_graph.py
60 lines (48 loc) · 2.1 KB
/
build_router_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import asyncio
from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph_mcp.configuration import Configuration
from langgraph_mcp import mcp_wrapper as mcp
from langgraph_mcp.retriever import make_retriever
from langgraph_mcp.state import BuilderState
async def build_router(state: BuilderState, *, config: RunnableConfig):
"""
Build the router by gathering routing descriptions from MCP servers and storing them in the retriever.
Parameters:
state (BuilderState): The current state of the router builder.
config (RunnableConfig): The configuration for the router builder.
Returns:
dict: Status of the build process.
"""
status = "failure"
configuration = Configuration.from_runnable_config(config)
mcp_servers = configuration.mcp_server_config["mcpServers"]
try:
# Gather routing descriptions directly without a shared dictionary
routing_descriptions = await asyncio.gather(
*[
mcp.apply(server_name, server_config, mcp.RoutingDescription())
for server_name, server_config in mcp_servers.items()
]
)
# Create documents from the gathered descriptions
documents = [
Document(page_content=description, metadata={"id": server_name})
for server_name, description in routing_descriptions
]
# Store the documents in the retriever
with make_retriever(config) as retriever:
if configuration.retriever_provider == "milvus":
retriever.add_documents(documents, ids=[doc.metadata["id"] for doc in documents])
else:
await retriever.aadd_documents(documents)
status = "success"
except Exception as e:
print(f"Exception in run: {e}")
return {"status": status}
builder = StateGraph(state_schema=BuilderState, config_schema=Configuration)
builder.add_node(build_router)
builder.add_edge("__start__", "build_router")
graph = builder.compile()
graph.name = "BuildRouterGraph"