diff --git a/src/quadai/SAC/env_SAC.py b/src/quadai/SAC/env_SAC.py index a484350..9546521 100644 --- a/src/quadai/SAC/env_SAC.py +++ b/src/quadai/SAC/env_SAC.py @@ -25,6 +25,8 @@ class droneEnv(gym.Env): def __init__(self, render_every_frame, mouse_target): super(droneEnv, self).__init__() + path = os.path.dirname(os.path.abspath(__file__)) + path = os.path.join(path,"../") self.render_every_frame = render_every_frame # Makes the target follow the mouse self.mouse_target = mouse_target @@ -34,10 +36,10 @@ def __init__(self, render_every_frame, mouse_target): self.screen = pygame.display.set_mode((800, 800)) self.FramePerSec = pygame.time.Clock() - self.player = pygame.image.load(os.path.join("assets/sprites/drone_old.png")) + self.player = pygame.image.load(os.path.join(path,"assets/sprites/drone_old.png")) self.player.convert() - self.target = pygame.image.load(os.path.join("assets/sprites/target_old.png")) + self.target = pygame.image.load(os.path.join(path,"assets/sprites/target_old.png")) self.target.convert() pygame.font.init() diff --git a/src/quadai/balloon.py b/src/quadai/balloon.py index c4d213c..ec290dd 100644 --- a/src/quadai/balloon.py +++ b/src/quadai/balloon.py @@ -14,8 +14,9 @@ import numpy as np import pygame from pygame.locals import * -from quadai.player import HumanPlayer, PIDPlayer, SACPlayer +from quadai.player import HumanPlayer, PIDPlayer, SACPlayer,Player +def_players = [HumanPlayer(), PIDPlayer(), SACPlayer()] def correct_path(current_path): """ @@ -24,7 +25,7 @@ def correct_path(current_path): return os.path.join(os.path.dirname(__file__), current_path) -def balloon(): +def balloon(players : list[Player] = def_players): """ Runs the balloon game. """ @@ -150,8 +151,6 @@ def display_info(position): time_limit = 100 respawn_timer_max = 3 - players = [HumanPlayer(), PIDPlayer(), SACPlayer()] - # Generate 100 targets targets = [] for i in range(100): @@ -188,7 +187,7 @@ def display_info(position): player.angular_acceleration = 0 # Calculate propeller force in function of input - if player.name == "DQN" or player.name == "PID": + if player.type == "DQN" or player.type == "PID": thruster_left, thruster_right = player.act( [ targets[player.target_counter][0] - player.x_position, @@ -199,7 +198,7 @@ def display_info(position): player.angular_speed, ] ) - elif player.name == "SAC": + elif player.type == "SAC": angle_to_up = player.angle / 180 * pi velocity = sqrt(player.x_speed**2 + player.y_speed**2) angle_velocity = player.angular_speed @@ -285,7 +284,7 @@ def display_info(position): player.respawn_timer = respawn_timer_max else: # Display respawn timer - if player.name == "Human": + if player.type == "Human": respawn_text = respawn_timer_font.render( str(int(player.respawn_timer) + 1), True, (255, 255, 255) ) @@ -385,10 +384,13 @@ def display_info(position): # Print scores and who won print("") scores = [] + dict_scores = {} for player in players: print(player.name + " collected : " + str(player.target_counter)) scores.append(player.target_counter) + dict_scores[player.name] = player.target_counter winner = players[np.argmax(scores)].name print("") print("Winner is : " + winner + " !") + return dict_scores diff --git a/src/quadai/player.py b/src/quadai/player.py index e09ed07..04351b3 100644 --- a/src/quadai/player.py +++ b/src/quadai/player.py @@ -18,6 +18,8 @@ class Player: + type: str = "Player" + name: str = "Player" def __init__(self): self.thruster_mean = 0.04 self.thruster_amplitude = 0.04 @@ -33,6 +35,7 @@ def __init__(self): class HumanPlayer(Player): def __init__(self): self.name = "Human" + self.type = "Human" self.alpha = 255 super().__init__() @@ -56,6 +59,7 @@ def act(self, obs): class PIDPlayer(Player): def __init__(self): self.name = "PID" + self.type = "PID" self.alpha = 50 super().__init__() @@ -95,6 +99,7 @@ def act(self, obs): class HumanPlayer(Player): def __init__(self): self.name = "Human" + self.type = "Human" self.alpha = 255 super().__init__() @@ -116,13 +121,14 @@ def act(self, obs): class SACPlayer(Player): - def __init__(self): - self.name = "SAC" + def __init__(self,name="SAC",model_path="models/sac_model_v2_5000000_steps.zip"): + self.name = name + self.type = "SAC" self.alpha = 50 self.thruster_amplitude = 0.04 self.diff_amplitude = 0.003 - model_path = "models/sac_model_v2_5000000_steps.zip" - model_path = os.path.join(os.path.dirname(__file__), model_path) + if not model_path.startswith("/"): + model_path = os.path.join(os.path.dirname(__file__), model_path) self.path = model_path super().__init__()