1+ import asyncio
2+ import json
3+ import websockets
4+ import argparse
5+ import logging
6+ from datetime import datetime
7+
8+ # Configure logging
9+ logging .basicConfig (
10+ level = logging .INFO ,
11+ format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12+ )
13+ logger = logging .getLogger (__name__ )
14+
15+ # Default values
16+ DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws"
17+ DEFAULT_STARTING_URL = "http://128.105.145.205:7770/"
18+ DEFAULT_GOAL = "search running shoes, click on the first result"
19+
20+ async def connect_and_test_search (
21+ ws_url : str ,
22+ starting_url : str ,
23+ goal : str ,
24+ search_algorithm : str = "bfs" ,
25+ max_depth : int = 3
26+ ):
27+ """
28+ Connect to the WebSocket endpoint and test the tree search functionality.
29+
30+ Args:
31+ ws_url: WebSocket URL to connect to
32+ starting_url: URL to start the search from
33+ goal: Goal to achieve
34+ search_algorithm: Search algorithm to use (bfs or dfs)
35+ max_depth: Maximum depth for the search tree
36+ """
37+ logger .info (f"Connecting to WebSocket at { ws_url } " )
38+
39+ async with websockets .connect (ws_url ) as websocket :
40+ logger .info ("Connected to WebSocket" )
41+
42+ # Wait for connection established message
43+ response = await websocket .recv ()
44+ data = json .loads (response )
45+ if data .get ("type" ) == "connection_established" :
46+ logger .info (f"Connection established with ID: { data .get ('connection_id' )} " )
47+
48+ # Send search request
49+ request = {
50+ "type" : "start_search" ,
51+ "agent_type" : "MCTSAgent" ,
52+ "starting_url" : starting_url ,
53+ "goal" : goal ,
54+ "search_algorithm" : search_algorithm ,
55+ "max_depth" : max_depth
56+ }
57+
58+ logger .info (f"Sending search request: { request } " )
59+ await websocket .send (json .dumps (request ))
60+
61+ # Process responses
62+ while True :
63+ try :
64+ response = await websocket .recv ()
65+ data = json .loads (response )
66+
67+ # Log the message type and some key information
68+ msg_type = data .get ("type" , "unknown" )
69+
70+ if msg_type == "status_update" :
71+ logger .info (f"Status update: { data .get ('status' )} - { data .get ('message' )} " )
72+
73+ elif msg_type == "iteration_start" :
74+ logger .info (f"Iteration start: { data .get ('iteration' )} " )
75+
76+ elif msg_type == "step_start" :
77+ logger .info (f"Step start: { data .get ('step' )} - { data .get ('step_name' )} " )
78+
79+ elif msg_type == "node_update" :
80+ node_id = data .get ("node_id" )
81+ status = data .get ("status" )
82+ logger .info (f"Node update: { node_id } - { status } " )
83+
84+ # If node was scored, log the score
85+ if status == "scored" :
86+ logger .info (f"Node score: { data .get ('score' )} " )
87+
88+ elif msg_type == "trajectory_update" :
89+ logger .info (f"Trajectory update received with { data .get ('trajectory' )} " )
90+
91+ elif msg_type == "tree_update" :
92+ logger .info (f"Tree update received with { data .get ('tree' )} " )
93+
94+ elif msg_type == "best_path_update" :
95+ logger .info (f"Best path update: score={ data .get ('score' )} , path length={ len (data .get ('path' , []))} " )
96+
97+ elif msg_type == "search_complete" :
98+ status = data .get ("status" )
99+ score = data .get ("score" , "N/A" )
100+ path_length = len (data .get ("path" , []))
101+
102+ logger .info (f"Search complete: { status } , score={ score } , path length={ path_length } " )
103+ logger .info ("Path actions:" )
104+
105+ for i , node in enumerate (data .get ("path" , [])):
106+ logger .info (f" { i + 1 } . { node .get ('action' )} " )
107+
108+ # Exit the loop when search is complete
109+ break
110+
111+ elif msg_type == "error" :
112+ logger .error (f"Error: { data .get ('message' )} " )
113+ break
114+
115+ else :
116+ logger .info (f"Received message of type { msg_type } " )
117+ logger .info (f"Message: { data } " )
118+
119+ except websockets .exceptions .ConnectionClosed :
120+ logger .warning ("WebSocket connection closed" )
121+ break
122+ except Exception as e :
123+ logger .error (f"Error processing message: { e } " )
124+ break
125+
126+ logger .info ("Test completed" )
127+
128+ def parse_arguments ():
129+ """Parse command line arguments"""
130+ parser = argparse .ArgumentParser (description = "Test the tree search WebSocket functionality" )
131+
132+ parser .add_argument ("--ws-url" , type = str , default = DEFAULT_WS_URL ,
133+ help = f"WebSocket URL (default: { DEFAULT_WS_URL } )" )
134+
135+ parser .add_argument ("--starting-url" , type = str , default = DEFAULT_STARTING_URL ,
136+ help = f"Starting URL for the search (default: { DEFAULT_STARTING_URL } )" )
137+
138+ parser .add_argument ("--goal" , type = str , default = DEFAULT_GOAL ,
139+ help = f"Goal to achieve (default: { DEFAULT_GOAL } )" )
140+
141+ parser .add_argument ("--algorithm" , type = str , choices = ["bfs" , "dfs" , "lats" , "mcts" ], default = "mcts" ,
142+ help = "Search algorithm to use (default: lats)" )
143+
144+ parser .add_argument ("--max-depth" , type = int , default = 3 ,
145+ help = "Maximum depth for the search tree (default: 3)" )
146+
147+ return parser .parse_args ()
148+
149+ async def main ():
150+ """Main entry point"""
151+ args = parse_arguments ()
152+
153+ logger .info ("Starting tree search WebSocket test" )
154+ logger .info (f"WebSocket URL: { args .ws_url } " )
155+ logger .info (f"Starting URL: { args .starting_url } " )
156+ logger .info (f"Goal: { args .goal } " )
157+ logger .info (f"Algorithm: { args .algorithm } " )
158+ logger .info (f"Max depth: { args .max_depth } " )
159+
160+ await connect_and_test_search (
161+ ws_url = args .ws_url ,
162+ starting_url = args .starting_url ,
163+ goal = args .goal ,
164+ search_algorithm = args .algorithm ,
165+ max_depth = args .max_depth
166+ )
167+
168+ if __name__ == "__main__" :
169+ asyncio .run (main ())
0 commit comments