-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e4912c3
commit b0a2ed1
Showing
53 changed files
with
368 additions
and
3,170 deletions.
There are no files selected for viewing
File renamed without changes.
88 changes: 88 additions & 0 deletions
88
rlgym_tools/action_parsers/advanced_lookup_table_action.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
from rlgym.rocket_league.action_parsers import LookupTableAction | ||
|
||
|
||
class AdvancedLookupTableAction(LookupTableAction): | ||
def __init__(self, throttle_bins: Any = 3, | ||
steer_bins: Any = 3, | ||
torque_bins: Any = 3, | ||
flip_bins: Any = 8, | ||
include_stall: bool = False): | ||
super().__init__() | ||
self._lookup_table = self.make_lookup_table(throttle_bins, steer_bins, torque_bins, flip_bins, include_stall) | ||
|
||
@staticmethod | ||
def make_lookup_table(throttle_bins: Any = 3, | ||
steer_bins: Any = 3, | ||
torque_bins: Any = 3, | ||
flip_bins: Any = 8, | ||
include_stalls: bool = False): | ||
# Parse bins | ||
def parse_bin(b, endpoint=True): | ||
if isinstance(b, int): | ||
b = np.linspace(-1, 1, b, endpoint=endpoint) | ||
else: | ||
b = np.array(b) | ||
return b | ||
|
||
throttle_bins = parse_bin(throttle_bins) | ||
steer_bins = parse_bin(steer_bins) | ||
torque_bins = parse_bin(torque_bins) | ||
flip_bins = (parse_bin(flip_bins, endpoint=False) + 1) * np.pi # Split a circle into equal segments in [0, 2pi) | ||
|
||
actions = [] | ||
|
||
# Ground | ||
pitch = roll = jump = 0 | ||
for throttle in throttle_bins: | ||
for steer in steer_bins: | ||
for boost in (0, 1): | ||
for handbrake in (0, 1): | ||
if boost == 1 and throttle != 1: | ||
continue | ||
yaw = steer | ||
actions.append([throttle, steer, pitch, yaw, roll, jump, boost, handbrake]) | ||
|
||
# Aerial | ||
jump = handbrake = 0 | ||
for pitch in torque_bins: | ||
for yaw in torque_bins: | ||
for roll in torque_bins: | ||
if pitch == roll == 0 and np.isclose(yaw, steer_bins).any(): | ||
continue # Duplicate with ground | ||
magnitude = max(abs(pitch), abs(yaw), abs(roll)) | ||
if magnitude < 1: | ||
continue # Duplicate rotation direction, only keep max magnitude | ||
for boost in (0, 1): | ||
throttle = boost | ||
steer = yaw | ||
actions.append([throttle, steer, pitch, yaw, roll, jump, boost, handbrake]) | ||
|
||
# Flips and jumps | ||
jump = handbrake = 1 # Enable handbrake for potential wavedashes | ||
yaw = steer = 0 # Only need roll for sideflip | ||
angles = [np.nan] + [v for v in flip_bins] | ||
for angle in angles: | ||
if np.isnan(angle): | ||
pitch = roll = 0 # Empty jump | ||
else: | ||
pitch = np.sin(angle) | ||
roll = np.cos(angle) | ||
# Project to square of diameter 2 because why not | ||
magnitude = max(abs(pitch), abs(roll)) | ||
pitch /= magnitude | ||
roll /= magnitude | ||
for boost in (0, 1): | ||
throttle = boost | ||
actions.append([throttle, steer, pitch, yaw, roll, jump, boost, handbrake]) | ||
if include_stalls: | ||
# Add actions for stalling | ||
actions.append([0, 0, 0, 1, -1, 1, 0, 1]) | ||
actions.append([0, 0, 0, -1, 1, 1, 0, 1]) | ||
|
||
actions = np.round(np.array(actions), 3) # Convert to numpy and remove floating point errors | ||
assert len(np.unique(actions, axis=0)) == len(actions), 'Duplicate actions found' | ||
|
||
return actions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Dict, Any, Tuple | ||
|
||
import numpy as np | ||
from rlgym.api import ActionParser, AgentID, ActionType, EngineActionType, SpaceType | ||
from rlgym.rocket_league.api import GameState | ||
|
||
|
||
class DelayedAction(ActionParser[AgentID, np.ndarray, np.ndarray, GameState, Tuple[AgentID, np.ndarray]]): | ||
def __init__(self, parser: ActionParser, action_queue_size: int = 1): | ||
""" | ||
DelayedAction maintains a queue of actions to execute and adds parsed actions to the queue. | ||
:param parser: the action parser to parse actions that are then added to the queue. | ||
""" | ||
super().__init__() | ||
self.parser = parser | ||
self.action_queue_size = action_queue_size | ||
self.action_queue = {} | ||
self.is_initial = True | ||
|
||
def get_action_space(self, agent: AgentID) -> SpaceType: | ||
return self.parser.get_action_space(agent) | ||
|
||
def reset(self, initial_state: GameState, shared_info: Dict[str, Any]) -> None: | ||
self.parser.reset(initial_state, shared_info) | ||
self.action_queue = {k: [] for k in initial_state.cars.keys()} | ||
self.is_initial = True | ||
shared_info["action_queue"] = self.action_queue | ||
|
||
def parse_actions(self, actions: Dict[AgentID, ActionType], state: GameState, shared_info: Dict[str, Any]) \ | ||
-> Dict[AgentID, EngineActionType]: | ||
parsed_actions = self.parser.parse_actions(actions, state, shared_info) | ||
returned_actions = {} | ||
if self.is_initial: | ||
for agent, action in parsed_actions.items(): | ||
self.action_queue[agent] = [action] * self.action_queue_size | ||
self.is_initial = False | ||
else: | ||
for agent, action in parsed_actions.items(): | ||
self.action_queue[agent].append(action) | ||
returned_actions[agent] = self.action_queue[agent].pop(0) | ||
shared_info["action_queue"] = self.action_queue | ||
return returned_actions |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import List, Dict, Any | ||
|
||
from rlgym.api import DoneCondition, AgentID | ||
from rlgym.rocket_league.api import GameState | ||
from rlgym.rocket_league.common_values import BLUE_TEAM, BALL_RADIUS | ||
from rlgym.rocket_league.common_values import TICKS_PER_SECOND | ||
|
||
|
||
class GameCondition(DoneCondition[AgentID, GameState]): | ||
def __init__(self, game_duration_seconds: int, seconds_per_goal_forfeit=None, max_overtime_seconds=None): | ||
self.game_duration_seconds = game_duration_seconds | ||
self.seconds_left = game_duration_seconds | ||
self.is_overtime = False | ||
self.seconds_per_goal_forfeit = seconds_per_goal_forfeit | ||
self.max_overtime_seconds = max_overtime_seconds | ||
self.scoreline = (0, 0) | ||
self.prev_state = None | ||
|
||
def reset(self, initial_state: GameState, shared_info: Dict[str, Any]) -> None: | ||
self.seconds_left = self.game_duration_seconds | ||
self.is_overtime = False | ||
self.scoreline = (0, 0) | ||
self.prev_state = initial_state | ||
shared_info["scoreboard"] = {"scoreline": self.scoreline, | ||
"is_overtime": self.is_overtime, | ||
"seconds_left": self.seconds_left, | ||
"go_to_kickoff": True} | ||
|
||
def is_done(self, agents: List[AgentID], state: GameState, shared_info: Dict[str, Any]) -> Dict[AgentID, bool]: | ||
ticks_passed = state.tick_count - self.prev_state.tick_count | ||
self.seconds_left -= ticks_passed / TICKS_PER_SECOND | ||
self.seconds_left = max(0, self.seconds_left) | ||
dones = {agent: False for agent in agents} | ||
go_to_kickoff = False | ||
if state.goal_scored: | ||
if self.is_overtime or self.seconds_left <= 0: | ||
dones = {agent: True for agent in agents} | ||
if state.scoring_team == BLUE_TEAM: | ||
self.scoreline = (self.scoreline[0] + 1, self.scoreline[1]) | ||
else: | ||
self.scoreline = (self.scoreline[0], self.scoreline[1] + 1) | ||
elif self.seconds_left <= 0: | ||
prev_ball = self.prev_state.ball | ||
next_z = prev_ball.position[2] + ticks_passed * prev_ball.velocity[2] / TICKS_PER_SECOND | ||
if next_z - BALL_RADIUS < 0: # Ball would be below the ground | ||
if self.scoreline[0] != self.scoreline[1]: | ||
dones = {agent: True for agent in agents} | ||
else: | ||
go_to_kickoff = True | ||
self.is_overtime = True | ||
elif self.seconds_per_goal_forfeit is not None: | ||
goal_diff = abs(self.scoreline[0] - self.scoreline[1]) | ||
if goal_diff >= 3: | ||
seconds_per_goal = self.seconds_left / goal_diff | ||
if seconds_per_goal < self.seconds_per_goal_forfeit: # Forfeit if it's not realistic to catch up | ||
dones = {agent: True for agent in agents} | ||
|
||
self.prev_state = state | ||
shared_info["scoreboard"] = {"scoreline": self.scoreline, | ||
"is_overtime": self.is_overtime, | ||
"seconds_left": self.seconds_left, | ||
"go_to_kickoff": go_to_kickoff} | ||
return dones |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.