Skip to content

Commit af0e7e8

Browse files
committed
improve node expansion del session id
1 parent 05bc00f commit af0e7e8

File tree

2 files changed

+187
-4
lines changed

2 files changed

+187
-4
lines changed

visual-tree-search-app/components/MessageLogPanel.tsx

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,32 @@ interface ParsedMessage {
8080
iteration?: number;
8181
session_id?: string;
8282
node_action?: string;
83+
children?: ChildNodeData[];
84+
children_count?: number;
85+
node_info?: {
86+
action?: string;
87+
description?: string;
88+
depth?: number;
89+
value?: number;
90+
visits?: number;
91+
};
8392
}
8493

8594
interface PathStep {
8695
natural_language_description: string;
8796
action: string;
8897
}
8998

99+
interface ChildNodeData {
100+
id: number;
101+
parent_id: number;
102+
action: string;
103+
description: string;
104+
is_terminal: boolean;
105+
prob: number;
106+
depth: number;
107+
}
108+
90109
const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
91110
messages,
92111
messagesEndRef,
@@ -194,6 +213,8 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
194213
case 'tree_update_simulation':
195214
case 'trajectory_update':
196215
case 'removed_simulation':
216+
case 'tree_update_node_expansion':
217+
case 'tree_update_node_evaluation':
197218
return "bg-gradient-to-r from-cyan-50 to-cyan-100 dark:from-cyan-900/20 dark:to-cyan-800/20 border-cyan-200 dark:border-cyan-800";
198219

199220
case 'iteration_start':
@@ -205,6 +226,12 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
205226
case 'node_created':
206227
case 'node_simulated':
207228
case 'node_terminal':
229+
case 'node_expansion_start':
230+
case 'node_expansion_complete':
231+
case 'evaluation_start':
232+
case 'child_evaluated':
233+
case 'node_evaluation_start':
234+
case 'node_evaluation_complete':
208235
return "bg-gradient-to-r from-green-50 to-green-100 dark:from-green-900/20 dark:to-green-800/20 border-green-200 dark:border-green-800";
209236

210237
case 'simulation_result':
@@ -326,6 +353,18 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
326353
return <Flag className="h-4 w-4 text-green-500" />;
327354
case 'simulation_result':
328355
return <Info className="h-4 w-4 text-amber-500" />;
356+
case 'node_expansion_start':
357+
return <Expand className="h-4 w-4 text-purple-500" />;
358+
case 'node_expansion_complete':
359+
return <CheckCircle className="h-4 w-4 text-green-500" />;
360+
case 'evaluation_start':
361+
return <Brain className="h-4 w-4 text-blue-500" />;
362+
case 'child_evaluated':
363+
return <Star className="h-4 w-4 text-amber-500" />;
364+
case 'node_evaluation_start':
365+
return <Brain className="h-4 w-4 text-blue-500" />;
366+
case 'node_evaluation_complete':
367+
return <CheckCircle className="h-4 w-4 text-green-500" />;
329368
default:
330369
return <Info className="h-4 w-4 text-slate-500" />;
331370
}
@@ -412,6 +451,8 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
412451
case 'tree_update_simulation':
413452
case 'trajectory_update':
414453
case 'removed_simulation':
454+
case 'tree_update_node_expansion':
455+
case 'tree_update_node_evaluation':
415456
return "bg-cyan-100 dark:bg-cyan-800/30 text-cyan-600 dark:text-cyan-400";
416457

417458
case 'iteration_start':
@@ -423,6 +464,12 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
423464
case 'node_created':
424465
case 'node_simulated':
425466
case 'node_terminal':
467+
case 'node_expansion_start':
468+
case 'node_expansion_complete':
469+
case 'evaluation_start':
470+
case 'child_evaluated':
471+
case 'node_evaluation_start':
472+
case 'node_evaluation_complete':
426473
return "bg-green-100 dark:bg-green-800/30 text-green-600 dark:text-green-400";
427474

428475
case 'simulation_result':
@@ -451,6 +498,11 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
451498
};
452499

