-
Notifications
You must be signed in to change notification settings - Fork 14
/
train_imitation.py
129 lines (105 loc) · 3.81 KB
/
train_imitation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import glob
import os
import shutil
import hydra
from ray import air, tune
from ray.rllib.algorithms.bc import BCConfig, BC
from ray.rllib.algorithms.marwil import MARWILConfig, MARWIL
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from control_pcgrl.configs.config import Config, PoDConfig
from control_pcgrl.il.utils import make_pod_env
from control_pcgrl.il.wrappers import obfuscate_observation
from control_pcgrl.rl.envs import make_env
from control_pcgrl.rl.models import CustomFeedForwardModel
from control_pcgrl.rl.utils import validate_config
@hydra.main(config_path="control_pcgrl/configs", config_name="pod")
def main(cfg: PoDConfig):
cfg = validate_config(cfg)
if cfg is False:
print("Invalid config!")
return
traj_dir = os.path.join(cfg.log_dir, "repair-paths")
register_env('pcgrl', make_env)
model_cls = CustomFeedForwardModel
ModelCatalog.register_custom_model("custom_model", model_cls)
if cfg.offline_algo == "BC":
algo_config = BCConfig(
)
elif cfg.offline_algo == "MARWIL":
algo_config = MARWILConfig(
)
else:
raise ValueError(f"Invalid offline algorithm: {cfg.offline_algo}")
algo_config.model = {
'custom_model': 'custom_model',
'custom_model_config': {
},
}
# Print out some default values.
print(algo_config.beta)
# Update the config object.
algo_config.training(
# lr=tune.grid_search([0.001, 0.0001]), beta=0.0
lr=0.001,
)
# Get all json files in the directory
traj_glob = os.path.join(traj_dir, "*.json")
# Set the config object's data path.
# Run this from the ray directory root.
algo_config.offline_data(
# input_="./tmp/demo-out/output-2023-0"
# input_=os.path.join(cfg.log_dir, "demo-out")
input_=traj_glob,
)
# Set the config object's env, used for evaluation.
algo_config.environment(env='pcgrl')
algo_config.env_config = {**cfg}
algo_config.framework("torch")
il_log_dir = "il_logs"
exp_name = cfg.offline_algo
exp_dir = os.path.join(il_log_dir, exp_name)
if not cfg.overwrite and os.path.exists(exp_dir):
tuner = tune.Tuner.restore(exp_dir)
else:
shutil.rmtree(exp_dir, ignore_errors=True)
run_config = air.RunConfig(
checkpoint_config=air.CheckpointConfig(
checkpoint_at_end=True,
checkpoint_frequency=10,
num_to_keep=2,
),
local_dir=il_log_dir,
)
tuner = tune.Tuner(
cfg.offline_algo,
# "BC",
param_space=algo_config.to_dict(),
tune_config = tune.TuneConfig(
metric="info/learner/default_policy/learner_stats/policy_loss",
mode="min",
),
run_config=run_config,
)
if cfg.infer:
algo_cls = BC if cfg.offline_algo == "BC" else MARWIL
best_result = tuner.get_results().get_best_result()
ckpt = best_result.best_checkpoints[0][0]
bc_model = algo_cls.from_checkpoint(ckpt)
print(f"Restored from checkpoint {ckpt}")
# bc_model.evaluate()
env = make_pod_env(cfg)
while True:
obs, info = env.reset()
done, truncated = False, False
while not done and not truncated:
# action = bc_model.compute_single_action(obfuscate_observation(obs), explore=False)
action = bc_model.compute_single_action((obs), explore=True)
obs, reward, done, truncated, info = env.step(action)
env.render()
else:
# Use to_dict() to get the old-style python config dict
# when running with tune.
result = tuner.fit()
if __name__ == "__main__":
main()