@@ -22,91 +22,91 @@ async def run(self, websocket=None) -> list[LATSNode]:
2222 return best_node
2323
2424 async def lats_search (self , websocket = None ):
25- terminal_nodes = []
25+ terminal_nodes = []
2626
27- for i in range (self .config .iterations ):
28- await self .websocket_iteration_start (i , websocket = websocket )
29-
30- print (f"Iteration { i } /{ self .config .iterations } ..." )
27+ for i in range (self .config .iterations ):
28+ await self .websocket_iteration_start (i , websocket = websocket )
29+
30+ print (f"Iteration { i } /{ self .config .iterations } ..." )
3131
32- # Step 1: Node Selection
33- ## TODO: move websocket node selection into node_selection method
34- print (f"{ GREEN } Step 1: node selection{ RESET } " )
35- await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
36- node = await self .node_selection (self .root_node )
37- await self .websocket_node_selection (node , websocket = websocket )
32+ # Step 1: Node Selection
33+ ## TODO: move websocket node selection into node_selection method
34+ print (f"{ GREEN } Step 1: node selection{ RESET } " )
35+ await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
36+ node = await self .node_selection (self .root_node )
37+ await self .websocket_node_selection (node , websocket = websocket )
3838
39- if node is None :
40- print ("All paths lead to terminal nodes with reward 0. Ending search." )
41- break
39+ if node is None :
40+ print ("All paths lead to terminal nodes with reward 0. Ending search." )
41+ break
4242
43- # Step 2: Node Expansion
44- print (f"{ GREEN } Step 2: node expansion{ RESET } " )
45- await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
46- await self .node_expansion (node , websocket )
47- if node is None :
48- # all the nodes are terminal, stop the search
49- print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
50- break
51- tree_data = self ._get_tree_data ()
52- if websocket :
53- await self .websocket_tree_update (type = "tree_update_node_expansion" , websocket = websocket , tree_data = tree_data )
54- else :
55- print_entire_tree (self .root_node )
43+ # Step 2: Node Expansion
44+ print (f"{ GREEN } Step 2: node expansion{ RESET } " )
45+ await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
46+ await self .node_expansion (node , websocket )
47+ if node is None :
48+ # all the nodes are terminal, stop the search
49+ print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
50+ break
51+ tree_data = self ._get_tree_data ()
52+ if websocket :
53+ await self .websocket_tree_update (type = "tree_update_node_expansion" , websocket = websocket , tree_data = tree_data )
54+ else :
55+ print_entire_tree (self .root_node )
5656
5757
58- # Step 3: Evaluation
59- print (f"{ GREEN } Step 3: node chilren evaluation{ RESET } " )
60- await self .websocket_step_start (step = 3 , step_name = "node_children_evaluation" , websocket = websocket )
61- await self .node_children_evaluation (node )
62- tree_data = self ._get_tree_data ()
63- if websocket :
64- await self .websocket_tree_update (type = "tree_update_node_children_evaluation" , websocket = websocket , tree_data = tree_data )
65- else :
66- print ("after evaluation" )
67- print_entire_tree (self .root_node )
58+ # Step 3: Evaluation
59+ print (f"{ GREEN } Step 3: node chilren evaluation{ RESET } " )
60+ await self .websocket_step_start (step = 3 , step_name = "node_children_evaluation" , websocket = websocket )
61+ await self .node_children_evaluation (node )
62+ tree_data = self ._get_tree_data ()
63+ if websocket :
64+ await self .websocket_tree_update (type = "tree_update_node_children_evaluation" , websocket = websocket , tree_data = tree_data )
65+ else :
66+ print ("after evaluation" )
67+ print_entire_tree (self .root_node )
6868
6969
70- # Step 4: Simulation
71- print (f"{ GREEN } Step 4: simulation{ RESET } " )
72- await self .websocket_step_start (step = 4 , step_name = "simulation" , websocket = websocket )
73- selected_node = max (node .children , key = lambda child : child .value )
74- await self .websocket_node_selection (selected_node , websocket = websocket , type = "node_selected_for_simulation" )
75- reward , terminal_node = await self .simulation (selected_node , websocket = websocket )
76- terminal_nodes .append (terminal_node )
77- await self .websocket_simulation_result (reward , terminal_node , websocket = websocket )
70+ # Step 4: Simulation
71+ print (f"{ GREEN } Step 4: simulation{ RESET } " )
72+ await self .websocket_step_start (step = 4 , step_name = "simulation" , websocket = websocket )
73+ selected_node = max (node .children , key = lambda child : child .value )
74+ await self .websocket_node_selection (selected_node , websocket = websocket , type = "node_selected_for_simulation" )
75+ reward , terminal_node = await self .simulation (selected_node , websocket = websocket )
76+ terminal_nodes .append (terminal_node )
77+ await self .websocket_simulation_result (reward , terminal_node , websocket = websocket )
7878
79- if reward == 1 :
80- await self .websocket_search_complete ("success" , reward , terminal_node .get_trajectory (), websocket = websocket )
81- return terminal_node
79+ if reward == 1 :
80+ await self .websocket_search_complete ("success" , reward , terminal_node .get_trajectory (), websocket = websocket )
81+ return terminal_node
8282
83- # Step 5: Backpropagation
84- print (f"{ GREEN } Step 5: backpropagation{ RESET } " )
85- await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
86- self .backpropagate (terminal_node , reward )
87- tree_data = self ._get_tree_data ()
88- if websocket :
89- await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
90- else :
91- print ("after backpropagation" )
92- print_entire_tree (self .root_node )
83+ # Step 5: Backpropagation
84+ print (f"{ GREEN } Step 5: backpropagation{ RESET } " )
85+ await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
86+ self .backpropagate (terminal_node , reward )
87+ tree_data = self ._get_tree_data ()
88+ if websocket :
89+ await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
90+ else :
91+ print ("after backpropagation" )
92+ print_entire_tree (self .root_node )
9393
94- # Find best node
95- all_nodes_list = collect_all_nodes (self .root_node )
96- all_nodes_list .extend (terminal_nodes )
97-
98- ## temp change: if reward is the same, choose the deeper node
99- best_child = max (all_nodes_list , key = lambda x : (x .reward , x .depth ))
94+ # Find best node
95+ all_nodes_list = collect_all_nodes (self .root_node )
96+ all_nodes_list .extend (terminal_nodes )
97+
98+ ## temp change: if reward is the same, choose the deeper node
99+ best_child = max (all_nodes_list , key = lambda x : (x .reward , x .depth ))
100+
101+ if best_child .value >= 0.75 :
102+ print ("Successful trajectory found" )
103+ await self .websocket_search_complete ("success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
104+ else :
105+ print ("Unsuccessful trajectory found" )
106+ await self .websocket_search_complete ("partial_success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
107+ await self .playwright_manager .close ()
100108
101- if best_child .value >= 0.75 :
102- print ("Successful trajectory found" )
103- await self .websocket_search_complete ("success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
104- else :
105- print ("Unsuccessful trajectory found" )
106- await self .websocket_search_complete ("partial_success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
107- await self .playwright_manager .close ()
108-
109- return best_child if best_child is not None else self .root_node
109+ return best_child if best_child is not None else self .root_node
110110
111111 async def node_selection (self , node : LATSNode , websocket = None ) -> Optional [LATSNode ]:
112112 if node .is_terminal :
0 commit comments