Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type to player to all for multiple of same type of player #6

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/quadai/SAC/env_SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class droneEnv(gym.Env):
def __init__(self, render_every_frame, mouse_target):
super(droneEnv, self).__init__()

path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path,"../")
self.render_every_frame = render_every_frame
# Makes the target follow the mouse
self.mouse_target = mouse_target
Expand All @@ -34,10 +36,10 @@ def __init__(self, render_every_frame, mouse_target):
self.screen = pygame.display.set_mode((800, 800))
self.FramePerSec = pygame.time.Clock()

self.player = pygame.image.load(os.path.join("assets/sprites/drone_old.png"))
self.player = pygame.image.load(os.path.join(path,"assets/sprites/drone_old.png"))
self.player.convert()

self.target = pygame.image.load(os.path.join("assets/sprites/target_old.png"))
self.target = pygame.image.load(os.path.join(path,"assets/sprites/target_old.png"))
self.target.convert()

pygame.font.init()
Expand Down
16 changes: 9 additions & 7 deletions src/quadai/balloon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import numpy as np
import pygame
from pygame.locals import *
from quadai.player import HumanPlayer, PIDPlayer, SACPlayer
from quadai.player import HumanPlayer, PIDPlayer, SACPlayer,Player

def_players = [HumanPlayer(), PIDPlayer(), SACPlayer()]

def correct_path(current_path):
"""
Expand All @@ -24,7 +25,7 @@ def correct_path(current_path):
return os.path.join(os.path.dirname(__file__), current_path)


def balloon():
def balloon(players : list[Player] = def_players):
"""
Runs the balloon game.
"""
Expand Down Expand Up @@ -150,8 +151,6 @@ def display_info(position):
time_limit = 100
respawn_timer_max = 3

players = [HumanPlayer(), PIDPlayer(), SACPlayer()]

# Generate 100 targets
targets = []
for i in range(100):
Expand Down Expand Up @@ -188,7 +187,7 @@ def display_info(position):
player.angular_acceleration = 0

# Calculate propeller force in function of input
if player.name == "DQN" or player.name == "PID":
if player.type == "DQN" or player.type == "PID":
thruster_left, thruster_right = player.act(
[
targets[player.target_counter][0] - player.x_position,
Expand All @@ -199,7 +198,7 @@ def display_info(position):
player.angular_speed,
]
)
elif player.name == "SAC":
elif player.type == "SAC":
angle_to_up = player.angle / 180 * pi
velocity = sqrt(player.x_speed**2 + player.y_speed**2)
angle_velocity = player.angular_speed
Expand Down Expand Up @@ -285,7 +284,7 @@ def display_info(position):
player.respawn_timer = respawn_timer_max
else:
# Display respawn timer
if player.name == "Human":
if player.type == "Human":
respawn_text = respawn_timer_font.render(
str(int(player.respawn_timer) + 1), True, (255, 255, 255)
)
Expand Down Expand Up @@ -385,10 +384,13 @@ def display_info(position):
# Print scores and who won
print("")
scores = []
dict_scores = {}
for player in players:
print(player.name + " collected : " + str(player.target_counter))
scores.append(player.target_counter)
dict_scores[player.name] = player.target_counter
winner = players[np.argmax(scores)].name

print("")
print("Winner is : " + winner + " !")
return dict_scores
14 changes: 10 additions & 4 deletions src/quadai/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@


class Player:
type: str = "Player"
name: str = "Player"
def __init__(self):
self.thruster_mean = 0.04
self.thruster_amplitude = 0.04
Expand All @@ -33,6 +35,7 @@ def __init__(self):
class HumanPlayer(Player):
def __init__(self):
self.name = "Human"
self.type = "Human"
self.alpha = 255
super().__init__()

Expand All @@ -56,6 +59,7 @@ def act(self, obs):
class PIDPlayer(Player):
def __init__(self):
self.name = "PID"
self.type = "PID"
self.alpha = 50
super().__init__()

Expand Down Expand Up @@ -95,6 +99,7 @@ def act(self, obs):
class HumanPlayer(Player):
def __init__(self):
self.name = "Human"
self.type = "Human"
self.alpha = 255
super().__init__()

Expand All @@ -116,13 +121,14 @@ def act(self, obs):


class SACPlayer(Player):
def __init__(self):
self.name = "SAC"
def __init__(self,name="SAC",model_path="models/sac_model_v2_5000000_steps.zip"):
self.name = name
self.type = "SAC"
self.alpha = 50
self.thruster_amplitude = 0.04
self.diff_amplitude = 0.003
model_path = "models/sac_model_v2_5000000_steps.zip"
model_path = os.path.join(os.path.dirname(__file__), model_path)
if not model_path.startswith("/"):
model_path = os.path.join(os.path.dirname(__file__), model_path)
self.path = model_path
super().__init__()

Expand Down