3434openai_client = OpenAI ()
3535
3636
37- ## TODO: remove account reset websocket message
38- ## browser setup message, ok to leave there in the _reset_browser method
39-
40-
4137class BaseAgent :
4238 # no need to pass an initial playwright_manager to the agent class
4339 def __init__ (
@@ -381,6 +377,10 @@ async def websocket_search_complete(self, status, score, path, websocket=None):
381377 "path" : path ,
382378 "timestamp" : datetime .utcnow ().isoformat ()
383379 })
380+ else :
381+ print (f"Search complete: { GREEN } { status } { RESET } " )
382+ print (f"Search score: { GREEN } { score } { RESET } " )
383+ print (f"Search path: { GREEN } { path } { RESET } " )
384384
385385 # shared, not implemented, BFS, DFS and LATS has its own node selection logic
386386 async def node_selection (self , node , websocket = None ):
@@ -485,31 +485,19 @@ def backpropagate(self, node: LATSNode, value: float) -> None:
485485 node = node .parent
486486
487487 # shared
488- async def simulation (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 , websocket = None ) -> tuple [float , LATSNode ]:
488+ async def simulation (self , node : LATSNode , websocket = None ) -> tuple [float , LATSNode ]:
489489 depth = node .depth
490+ num_simulations = self .config .num_simulations
491+ max_depth = self .config .max_depth
490492 print ("print the trajectory" )
491493 print_trajectory (node )
492494 print ("print the entire tree" )
493495 print_entire_tree (self .root_node )
494- # if websocket:
495- # tree_data = self._get_tree_data()
496- # await self.websocket_tree_update(type="tree_update_simulation", tree_data=tree_data, websocket=websocket)
497- # await websocket.send_json({
498- # "type": "tree_update",
499- # "tree": tree_data,
500- # "timestamp": datetime.utcnow().isoformat()
501- # })
502- # trajectory_data = self._get_trajectory_data(node)
503- # await websocket.send_json({
504- # "type": "trajectory_update",
505- # "trajectory": trajectory_data,
506- # "timestamp": datetime.utcnow().isoformat()
507- # })
508- return await self .rollout (node , max_depth = max_depth , websocket = websocket )
496+ return await self .rollout (node , websocket = websocket )
509497
510498 # refactor simulation, rollout, send_completion_request methods
511499 # TODO: check, score as reward and then update value of the starting node?
512- async def rollout (self , node : LATSNode , max_depth : int = 2 , websocket = None )-> tuple [float , LATSNode ]:
500+ async def rollout (self , node : LATSNode , websocket = None )-> tuple [float , LATSNode ]:
513501 # Reset browser state
514502 await self ._reset_browser ()
515503 path = self .get_path_to_root (node )
@@ -540,23 +528,14 @@ async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tu
540528 "action" : n .action ,
541529 "feedback" : n .feedback
542530 })
543- ## call the prompt agent
544531 print ("current depth: " , len (path ) - 1 )
545532 print ("max depth: " , self .config .max_depth )
546533
547- ## find a better name for this
548534 trajectory , terminal_node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory , websocket = websocket )
549535 print ("print the trajectory" )
550536 print_trajectory (terminal_node )
551537 print ("print the entire tree" )
552538 print_entire_tree (self .root_node )
553- # if websocket:
554- # trajectory_data = self._get_trajectory_data(node)
555- # await websocket.send_json({
556- # "type": "trajectory_update",
557- # "trajectory": trajectory_data,
558- # "timestamp": datetime.utcnow().isoformat()
559- # })
560539
561540 page = await self .playwright_manager .get_page ()
562541 page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
@@ -583,12 +562,6 @@ async def send_completion_request(self, plan, depth, node, trajectory=[], websoc
583562 print ("print the entire tree" )
584563 print_entire_tree (self .root_node )
585564 if websocket :
586- # tree_data = self._get_tree_data()
587- # await websocket.send_json({
588- # "type": "tree_update",
589- # "tree": tree_data,
590- # "timestamp": datetime.utcnow().isoformat()
591- # })
592565 trajectory_data = self ._get_trajectory_data (node )
593566 await websocket .send_json ({
594567 "type" : "trajectory_update" ,
@@ -684,15 +657,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
684657 path = self .get_path_to_root (node )
685658
686659 # Execute path
687- for n in path [1 :]: # Skip root node
688- # if websocket:
689- # await websocket.send_json({
690- # "type": "replaying_action",
691- # "node_id": id(n),
692- # "action": n.action,
693- # "timestamp": datetime.utcnow().isoformat()
694- # })
695-
660+ for n in path [1 :]: # Skip root node
696661 success = await playwright_step_execution (
697662 n ,
698663 self .goal ,
@@ -702,12 +667,6 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
702667 )
703668 if not success :
704669 n .is_terminal = True
705- # if websocket:
706- # await websocket.send_json({
707- # "type": "replay_failed",
708- # "node_id": id(n),
709- # "timestamp": datetime.utcnow().isoformat()
710- # })
711670 return []
712671
713672 if not n .feedback :
@@ -716,26 +675,13 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
716675 n .natural_language_description ,
717676 self .playwright_manager ,
718677 )
719- # if websocket:
720- # await websocket.send_json({
721- # "type": "feedback_generated",
722- # "node_id": id(n),
723- # "feedback": n.feedback,
724- # "timestamp": datetime.utcnow().isoformat()
725- # })
726678
727679 time .sleep (3 )
728680 page = await self .playwright_manager .get_page ()
729681 page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
730682
731683 messages = [{"role" : "user" , "content" : f"Action is: { n .action } " } for n in path [1 :]]
732-
733- # if websocket:
734- # await websocket.send_json({
735- # "type": "generating_actions",
736- # "node_id": id(node),
737- # "timestamp": datetime.utcnow().isoformat()
738- # })
684+
739685
740686 next_actions = await extract_top_actions (
741687 [{"natural_language_description" : n .natural_language_description , "action" : n .action , "feedback" : n .feedback } for n in path [1 :]],
@@ -779,23 +725,8 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
779725 action ["element" ] = element
780726 except Exception as e :
781727 action ["element" ] = None
782- # if websocket:
783- # await websocket.send_json({
784- # "type": "element_location_failed",
785- # "action": action["action"],
786- # "error": str(e),
787- # "timestamp": datetime.utcnow().isoformat()
788- # })
789728 children .append (action )
790729
791730 if not children :
792- node .is_terminal = True
793- # if websocket:
794- # await websocket.send_json({
795- # "type": "node_terminal",
796- # "node_id": id(node),
797- # "reason": "no_valid_actions",
798- # "timestamp": datetime.utcnow().isoformat()
799- # })
800-
731+ node .is_terminal = True
801732 return children
0 commit comments