-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
104 lines (82 loc) · 3.02 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import time
import comet_ml
import hydra
from dotenv import load_dotenv
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from roble import ppo, sac
PATH = "./runs"
load_dotenv()
comet_api_key = os.getenv("COMET_API_KEY")
class Log:
def __init__(self, cfg, PATH):
from roble.utils.logging_utils import Logger as TableLogger
self._logger = TableLogger()
self._logger.add_folder_output(folder_name=f"{PATH}")
self._logger.add_tabular_output(file_name=f"{PATH}/log_data.csv")
os.makedirs(PATH, exist_ok=True)
with open(f"{PATH}/conf.yaml", "w") as fd:
fd.write(OmegaConf.to_yaml(cfg))
fd.flush()
def log_dict(self, dico):
for k, v in dico.items():
if isinstance(v, list) and len(v) == 0:
continue
self._logger.record_tabular_misc_stat(k, v)
self._logger.dump_tabular()
@hydra.main(config_path="conf", config_name="config")
def my_main(args: DictConfig):
global PATH
os.chdir(get_original_cwd())
run_name = f"{args.meta.env_id}__{args.meta.exp_name}__{args.meta.add_to_runname}__{args.meta.seed}__{int(time.time())}"
args.meta.run_name = run_name
PATH = f"{PATH}/{args.meta.run_name}"
logger = Log(args, PATH)
if args.meta.track:
# [TODO]
"""
[DEFAULT]
experiment = comet_ml.Experiment(
api_key=your key,
project_name=project name
)
[/DEFAULT]
"""
experiment = comet_ml.Experiment(
api_key=comet_api_key,
project_name="ballbalancingrobot",
workspace="ift6163-project"
)
# [/TODO]
experiment.add_tag("project")
experiment.set_name(args.meta.run_name)
experiment.set_filename(fname="cometML_test")
logger._logger.set_comet_logger(experiment)
def get_arg_dict(args):
dico = dict(vars(args))
return dico["_content"]
def flatten_conf(conf1, conf2):
dico = get_arg_dict(conf1)
dico.update(get_arg_dict(conf2))
args = OmegaConf.create(dico)
return args
sim2real = args.sim2real
env_args = args.environment
new_args = flatten_conf(args.meta, OmegaConf.create({"sim2real": get_arg_dict(sim2real)}))
new_args = flatten_conf(new_args, OmegaConf.create({"environment": get_arg_dict(env_args)}))
if args.meta.sac_instead:
args = flatten_conf(new_args, args.sac)
args.buffer_size = int(args.buffer_size)
sac.train(args, logger, PATH)
elif args.meta.real_application:
ppo.run_ppo_policy(args, PATH)
else:
args = flatten_conf(new_args, args.ppo)
args.target_kl = None if args.target_kl == "None" else args.target_kl
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
ppo.train(args, logger, PATH)
if __name__ == "__main__":
my_main()