-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_v_model.py
132 lines (105 loc) · 4.87 KB
/
model_v_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import sys
from pathlib import Path
from typing import Any, List
import numpy as np
import pandas as pd
from lucy_utils.multi_instance_utils import get_match
from lucy_utils.obs import GraphAttentionObsV1, GraphAttentionObsV2, NectoObs
from rlgym.utils.gamestates import GameState
from rlgym.utils.gamestates import PlayerData
from rlgym.utils.obs_builders import ObsBuilder
from rlgym.utils.reward_functions import DefaultReward
from rlgym.utils.state_setters import DefaultState
from rlgym.utils.terminal_conditions import common_conditions
from rlgym_tools.extra_action_parsers.kbm_act import KBMAction
from rlgym_tools.sb3_utils import SB3MultipleInstanceEnv
from stable_baselines3 import PPO
# import deprecated `utils` package for trained Necto to work
utils_path = str(Path.home()) + "\\rocket_league_utils\\old_deprecated_utils"
sys.path.insert(0, utils_path)
class MultiModelObs(ObsBuilder):
def __init__(self, obss: List[ObsBuilder], num_obs_players: List[int]):
super(MultiModelObs, self).__init__()
assert len(obss) == len(num_obs_players), "`obss` and `num_obs_players` lengths must match"
self.obss = obss
self.num_obs_players = np.cumsum(num_obs_players)
self.p_idx = 0
self.curr_state = None
self.autodetect = True
def reset(self, initial_state: GameState):
[o.reset(initial_state) for o in self.obss]
def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any:
if self.autodetect:
self.autodetect = False
return np.zeros(0)
if self.curr_state != state:
self.p_idx = 0
self.curr_state = state
obs_idx = (self.p_idx >= self.num_obs_players).sum()
self.p_idx += 1
return self.obss[obs_idx].build_obs(player, state, previous_action)
if __name__ == '__main__':
team_size = 2
tick_skip = 8
fps = 120 // tick_skip
terminal_conditions = [common_conditions.TimeoutCondition(fps * 300),
common_conditions.NoTouchTimeoutCondition(fps * 45),
common_conditions.GoalScoredCondition()]
# obs_builder = MultiModelObs([GraphAttentionObsV1(stack_size=5), NectoObs()], [2, 2])
# obs_builder = MultiModelObs([GraphAttentionObsV1(), NectoObs()], [2, 2])
# obs_builder = MultiModelObs([GraphAttentionObsV2(stack_size=5, add_boost_pads=True), NectoObs()], [2, 2])
obs_builder = MultiModelObs([GraphAttentionObsV1(stack_size=5)], [4])
# obs_builder = MultiModelObs([GraphAttentionObsV2(stack_size=5, add_boost_pads=True),
# GraphAttentionObsV1(stack_size=5)], [2, 2])
# obs_builder = MultiModelObs([NectoObs()], [4])
match = get_match(reward=DefaultReward(),
terminal_conditions=terminal_conditions,
obs_builder=obs_builder,
action_parser=KBMAction(),
state_setter=DefaultState(),
team_size=team_size,
)
env = SB3MultipleInstanceEnv([match])
custom_objects = {
# arbitrary
'lr_schedule': 1e-4,
'clip_range': .2,
# 2v2
'n_envs': 2,
}
blue_model = PPO.load("../models_folder/NectoReward_ownPerceiver_preproc_norm_stack5/model_501760000_steps.zip",
device="cpu", custom_objects=custom_objects)
orange_model = PPO.load("../models_folder/Perceiver_LucyReward_v3/model_502400000_steps.zip",
device="cpu", custom_objects=custom_objects)
max_score_count = 300
match_name = "Necto reduced + 5-stack vs v3, " + str(max_score_count) + " goals, 500 million, 1"
blue_score_sum = 0
orange_score_sum = 0
blue_scores = []
orange_scores = []
while True:
obs = env.reset()
done = [False]
while not done[0]:
action = np.concatenate((blue_model.predict(np.stack(obs[:2]))[0],
orange_model.predict(np.stack(obs[2:]))[0]))[None]
obs, reward, done, gameinfo = env.step(action)
final_state: GameState = gameinfo[0]['state']
blue_score_dif = final_state.blue_score - blue_score_sum
orange_score_dif = final_state.orange_score - orange_score_sum
blue_score_sum += blue_score_dif
orange_score_sum += orange_score_dif
blue_scores.append(blue_score_sum)
orange_scores.append(orange_score_sum)
if blue_score_sum >= max_score_count or orange_score_sum >= max_score_count:
break
df = pd.DataFrame([blue_scores, orange_scores]).T
df.columns = ["Blue score", "Orange score"]
df.to_csv("evaluation_results/" + match_name + ".csv")
print("\n\n")
print("====================")
print("RESULT")
print("====================")
print("Blue:", blue_score_sum)
print("Orange:", orange_score_sum)
env.close()