From 15b06666a0b69390b5b23b66294fd978b52ec26f Mon Sep 17 00:00:00 2001 From: Flim de Jong Date: Wed, 20 Nov 2024 16:29:29 +0100 Subject: [PATCH] Initial version working on ray cluster (does not return observations yet) --- docker/runner/ray-cluster-combined.yaml | 71 ++++++++-- roboteam_ai/src/RL/RL_Ray/train.py | 52 ++++---- roboteam_ai/src/RL/env2.py | 47 ++++--- roboteam_ai/src/RL/src/changeGameState.py | 148 ++++----------------- roboteam_ai/src/RL/src/getState.py | 108 +++++++++------ roboteam_ai/src/RL/src/resetRefereeAPI.py | 42 +----- roboteam_ai/src/RL/src/websocketHandler.py | 45 +++++++ 7 files changed, 262 insertions(+), 251 deletions(-) create mode 100644 roboteam_ai/src/RL/src/websocketHandler.py diff --git a/docker/runner/ray-cluster-combined.yaml b/docker/runner/ray-cluster-combined.yaml index 97f2169ca..a054f98eb 100644 --- a/docker/runner/ray-cluster-combined.yaml +++ b/docker/runner/ray-cluster-combined.yaml @@ -22,7 +22,7 @@ spec: - key: kubernetes.io/hostname operator: In values: - - multinode-demo # Schedule on control-plane + - ray # Schedule on control-plane containers: - name: ray-head image: roboteamtwente/ray:development @@ -78,6 +78,8 @@ spec: labels: app: ray-worker spec: + hostNetwork: true + dnsPolicy: ClusterFirstWithHostNet affinity: nodeAffinity: requiredDuringSchedulingIgnoredDuringExecution: @@ -86,7 +88,13 @@ spec: - key: kubernetes.io/hostname operator: In values: - - multinode-demo-m02 # Schedule on worker node + - ray-m02 # Schedule on worker node + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchLabels: + app: ray-worker + topologyKey: "kubernetes.io/hostname" volumes: - name: gradle-cache emptyDir: {} @@ -101,8 +109,8 @@ spec: cpu: 500m memory: 1Gi limits: - cpu: 1000m - memory: 2Gi + cpu: 2000m + memory: 4Gi env: - name: LD_LIBRARY_PATH value: /home/roboteam/build/release/lib @@ -112,10 +120,10 @@ spec: # Game Controller - name: ssl-game-controller image: robocupssl/ssl-game-controller:latest - args: ["-address", "0.0.0.0:8081"] # Changed from :8081 to explicitly bind to all interfaces + args: ["-address", "0.0.0.0:8081"] ports: - containerPort: 8081 - protocol: TCP # Explicitly set protocol + protocol: TCP # Primary AI - name: roboteam-primary-ai @@ -193,11 +201,11 @@ spec: - "/home/roboteam/external/framework/build/bin/simulator-cli" ports: - containerPort: 10300 - protocol: UDP # Simulator control port + protocol: UDP - containerPort: 10301 - protocol: TCP # Presumably TCP ports + protocol: TCP - containerPort: 5558 - protocol: TCP # ZMQ port + protocol: TCP env: - name: LD_LIBRARY_PATH value: /home/roboteam/build/release/lib @@ -244,4 +252,47 @@ spec: - name: gc-interface port: 8081 targetPort: 8081 - nodePort: 30081 # Game controller interface \ No newline at end of file + nodePort: 30081 # Game controller interface + - name: redis + port: 6379 + targetPort: 6379 + nodePort: 30679 # Choose an available port + +--- +apiVersion: v1 +kind: Service +metadata: + name: roboteam-ray-cluster-head-svc +spec: + type: ClusterIP + selector: + app: ray-head + ports: + - name: redis + port: 6379 + targetPort: 6379 + - name: gcs + port: 10001 + targetPort: 10001 + - name: dashboard + port: 8265 + targetPort: 8265 + - name: serve + port: 8000 + targetPort: 8000 + +--- +apiVersion: v1 +kind: Service +metadata: + name: roboteam-ray-worker-svc +spec: + selector: + app: ray-worker + ports: + - name: game-controller + port: 8081 + targetPort: 8081 + - name: simulator + port: 5558 + targetPort: 5558 \ No newline at end of file diff --git a/roboteam_ai/src/RL/RL_Ray/train.py b/roboteam_ai/src/RL/RL_Ray/train.py index af9cc94ce..51c7254c2 100644 --- a/roboteam_ai/src/RL/RL_Ray/train.py +++ b/roboteam_ai/src/RL/RL_Ray/train.py @@ -27,23 +27,23 @@ def verify_imports(): def main(): verify_imports() - # if not ray.is_initialized(): - # ray.init( - # address="ray://192.168.49.2:31001", - # ignore_reinit_error=True, - # runtime_env={ - # "env_vars": { - # "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION": "0", - - # }, - # # "pip": [ - # # "numpy==1.24.3", - # # "pyzmq==26.2.0" - # # ] - # } - # ) - - ray.init() + if not ray.is_initialized(): + ray.init( + address=f"ray://192.168.49.2:31001", + ignore_reinit_error=True, + runtime_env={ + "env_vars": { + "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION": "0", + + }, + # "pip": [ + # "numpy==1.24.3", + # "pyzmq==26.2.0" + # ] + } + ) + + # ray.init() # We can set env_config here def env_creator(env_config): @@ -65,7 +65,13 @@ def env_creator(env_config): .resources(num_gpus=0) .env_runners( num_env_runners=1, + num_envs_per_env_runner=1, + sample_timeout_s=None ) +# .api_stack( +# enable_rl_module_and_learner=True, +# enable_env_runner_and_connector_v2=True +# ) .debugging( log_level="DEBUG", seed=42 @@ -78,17 +84,15 @@ def env_creator(env_config): algo = config.build() for i in range(10): + print(f"\nStarting iteration {i}") result = algo.train() result.pop("config") + print("\nTraining metrics:") + print(f"Episode Reward Mean: {result.get('episode_reward_mean', 'N/A')}") + print(f"Episode Length Mean: {result.get('episode_len_mean', 'N/A')}") + print(f"Total Timesteps: {result.get('timesteps_total', 'N/A')}") pprint(result) - if i % 5 == 0: - # Use save instead of save_to_path - checkpoint_dir = f"checkpoint_{i}" - os.makedirs(checkpoint_dir, exist_ok=True) - algo.save(checkpoint_dir) - print(f"Checkpoint saved in directory {checkpoint_dir}") - if __name__ == "__main__": main() diff --git a/roboteam_ai/src/RL/env2.py b/roboteam_ai/src/RL/env2.py index 9a55da945..912118b75 100644 --- a/roboteam_ai/src/RL/env2.py +++ b/roboteam_ai/src/RL/env2.py @@ -44,9 +44,8 @@ def __init__(self, config=None): self.blue_score = 0 # Initialize blue score to zero # Initialize the observation space - self.observation_space = spaces.Box(low=0, high=self.MAX_ROBOTS_US, shape=(15,), dtype=np.int32) + self.observation_space = spaces.Box(low=float('-inf'), high=float('inf'), shape=(1,15), dtype=np.float64) - # Action space: [attackers, defenders] # Wallers will be automatically calculated self.action_space = spaces.MultiDiscrete([self.MAX_ROBOTS_US + 1, self.MAX_ROBOTS_US + 1]) @@ -170,26 +169,38 @@ def get_observation(self): """ get_observation is meant to get the observation space (kinda like the state) """ - # Get the robot grid representation - self.robot_grid, self.is_yellow_dribbling, self.is_blue_dribbling = get_robot_state() # Matrix of 4 by 2 + 2 booleans + self.robot_grid, self.is_yellow_dribbling, self.is_blue_dribbling = get_robot_state() print(f"Robot grid: {self.robot_grid}") print(f"Yellow dribbling: {self.is_yellow_dribbling}, Blue dribbling: {self.is_blue_dribbling}") # Get the ball location - self.ball_position, self.ball_quadrant = get_ball_state() # x,y coordinates, quadrant - + self.ball_position, self.ball_quadrant = get_ball_state() print(f"Ball position: {self.ball_position}, Ball quadrant: {self.ball_quadrant}") - robot_positions_flat = self.robot_grid.flatten() + # Convert and flatten robot positions to float64 + robot_positions_flat = self.robot_grid.astype(np.float64).flatten() # 8 elements + + # Use ball quadrant for observation + ball_quadrant = np.array([float(self.ball_quadrant)], dtype=np.float64) # 1 element + + # Convert dribbling status to float64 + is_yellow_dribbling = np.array([float(self.is_yellow_dribbling)], dtype=np.float64) # 1 element - # Convert `ball_position` (scalar) and `is_yellow_dribbling` (boolean) to compatible formats - ball_position = np.array([self.ball_quadrant]) # 1 element - is_yellow_dribbling = np.array([int(self.is_yellow_dribbling)]) # Convert boolean to int (0 or 1) + # Combine all parts into the observation array with padding + observation = np.concatenate([ + robot_positions_flat, # 8 elements + ball_quadrant, # 1 element + is_yellow_dribbling, # 1 element + np.zeros(5, dtype=np.float64) # 5 elements to reach total of 15 + ]) - # Combine all parts into a single 15-element observation array - # Pad with zeros if you need additional elements - observation = np.concatenate([robot_positions_flat, ball_position, is_yellow_dribbling, np.zeros(5)]) + # Reshape to match expected shape (1, 15) + observation = observation.reshape(1, 15) + + # Verify shape and dtype + assert observation.shape == (1, 15), f"Observation shape {observation.shape} != (1, 15)" + assert observation.dtype == np.float64, f"Observation dtype {observation.dtype} != float64" return observation, self.calculate_reward() @@ -230,7 +241,7 @@ def step(self, action): observation_space, _ = self.reset() truncated = self.is_truncated() # Determine if the episode was truncated, too much time or a yellow card - time.sleep(0.25) # DELAY FOR STEPS (ADJUST LATER) + time.sleep(0.1) # DELAY FOR STEPS (ADJUST LATER) return observation_space, reward, done, truncated, {} @@ -266,20 +277,26 @@ def reset(self, seed=None,**kwargs): """ # Teleport ball to middle position + print("Teleporting ball...") teleport_ball(0,0) # Reset referee state + print("Resetting referee state...") reset_referee_state() # Set blue team on right side + initiates kickoff + print("Starting game...") start_game() + print("Getting observation...") + observation, _ = self.get_observation() + # Reset shaped_reward_given boolean self.shaped_reward_given = False self.is_yellow_dribbling = False self.is_blue_dribbling = False - observation, _ = self.get_observation() + print("Reset complete!") return observation,{} diff --git a/roboteam_ai/src/RL/src/changeGameState.py b/roboteam_ai/src/RL/src/changeGameState.py index 89bdb5bfe..27b069ee6 100644 --- a/roboteam_ai/src/RL/src/changeGameState.py +++ b/roboteam_ai/src/RL/src/changeGameState.py @@ -1,129 +1,39 @@ -import sys -import os -import websockets -import asyncio -import json -import time +from . websocketHandler import run_websocket_command def set_team_state(team, on_positive_half): - """ - Set team state including which half they play on. - - Args: - team (str): Either "BLUE" or "YELLOW" - on_positive_half (bool): True if team should play on positive half - """ - uri = "ws://localhost:8081/api/control" - - async def _async_set_team(): - try: - async with websockets.connect(uri) as websocket: - state_msg = { - "change": { - "update_team_state_change": { - "for_team": team, - "on_positive_half": on_positive_half - } - } - } - - print(f"Setting {team} team to play on {'positive' if on_positive_half else 'negative'} half...") - await websocket.send(json.dumps(state_msg)) - - try: - response = await asyncio.wait_for(websocket.recv(), timeout=2.0) - #print("Received response:", json.loads(response)) - except asyncio.TimeoutError: - print("No response received in 2 seconds") - - except websockets.exceptions.ConnectionClosed as e: - print(f"WebSocket connection closed: {e}") - except Exception as e: - print(f"Error: {e}") - - # Run the async function synchronously - loop = asyncio.get_event_loop() - loop.run_until_complete(_async_set_team()) + state_msg = { + "change": { + "update_team_state_change": { + "for_team": team, + "on_positive_half": on_positive_half + } + } + } + return run_websocket_command(state_msg) def send_referee_command(command_type, team=None): - """ - Send a referee command using websockets. - - Args: - command_type (str): One of: - HALT, STOP, NORMAL_START, FORCE_START, DIRECT, - KICKOFF, PENALTY, TIMEOUT, BALL_PLACEMENT - team (str, optional): For team-specific commands, either "YELLOW" or "BLUE" - """ - uri = "ws://localhost:8081/api/control" - - async def _async_send_command(): - try: - async with websockets.connect(uri) as websocket: - command_msg = { - "change": { - "new_command_change": { - "command": { - "type": command_type, - "for_team": team if team else "UNKNOWN" - } - } - } + command_msg = { + "change": { + "new_command_change": { + "command": { + "type": command_type, + "for_team": team if team else "UNKNOWN" } - - print(f"Sending referee command: {command_type} {'for ' + team if team else ''}...") - await websocket.send(json.dumps(command_msg)) - - try: - response = await asyncio.wait_for(websocket.recv(), timeout=2.0) - #print("Received response:", json.loads(response)) - except asyncio.TimeoutError: - print("No response received in 2 seconds") - - except websockets.exceptions.ConnectionClosed as e: - print(f"WebSocket connection closed: {e}") - except Exception as e: - print(f"Error: {e}") - - loop = asyncio.get_event_loop() - loop.run_until_complete(_async_send_command()) + } + } + } + return run_websocket_command(command_msg) def set_first_kickoff_team(team): - """ - Set which team takes the first kickoff. - - Args: - team (str): Either "BLUE" or "YELLOW" - """ - uri = "ws://localhost:8081/api/control" - - async def _async_set_kickoff(): - try: - async with websockets.connect(uri) as websocket: - config_msg = { - "change": { - "update_config_change": { - "first_kickoff_team": team - } - } - } - - print(f"Setting first kickoff team to {team}...") - await websocket.send(json.dumps(config_msg)) - - try: - response = await asyncio.wait_for(websocket.recv(), timeout=2.0) - #print("Received response:", json.loads(response)) - except asyncio.TimeoutError: - print("No response received in 2 seconds") - - except websockets.exceptions.ConnectionClosed as e: - print(f"WebSocket connection closed: {e}") - except Exception as e: - print(f"Error: {e}") - - loop = asyncio.get_event_loop() - loop.run_until_complete(_async_set_kickoff()) + """Set which team takes first kickoff.""" + config_msg = { + "change": { + "update_config_change": { + "first_kickoff_team": team + } + } + } + return run_websocket_command(config_msg) # Simple command functions def halt(): diff --git a/roboteam_ai/src/RL/src/getState.py b/roboteam_ai/src/RL/src/getState.py index 99881306e..e040f19ce 100644 --- a/roboteam_ai/src/RL/src/getState.py +++ b/roboteam_ai/src/RL/src/getState.py @@ -5,6 +5,7 @@ import numpy as np import socket import struct +from . websocketHandler import run_websocket_command # Make sure to go back to the main roboteam directory @@ -20,50 +21,67 @@ # from roboteam_networking.proto.ssl_gc_api_pb2 import Output as RefereeState # Alias for referee state from roboteam_networking.proto.messages_robocup_ssl_referee_pb2 import * # Alias for referee state -# Function to get the ball state -def get_ball_state(): - ball_position = np.zeros(2) # [x, y] - # Instead of -1, quadrant 4 if ball is in the center - ball_quadrant = 4 +IS_IN_K8S = True - CENTER_THRESHOLD = 0.01 # Define the center threshold +def get_zmq_address(): + """Get the appropriate ZMQ address based on environment""" + if IS_IN_K8S: + host = "roboteam-ray-worker-svc" + print("Running in Kubernetes, using service DNS") + else: + host = "localhost" + print("Running locally") + return f"tcp://{host}:5558" - context = zmq.Context() - socket_world = context.socket(zmq.SUB) - socket_world.setsockopt_string(zmq.SUBSCRIBE, "") - socket_world.connect("tcp://127.0.0.1:5558") # Connect to the simulation socket - - try: - message = socket_world.recv() - state = RoboState.FromString(message) - - if not len(state.processed_vision_packets): - return ball_position, ball_quadrant - - world = state.last_seen_world - - if world.HasField("ball"): - ball_position[0] = world.ball.pos.x - ball_position[1] = world.ball.pos.y - - print("x",ball_position[0]) - print("y",ball_position[1]) - - if abs(ball_position[0]) <= CENTER_THRESHOLD and abs(ball_position[1]) <= CENTER_THRESHOLD: - ball_quadrant = 4 # Center - elif ball_position[0] < 0: - ball_quadrant = 0 if ball_position[1] > 0 else 2 - else: - ball_quadrant = 1 if ball_position[1] > 0 else 3 - except DecodeError: - print("Failed to decode protobuf message") - except zmq.ZMQError as e: - print(f"ZMQ Error: {e}") - finally: - socket_world.close() - context.term() - - return ball_position, ball_quadrant +# Function to get the ball state +def get_ball_state(): + ball_position = np.zeros(2) # [x, y] + # Instead of -1, quadrant 4 if ball is in the center + ball_quadrant = 4 + + CENTER_THRESHOLD = 0.01 # Define the center threshold + + context = zmq.Context() + socket_world = context.socket(zmq.SUB) + socket_world.setsockopt_string(zmq.SUBSCRIBE, "") + + zmq_address = get_zmq_address() + print(f"Connecting to ZMQ at: {zmq_address}") + socket_world.connect(zmq_address) + + try: + print("Waiting for ZMQ message...") + message = socket_world.recv() + print("Received ZMQ message") + state = RoboState.FromString(message) + + if not len(state.processed_vision_packets): + return ball_position, ball_quadrant + + world = state.last_seen_world + + if world.HasField("ball"): + ball_position[0] = world.ball.pos.x + ball_position[1] = world.ball.pos.y + + print("x",ball_position[0]) + print("y",ball_position[1]) + + if abs(ball_position[0]) <= CENTER_THRESHOLD and abs(ball_position[1]) <= CENTER_THRESHOLD: + ball_quadrant = 4 # Center + elif ball_position[0] < 0: + ball_quadrant = 0 if ball_position[1] > 0 else 2 + else: + ball_quadrant = 1 if ball_position[1] > 0 else 3 + except DecodeError: + print("Failed to decode protobuf message") + except zmq.ZMQError as e: + print(f"ZMQ Error: {e}") + finally: + socket_world.close() + context.term() + + return ball_position, ball_quadrant # Function to get the robot state def get_robot_state(): @@ -74,14 +92,16 @@ def get_robot_state(): context = zmq.Context() socket_world = context.socket(zmq.SUB) socket_world.setsockopt_string(zmq.SUBSCRIBE, "") - socket_world.connect("tcp://127.0.0.1:5558") + + zmq_address = get_zmq_address() + print(f"Connecting to ZMQ at: {zmq_address}") + socket_world.connect(zmq_address) try: message = socket_world.recv() state = RoboState.FromString(message) # print(state) - if not len(state.processed_vision_packets): return grid_array, yellow_team_dribbling, blue_team_dribbling diff --git a/roboteam_ai/src/RL/src/resetRefereeAPI.py b/roboteam_ai/src/RL/src/resetRefereeAPI.py index 2fc1dedbc..3c2108919 100644 --- a/roboteam_ai/src/RL/src/resetRefereeAPI.py +++ b/roboteam_ai/src/RL/src/resetRefereeAPI.py @@ -1,41 +1,5 @@ -import sys -import os -import websockets -import asyncio -import json +from . websocketHandler import run_websocket_command def reset_referee_state(): - """ - Synchronous function to reset the referee state - """ - uri = "ws://localhost:8081/api/control" - - async def _async_reset(): - try: - async with websockets.connect(uri) as websocket: - # Create JSON message - reset_msg = { - "reset_match": True - } - - print("Sending JSON reset command...") - await websocket.send(json.dumps(reset_msg)) - - try: - response = await asyncio.wait_for(websocket.recv(), timeout=2.0) - #print("Received response:", json.loads(response)) - except asyncio.TimeoutError: - print("No response received in 2 seconds") - - except websockets.exceptions.ConnectionClosed as e: - print(f"WebSocket connection closed: {e}") - except Exception as e: - print(f"Error: {e}") - - # Run the async function synchronously - loop = asyncio.get_event_loop() - loop.run_until_complete(_async_reset()) - -if __name__ == "__main__": - print("Connecting to game controller...") - asyncio.get_event_loop().run_until_complete(reset_referee_state()) \ No newline at end of file + reset_msg = {"reset_match": True} + return run_websocket_command(reset_msg) \ No newline at end of file diff --git a/roboteam_ai/src/RL/src/websocketHandler.py b/roboteam_ai/src/RL/src/websocketHandler.py new file mode 100644 index 000000000..d7d66126a --- /dev/null +++ b/roboteam_ai/src/RL/src/websocketHandler.py @@ -0,0 +1,45 @@ +import os +import websockets +import asyncio +import json + +IS_IN_K8S = True # We run it locally. + +def get_websocket_uri(): + """Get the appropriate URI based on the environment""" + if IS_IN_K8S: + host = "roboteam-ray-worker-svc" + print("Running in Kubernetes, using service DNS") + else: + host = "localhost" + print("Running locally") + + return f"ws://{host}:8081/api/control" + +async def send_websocket_message(message, timeout=2.0): + """Generic function to send websocket messages""" + uri = get_websocket_uri() + try: + async with websockets.connect(uri) as websocket: + await websocket.send(json.dumps(message)) + return await asyncio.wait_for(websocket.recv(), timeout=timeout) + except Exception as e: + print(f"Websocket error: {e}") + raise + +def run_websocket_command(message): + """ + Synchronous wrapper to run websocket commands + This is the main file + """ + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + return loop.run_until_complete(send_websocket_message(message)) + finally: + loop.close() + asyncio.set_event_loop(None) \ No newline at end of file