-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 122659b
Showing
40 changed files
with
9,279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
PYTHONPATH=src |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# virtualenv | ||
.venv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# Intellij | ||
*.iml | ||
/.idea | ||
|
||
# Build output | ||
/build | ||
|
||
# Gradle files | ||
/.gradle | ||
|
||
# VSCode | ||
.vscode/ | ||
*.code-workspace | ||
|
||
# -------------------- | ||
|
||
# Trained models | ||
models_folder | ||
# Tensorboard logs | ||
bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[TYPECHECK] | ||
generated-members=QuickChats.* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2021 RLGym | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Lucy | ||
|
||
Lucy is the effort of two MSc students at the Aristotle University of Thessaloniki as part of our MSc theses | ||
in Reinforcement Learning. Lucy was largely inspired by Necto, built on a similar architecture and trained using | ||
novel reward functions. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
@echo off | ||
:: This file is taken from chocolatey: | ||
:: https://github.com/chocolatey/choco/blob/master/src/chocolatey.resources/redirects/RefreshEnv.cmd | ||
:: | ||
:: RefreshEnv.cmd | ||
:: | ||
:: Batch file to read environment variables from registry and | ||
:: set session variables to these values. | ||
:: | ||
:: With this batch file, there should be no need to reload command | ||
:: environment every time you want environment changes to propagate | ||
|
||
::echo "RefreshEnv.cmd only works from cmd.exe, please install the Chocolatey Profile to take advantage of refreshenv from PowerShell" | ||
echo | set /p dummy="Refreshing environment variables from registry for cmd.exe. Please wait..." | ||
|
||
goto main | ||
|
||
:: Set one environment variable from registry key | ||
:SetFromReg | ||
"%WinDir%\System32\Reg" QUERY "%~1" /v "%~2" > "%TEMP%\_envset.tmp" 2>NUL | ||
for /f "usebackq skip=2 tokens=2,*" %%A IN ("%TEMP%\_envset.tmp") do ( | ||
echo/set "%~3=%%B" | ||
) | ||
goto :EOF | ||
|
||
:: Get a list of environment variables from registry | ||
:GetRegEnv | ||
"%WinDir%\System32\Reg" QUERY "%~1" > "%TEMP%\_envget.tmp" | ||
for /f "usebackq skip=2" %%A IN ("%TEMP%\_envget.tmp") do ( | ||
if /I not "%%~A"=="Path" ( | ||
call :SetFromReg "%~1" "%%~A" "%%~A" | ||
) | ||
) | ||
goto :EOF | ||
|
||
:main | ||
echo/@echo off >"%TEMP%\_env.cmd" | ||
|
||
:: Slowly generating final file | ||
call :GetRegEnv "HKLM\System\CurrentControlSet\Control\Session Manager\Environment" >> "%TEMP%\_env.cmd" | ||
call :GetRegEnv "HKCU\Environment">>"%TEMP%\_env.cmd" >> "%TEMP%\_env.cmd" | ||
|
||
:: Special handling for PATH - mix both User and System | ||
call :SetFromReg "HKLM\System\CurrentControlSet\Control\Session Manager\Environment" Path Path_HKLM >> "%TEMP%\_env.cmd" | ||
call :SetFromReg "HKCU\Environment" Path Path_HKCU >> "%TEMP%\_env.cmd" | ||
|
||
:: Caution: do not insert space-chars before >> redirection sign | ||
echo/set "Path=%%Path_HKLM%%;%%Path_HKCU%%" >> "%TEMP%\_env.cmd" | ||
|
||
:: Cleanup | ||
del /f /q "%TEMP%\_envset.tmp" 2>nul | ||
del /f /q "%TEMP%\_envget.tmp" 2>nul | ||
|
||
:: capture user / architecture | ||
SET "OriginalUserName=%USERNAME%" | ||
SET "OriginalArchitecture=%PROCESSOR_ARCHITECTURE%" | ||
|
||
:: Set these variables | ||
call "%TEMP%\_env.cmd" | ||
|
||
:: reset user / architecture | ||
SET "USERNAME=%OriginalUserName%" | ||
SET "PROCESSOR_ARCHITECTURE=%OriginalArchitecture%" | ||
|
||
echo | set /p dummy="Finished." | ||
echo . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from rlgym.utils.state_setters import DefaultState | ||
|
||
from experiment.lucy_match_params import LucyReward, LucyTerminalConditions, LucyObs, LucyAction | ||
from lucy_utils.load_evaluate import load_and_evaluate | ||
|
||
if __name__ == '__main__': | ||
load_and_evaluate("../models_folder/Perceiver/model_449280000_steps.zip", | ||
2, | ||
LucyTerminalConditions(15), | ||
LucyObs(), | ||
DefaultState(), # we use the default state for evaluation | ||
LucyAction(), | ||
LucyReward(0.995) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from rlgym_tools.sb3_utils import SB3MultipleInstanceEnv | ||
from rlgym_tools.sb3_utils.sb3_instantaneous_fps_callback import SB3InstantaneousFPSCallback | ||
from stable_baselines3.common.callbacks import CheckpointCallback | ||
from stable_baselines3.common.vec_env import VecMonitor | ||
|
||
from lucy_match_params import LucyReward, LucyTerminalConditions, LucyObs, LucyState, LucyAction | ||
from lucy_utils.algorithms import DeviceAlternatingPPO | ||
from lucy_utils.models import PerceiverNet | ||
from lucy_utils.multi_instance_utils import config, make_matches | ||
from lucy_utils.policies import ActorCriticAttnPolicy | ||
from lucy_utils.rewards.sb3_log_reward import SB3NamedLogRewardCallback | ||
|
||
models_folder = "models_folder/" | ||
|
||
if __name__ == '__main__': | ||
num_instances = 8 | ||
agents_per_match = 2 * 2 # self-play | ||
n_steps, batch_size, gamma, fps, save_freq = config(num_instances=num_instances, | ||
avg_agents_per_match=agents_per_match, | ||
target_steps=256_000, | ||
target_batch_size=0.5, | ||
callback_save_freq=10) | ||
|
||
matches = make_matches(logged_reward_cls=lambda log=False: LucyReward(gamma, log), | ||
terminal_conditions=lambda: LucyTerminalConditions(fps), | ||
obs_builder_cls=lambda: LucyObs(stack_size=5), | ||
action_parser_cls=LucyAction, | ||
state_setter_cls=LucyState, | ||
sizes=[agents_per_match // 2] * num_instances # self-play, hence // 2 | ||
) | ||
|
||
env = SB3MultipleInstanceEnv(match_func_or_matches=matches) | ||
env = VecMonitor(env) | ||
|
||
policy_kwargs = dict(network_classes=PerceiverNet, | ||
net_arch=[dict( | ||
# minus one for the key padding mask | ||
query_dims=env.observation_space.shape[-1] - 1, | ||
# minus eight for the previous action | ||
kv_dims=env.observation_space.shape[-1] - 1 - 8, | ||
# the rest is default arguments | ||
)] * 2, # *2 because actor and critic will share the same architecture | ||
action_stack_size=5) | ||
|
||
# model = DeviceAlternatingPPO.load("./models_folder/Perceiver/model_743680000_steps.zip", env) | ||
model = DeviceAlternatingPPO(policy=ActorCriticAttnPolicy, | ||
env=env, | ||
learning_rate=1e-4, | ||
n_steps=n_steps, | ||
gamma=gamma, | ||
batch_size=batch_size, | ||
tensorboard_log="./bin", | ||
policy_kwargs=policy_kwargs, | ||
verbose=1, | ||
) | ||
|
||
callbacks = [SB3InstantaneousFPSCallback(), | ||
SB3NamedLogRewardCallback(), | ||
CheckpointCallback(save_freq, | ||
save_path=models_folder + "Perceiver", | ||
name_prefix="model")] | ||
model.learn(total_timesteps=1_000_000_000, | ||
callback=callbacks, | ||
tb_log_name="PPO_Perceiver2_4x256", | ||
# reset_num_timesteps=False | ||
) | ||
model.save(models_folder + "Perceiver_final") | ||
|
||
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from pathlib import Path | ||
|
||
from rlgym.utils.reward_functions import common_rewards | ||
from rlgym.utils.state_setters import RandomState, DefaultState | ||
from rlgym.utils.terminal_conditions import common_conditions | ||
from rlgym_tools.extra_action_parsers.kbm_act import KBMAction | ||
from rlgym_tools.extra_state_setters.goalie_state import GoaliePracticeState | ||
from rlgym_tools.extra_state_setters.replay_setter import ReplaySetter | ||
from rlgym_tools.extra_state_setters.symmetric_setter import KickoffLikeSetter | ||
from rlgym_tools.extra_state_setters.weighted_sample_setter import WeightedSampleSetter | ||
|
||
from lucy_utils import rewards | ||
from lucy_utils.build_reward import build_logged_reward | ||
from lucy_utils.obs import GraphAttentionObs | ||
|
||
_f_reward_weight_args = ((rewards.SignedLiuDistanceBallToGoalReward, 8), | ||
(common_rewards.VelocityBallToGoalReward, 2), | ||
(rewards.BallYCoordinateReward, 1), | ||
(common_rewards.VelocityPlayerToBallReward, 0.5), | ||
(rewards.LiuDistancePlayerToBallReward, 0.5), | ||
(rewards.DistanceWeightedAlignBallGoal, 0.65, dict(defense=0.5, offense=0.5)), | ||
(common_rewards.SaveBoostReward, 0.5)) | ||
""" | ||
Potential: reward class, weight (, kwargs) | ||
""" | ||
|
||
_r_reward_name_weight_args = ((rewards.EventReward, "Goal", 1, dict(goal=10, team_goal=4, concede=-10)), | ||
(rewards.EventReward, "Shot", 1, dict(shot=1)), | ||
(rewards.EventReward, "Save", 1, dict(save=3)), | ||
(rewards.EventReward, "Touch", 1, dict(touch=0.05)), | ||
(rewards.EventReward, "Demo", 1, dict(demo=2, demoed=-2))) | ||
""" | ||
Event: reward class, reward name, weight, kwargs | ||
""" | ||
|
||
|
||
def _get_reward(gamma: float, log: bool = False): | ||
return build_logged_reward(_f_reward_weight_args, _r_reward_name_weight_args, 0.3, gamma, log) | ||
|
||
|
||
def _get_terminal_conditions(fps): | ||
return [common_conditions.TimeoutCondition(fps * 300), | ||
common_conditions.NoTouchTimeoutCondition(fps * 45), | ||
common_conditions.GoalScoredCondition()] | ||
|
||
|
||
def _get_state(): | ||
replay_path = str(Path(__file__).parent / "../replay-samples/2v2/states.npy") | ||
# Following Necto logic | ||
return WeightedSampleSetter.from_zipped( | ||
# replay setter uses carball, no warnings for numpy==1.21.5 | ||
(ReplaySetter(replay_path), 0.7), | ||
(RandomState(True, True, False), 0.15), | ||
(DefaultState(), 0.05), | ||
(KickoffLikeSetter(), 0.05), | ||
(GoaliePracticeState(first_defender_in_goal=True), 0.05) | ||
) | ||
|
||
|
||
LucyReward = _get_reward | ||
LucyTerminalConditions = _get_terminal_conditions | ||
LucyState = _get_state | ||
LucyObs = GraphAttentionObs | ||
LucyAction = KBMAction |
Oops, something went wrong.