453500
const formatMessageContent = (message: ParsedMessage) => {
501+
// Ignore tree update messages
502+
if (message.type === 'tree_update_node_expansion' || message.type === 'tree_update_node_evaluation') {
503+
return null;
504+
}
505+
454506
switch (message.type) {
455507
case 'reflection_backtracking':
456508
return (
@@ -483,7 +535,9 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
483535
<div className="flex items-center gap-2 animate-fadeIn">
484536
{getIcon(message)}
485537
<div className="animate-slideIn">
486-
<div className="text-green-600 dark:text-green-400">{message.info}</div>
538+
<div className="text-green-600 dark:text-green-400">
539+
{message.info?.replace(/Session [a-f0-9-]+ terminated successfully/, 'Session terminated successfully')}
540+
</div>
487541
</div>
488542
</div>
489543
);
@@ -708,6 +762,69 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({
708762
</div>
709763
);
710764

765+
case 'node_expansion_start':
766+
return (
767+
<div className="flex items-center gap-2 animate-fadeIn">
768+
{getIcon(message)}
769+
<div className="animate-slideIn">
770+
<div className="text-purple-600 dark:text-purple-400">
771+
{message.node_info?.description || 'Starting node expansion'}
772+
</div>
773+
{message.node_info?.action && message.node_info.action !== 'ROOT' && (
774+
<div className="text-xs text-slate-500 dark:text-slate-400">
775+
Action: {message.node_info.action}
776+
</div>
777+
)}
778+
</div>
779+
</div>
780+
);
781+
case 'node_expansion_complete':
782+
return (
783+
<div className="flex items-center gap-2 animate-fadeIn">
784+
{getIcon(message)}
785+
<div className="animate-slideIn">
786+
<div className="text-green-600 dark:text-green-400">
787+
{message.node_info?.description || 'Node expansion complete'}
788+
</div>
789+
{message.node_info?.action && message.node_info.action !== 'ROOT' && (
790+
<div className="text-xs text-slate-500 dark:text-slate-400">
791+
Action: {message.node_info.action}
792+
</div>
793+
)}
794+
{message.children && message.children.length > 0 && (
795+
<div className="mt-1">
796+
{message.children.map((child: ChildNodeData, index: number) => (
797+
<div
798+
key={index}
799+
className="text-xs text-slate-500 dark:text-slate-400 pl-2 border-l-2 border-slate-200 dark:border-slate-700"
800+
>
801+
<div className="text-indigo-600 dark:text-indigo-400">
802+
{child.description || 'No description'}
803+
</div>
804+
<div className="text-slate-500 dark:text-slate-400">
805+
Action: {child.action || 'None'}
806+
</div>
807+
{child.prob && (
808+
<div className="text-slate-500 dark:text-slate-400">
809+
Probability: {child.prob.toFixed(3)}
810+
</div>
811+
)}
812+
</div>
813+
))}
814+
</div>
815+
)}
816+
</div>
817+
</div>
818+
);
819+
case 'evaluation_start':
820+
return `Starting evaluation of ${message.children_count} children for node ${message.node_id}`;
821+
case 'child_evaluated':
822+
return `Child node ${message.node_id} evaluated with score ${message.score}`;
823+
case 'node_evaluation_start':
824+
return `Starting evaluation of node ${message.node_id}`;
825+
case 'node_evaluation_complete':
826+
return `Node ${message.node_id} evaluated with score ${message.score}`;
827+
711828
default:
712829
return (
713830
<div className="flex items-center gap-2 animate-fadeIn">

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,17 @@ def __init__(
6868
subsets=self.agent_type, strict=False, multiaction=False, demo_mode="default"
6969
)
7070
self.root_node = LATSNode(
71-
natural_language_description=None,
72-
action=None,
73-
prob=None,
71+
natural_language_description="Root Node",
72+
action="ROOT",
73+
prob=1.0,
7474
element=None,
7575
goal=self.goal,
7676
parent=None
7777
)
78+
self.root_node.value = 0.0
79+
self.root_node.visits = 0
80+
self.root_node.depth = 0
81+
self.root_node.is_terminal = False
7882
self.goal_finished = False
7983
self.result_node = None
8084
self.reset_url = os.environ["ACCOUNT_RESET_URL"]
@@ -388,7 +392,23 @@ async def node_selection(self, node, websocket = None):
388392

389393

390394
async def node_expansion(self, node: LATSNode, websocket = None) -> None:
395+
if websocket:
396+
node_info = {
397+
"action": node.action if node.action else "ROOT",
398+
"description": node.natural_language_description if node.natural_language_description else "Root Node",
399+
"value": node.value if hasattr(node, 'value') else 0.0,
400+
"visits": node.visits if hasattr(node, 'visits') else 0,
401+
"depth": node.depth if hasattr(node, 'depth') else 0,
402+
"is_terminal": node.is_terminal if hasattr(node, 'is_terminal') else False
403+
}
404+
await websocket.send_json({
405+
"type": "node_expansion_start",
406+
"node_id": id(node),
407+
"node_info": node_info,
408+
"timestamp": datetime.utcnow().isoformat()
409+
})
391410
children_state = await self.generate_children(node, websocket)
411+
children_data = []
392412
for child_state in children_state:
393413
child = LATSNode(
394414
natural_language_description=child_state["natural_language_description"],
@@ -401,12 +421,36 @@ async def node_expansion(self, node: LATSNode, websocket = None) -> None:
401421
if child.depth == self.config.max_depth:
402422
child.is_terminal = True
403423
node.children.append(child)
424+
children_data.append({
425+
"id": id(child),
426+
"parent_id": id(node),
427+
"action": child.action,
428+
"description": child.natural_language_description,
429+
"is_terminal": child.is_terminal,
430+
"prob": child.prob,
431+
"depth": child.depth
432+
})
404433
await self.websocket_node_created(child, node, websocket=websocket)
434+
if websocket:
435+
await websocket.send_json({
436+
"type": "node_expansion_complete",
437+
"node_id": id(node),
438+
"node_info": node_info,
439+
"children": children_data,
440+
"timestamp": datetime.utcnow().isoformat()
441+
})
405442

406443

407444
# node evaluation
408445
# change the node evaluation to use the new prompt
409446
async def node_children_evaluation(self, node: LATSNode) -> None:
447+
if websocket:
448+
await websocket.send_json({
449+
"type": "evaluation_start",
450+
"node_id": id(node),
451+
"children_count": len(node.children),
452+
"timestamp": datetime.utcnow().isoformat()
453+
})
410454
scores = []
411455
print(f"{GREEN}-- total {len(node.children)} children to evaluate:{RESET}")
412456
for i, child in enumerate(node.children):
@@ -423,13 +467,27 @@ async def node_children_evaluation(self, node: LATSNode) -> None:
423467
result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model)
424468
score = result["overall_score"]
425469
scores.append(score)
470+
if websocket:
471+
await websocket.send_json({
472+
"type": "child_evaluated",
473+
"node_id": id(child),
474+
"parent_id": id(node),
475+
"score": score,
476+
"timestamp": datetime.utcnow().isoformat()
477+
})
426478

427479
for child, score in zip(node.children, scores):
428480
child.value = score
429481
# child.reward = score
430482

431483
async def node_evaluation(self, node: LATSNode) -> None:
432484
"""Evaluate the current node and assign its score."""
485+
if websocket:
486+
await websocket.send_json({
487+
"type": "node_evaluation_start",
488+
"node_id": id(node),
489+
"timestamp": datetime.utcnow().isoformat()
490+
})
433491
try:
434492
# Get the path from root to this node
435493
path = self.get_path_to_root(node)
@@ -468,6 +526,14 @@ async def node_evaluation(self, node: LATSNode) -> None:
468526
node.value = score
469527
# node.reward = score
470528

529+
if websocket:
530+
await websocket.send_json({
531+
"type": "node_evaluation_complete",
532+
"node_id": id(node),
533+
"score": score,
534+
"trajectory": trajectory,
535+
"timestamp": datetime.utcnow().isoformat()
536+
})
471537

472538
except Exception as e:
473539
error_msg = f"Error in node evaluation: {str(e)}"

0 commit comments

Comments
 (0)