-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate_expert_data.py
33 lines (27 loc) · 999 Bytes
/
generate_expert_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from environment import RacerEnvironment
from stable_baselines3 import PPO
import pickle
env = RacerEnvironment(render=True, evaluate=True)
model = PPO.load("models/trained_ppo_agent/racer", env=env)
num_episodes = 12
obs_data = []
act_data = []
for i in range(num_episodes):
print("New Episode")
obs = env.reset()
done = False
episode_observation_data = []
episode_action_data = []
while not done:
action, next_hidden_state = model.predict(obs)
episode_observation_data.append(obs.tolist())
episode_action_data.append(action.tolist())
obs, reward, done, info = env.step(action)
if 'done' in info.keys():
obs_data.extend(episode_observation_data)
act_data.extend(episode_action_data)
print(info, done)
with open('models/trained_ppo_agent/env_obs_data.pkl', 'wb') as f:
pickle.dump(obs_data, f)
with open('models/trained_ppo_agent/env_act_data.pkl', 'wb') as f:
pickle.dump(act_data, f)