Skip to content

Commit

Permalink
Updated experiment for Nexto
Browse files Browse the repository at this point in the history
  • Loading branch information
Enkhai committed Apr 24, 2023
1 parent 3309290 commit 6af55d9
Showing 1 changed file with 45 additions and 35 deletions.
80 changes: 45 additions & 35 deletions experiment/experiment.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,42 @@
from lucy_utils.algorithms import DeviceAlternatingAuxPPO
from lucy_utils.algorithms import DeviceAlternatingPPO
from lucy_utils.models import PerceiverNet
from lucy_utils.multi_instance_utils import config, make_matches
from lucy_utils.policies import AuxACAttnPolicy
from lucy_utils.multi_instance_utils import config, get_matches
from lucy_utils.obs import NextoObsBuilder
from lucy_utils.parsers import NextoAction
from lucy_utils.policies import ActorCriticAttnPolicy
from lucy_utils.rewards.nexto_reward import NextoRewardFunction
from lucy_utils.rewards.sb3_log_reward import SB3NamedLogRewardCallback
from rlgym_tools.sb3_utils import SB3MultipleInstanceEnv
from rlgym_tools.sb3_utils.sb3_instantaneous_fps_callback import SB3InstantaneousFPSCallback
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import VecMonitor

from lucy_match_params import LucyReward, LucyTerminalConditions, LucyObs, LucyState, LucyAction
from lucy_match_params import LucyTerminalConditions, LucyState

models_folder = "models_folder/"
tensorboard_log_dir = "bin"
model_name = "Perceiver_LucyReward_v3_combined"
model_name = "Nexto_Perceiver"

if __name__ == '__main__':
# ----- ENV CONFIG -----

num_instances = 20
num_instances = 1
agents_per_match = 2 * 2 # self-play
n_steps, batch_size, gamma, fps, save_freq = config(num_instances=num_instances,
avg_agents_per_match=agents_per_match,
target_steps=320_000,
target_batch_size=4_000,
callback_save_freq=10)

action_stacking = 5
action_stacking = 1

matches = make_matches(logged_reward_cls=lambda log=False: LucyReward(1, log),
terminal_conditions=lambda: LucyTerminalConditions(fps),
obs_builder_cls=lambda: LucyObs(stack_size=action_stacking),
action_parser_cls=LucyAction,
state_setter_cls=LucyState,
sizes=[agents_per_match // 2] * num_instances # self-play, hence // 2
)
matches = get_matches(reward_cls=NextoRewardFunction,
terminal_conditions=lambda: LucyTerminalConditions(fps),
obs_builder_cls=NextoObsBuilder,
action_parser_cls=NextoAction,
state_setter_cls=LucyState,
sizes=[agents_per_match // 2] * num_instances # self-play, hence // 2
)

# ----- ENV SETUP -----

Expand All @@ -42,31 +45,38 @@

# ----- MODEL SETUP -----

critic_net_arch = dict(
# minus one for the key padding mask
query_dims=env.observation_space.shape[-1] - 1,
# minus the stack for the previous actions
kv_dims=env.observation_space.shape[-1] - 1 - (action_stacking * 8),
# the rest is default arguments
)
actor_net_arch = dict(critic_net_arch)
actor_net_arch['player_emb_net_shape'] = [32]
actor_net_arch['action_emb_net_shape'] = [32] * 3

policy_kwargs = dict(network_classes=PerceiverNet,
net_arch=[dict(
# minus one for the key padding mask
query_dims=env.observation_space.shape[-1] - 1,
# minus the stack for the previous actions
kv_dims=env.observation_space.shape[-1] - 1 - (action_stacking * 8),
# the rest is default arguments
)] * 2, # *2 because actor and critic will share the same architecture
net_arch=[actor_net_arch, critic_net_arch],
action_stack_size=action_stacking,
use_rp=True,
use_sr=True,
rp_seq_len=20,
zero_rew_threshold=0.009)
is_nexto=True,
# use_rp=True,
# use_sr=True,
# rp_seq_len=20,
# zero_rew_threshold=0.009
)

# model = DeviceAlternatingPPO.load("./models_folder/Perceiver/model_743680000_steps.zip", env)
model = DeviceAlternatingAuxPPO(policy=AuxACAttnPolicy,
env=env,
learning_rate=1e-4,
n_steps=n_steps,
gamma=gamma,
batch_size=batch_size,
tensorboard_log=tensorboard_log_dir,
policy_kwargs=policy_kwargs,
verbose=1,
)
model = DeviceAlternatingPPO(policy=ActorCriticAttnPolicy,
env=env,
learning_rate=1e-4,
n_steps=n_steps,
gamma=gamma,
batch_size=batch_size,
tensorboard_log=tensorboard_log_dir,
policy_kwargs=policy_kwargs,
verbose=1,
)

# ----- TRAINING -----

Expand Down

0 comments on commit 6af55d9

Please sign in to comment.