@@ -141,10 +141,10 @@ async def evaluate_selected_path(self, path) -> None:
141141 "feedback" : n .feedback
142142 })
143143
144- # Score the trajectory
145- # TODO: if node is terminal, score is 0?
146- # if node.is_terminal:
147- # score = 0
144+ ## fix for MCTS agent only
145+ if len ( trajectory ) == 0 :
146+ score = 0
147+ return score
148148 prompt = create_llm_prompt (trajectory , self .goal )
149149 print (f"prompt: { prompt } " )
150150 result = score_trajectory_with_openai (
@@ -230,8 +230,10 @@ async def reflection_backtracking(self, path) -> List[LATSNode]:
230230 print ("Suggested improvements:" )
231231 for improvement in reflection_result ["suggested_improvements" ]:
232232 print (f"- { improvement } " )
233+ print (f"current_node: { current_node .action } " )
234+ print (f"current_node: { current_node .natural_language_description } " )
233235
234- return path
236+ return path , current_node
235237
236238 async def mcts_search (self , websocket = None ) -> Optional [LATSNode ]:
237239 best_score = float ('-inf' )
@@ -249,17 +251,22 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
249251 # "node selection" combines selection and partial simulation
250252 print (f"{ GREEN } Step 1: Node Selection{ RESET } " )
251253 await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
252- node = await self .node_selection (self .root_node , websocket )
254+ selected_node = await self .node_selection (self .root_node , websocket )
255+ tree_data = self ._get_tree_data ()
256+ if websocket :
257+ await self .websocket_tree_update (type = "tree_update_node_selection" , websocket = websocket , tree_data = tree_data )
258+ else :
259+ print_entire_tree (self .root_node )
253260
254- if node is None :
261+ if selected_node is None :
255262 logger .warning ("All paths lead to terminal nodes. Ending search." )
256263 break
257264
258265 # Step 2: Node Expansion
259266 print (f"{ GREEN } Step 2: Node Expansion{ RESET } " )
260267 await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
261- await self .node_expansion (node , websocket )
262- if node is None :
268+ await self .node_expansion (selected_node , websocket )
269+ if selected_node is None :
263270 # all the nodes are terminal, stop the search
264271 print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
265272 break
@@ -274,29 +281,34 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
274281 # TODO: implement simulation using openai
275282 print (f"{ GREEN } Step 3: Simulation{ RESET } " )
276283 await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
277- path = self .get_path_to_root (node )
284+ path = self .get_path_to_root (selected_node )
278285 score = await self .evaluate_selected_path (path )
279286 # change to reward later?
280287 if score > best_score :
281288 best_score = score
282289 best_path = path
290+ best_node = selected_node
283291 print (f"\n New best path found!" )
284- print (f"Previous best score: { best_score :.3f} " )
285- print (f"New best score: { score :.3f} " )
292+ print (f"best score: { best_score :.3f} " )
293+ print (f"best node: { best_node .action } " )
294+ print (f"best node: { best_node .natural_language_description } " )
295+ print (f"best path: { best_path } " )
286296
287297
288298 ## Step 4: reflection backtracking
289299 print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
290300 await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
291301 if score >= self .config .reflection_score :
292302 # Convert path to serializable trajectory
293- trajectory = [node .action for node in path if node .action is not None ]
294- await self .websocket_search_complete ("success" , score , trajectory , websocket = websocket )
295- return node
303+ # trajectory = [node.action for node in path if node.action is not None]
304+ await self .websocket_search_complete ("success" , score , selected_node . get_trajectory () , websocket = websocket )
305+ return selected_node
296306
297307 print (f"path: { path } " )
298- path = await self .reflection_backtracking (path )
308+ path , current_node = await self .reflection_backtracking (path )
299309 print (f"path: { path } " )
310+ print (f"current_node: { current_node .action } " )
311+ print (f"current_node: { current_node .natural_language_description } " )
300312
301313 # Step 5: backpropagation
302314 print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
@@ -308,8 +320,12 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
308320 print (f"Node { node .action } :" )
309321 print (f" Visits: { node .visits } " )
310322 print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
323+ if websocket :
324+ await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
325+ else :
326+ print_entire_tree (self .root_node )
311327 if best_node :
312328 # Convert node to serializable trajectory
313- trajectory = [n .action for n in self .get_path_to_root (best_node ) if n .action is not None ]
329+ # trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
314330 await self .websocket_search_complete ("partial_success" , best_node .value , best_node .get_trajectory (), websocket = websocket )
315331 return best_node
0 commit comments