Skip to content

Commit

Permalink
Merge pull request #81 from stratosphereips/check-EOF-in-communication
Browse files Browse the repository at this point in the history
Check eof in communication
  • Loading branch information
ondrej-lukas authored Feb 6, 2025
2 parents 3693e0e + ece5132 commit 18324c4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
10 changes: 4 additions & 6 deletions agents/attackers/random/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# with the path fixed, we can import now
from base_agent import BaseAgent
from agent_utils import generate_valid_actions
from datetime import datetime


class RandomAttackerAgent(BaseAgent):

Expand Down Expand Up @@ -63,14 +61,14 @@ def select_action(self, observation:Observation)->Action:
parser.add_argument("--port", help="Port where the game server is", default=9000, type=int, action='store', required=False)
parser.add_argument("--episodes", help="Sets number of episodes to play or evaluate", default=100, type=int)
parser.add_argument("--test_each", help="Evaluate performance during testing every this number of episodes.", default=10, type=int)
parser.add_argument("--logdir", help="Folder to store logs", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs"))
parser.add_argument("--logdir", help="Folder to store logs", default=path.join(path.dirname(path.abspath(__file__)), "logs"))
parser.add_argument("--evaluate", help="Evaluate the agent and report, instead of playing the game only once.", default=True)
parser.add_argument("--mlflow_url", help="URL for mlflow tracking server. If not provided, mlflow will store locally.", default=None)
args = parser.parse_args()

if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logging.basicConfig(filename=os.path.join(args.logdir, "random_agent.log"), filemode='w', format='%(asctime)s %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S',level=logging.INFO)
if not path.exists(args.logdir):
makedirs(args.logdir)
logging.basicConfig(filename=path.join(args.logdir, "random_agent.log"), filemode='w', format='%(asctime)s %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S',level=logging.INFO)

# Create agent
agent = RandomAttackerAgent(args.host, args.port,"Attacker", seed=42)
Expand Down
20 changes: 16 additions & 4 deletions agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# This is used so the agent can see the environment and game components
sys.path.append(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))
from AIDojoCoordinator.game_components import Action, GameState, Observation, ActionType, GameStatus,AgentInfo
from AIDojoCoordinator.game_components import Action, GameState, Observation, ActionType, GameStatus, AgentInfo, ProtocolConfig

class BaseAgent(ABC):
"""
Expand Down Expand Up @@ -86,7 +86,19 @@ def _receive_data(socket)->tuple:
Receive data from server
"""
# Receive data from the server
data = socket.recv(8192).decode()
data = b"" # Initialize an empty byte string

while True:
chunk = socket.recv(ProtocolConfig.BUFFER_SIZE) # Receive a chunk
if not chunk: # If no more data, break (connection closed)
break
data += chunk
if ProtocolConfig.END_OF_MESSAGE in data: # Check if EOF marker is present
break
if ProtocolConfig.END_OF_MESSAGE not in data:
raise ConnectionError("Unfinished connection.")
data = data.replace(ProtocolConfig.END_OF_MESSAGE, b"") # Remove EOF marker
data = data.decode()
self._logger.debug(f"Data received from env: {data}")
# extract data from string representation
data_dict = json.loads(data)
Expand Down Expand Up @@ -124,12 +136,12 @@ def register(self)->Observation:
except Exception as e:
self._logger.error(f'Exception in register(): {e}')

def request_game_reset(self)->Observation:
def request_game_reset(self, request_trajectory=False)->Observation:
"""
Method for requesting restart of the game.
"""
self._logger.debug("Requesting game reset")
status, observation_dict, message = self.communicate(Action(ActionType.ResetGame, {}))
status, observation_dict, message = self.communicate(Action(ActionType.ResetGame, parameters={"request_trajectory":request_trajectory}))
if status:
self._logger.debug('\tReset successful')
return Observation(GameState.from_dict(observation_dict["state"]), observation_dict["reward"], observation_dict["end"], message)
Expand Down

0 comments on commit 18324c4

Please sign in to comment.