-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
33 lines (29 loc) · 1.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from argument import parameters
from trainer import train
from test import pytest
from visualization.plot_curves import learning_curve
import numpy as np
import os
import multiprocessing
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
def main_multiprocessing(args,seeds):
processes = []
for seed in seeds:
args.seed = int(seed)
process = multiprocessing.Process(target=train,args=(args,))
processes.append(process)
process.start()
for process in processes:
process.join()
if __name__ == '__main__':
env_list = ["Swimmer-v2","Reacher-v2","Hopper-v2","HalfCheetah-v2","Walker2d-v2","Humanoid-v2",
'MountainCarContinuous-v0', "LunarLanderContinuous-v2", "BipedalWalker-v3"]
algorithm_list = ['PPO-KL','PPO-Clip','PPO-S','TR-PPO','TR-PPO-SRB','TR-PPO-RB']
args = parameters.get_paras()
if args.plot == True:
learning_curve(args,algorithm_list)
elif args.evaluate == True:
pytest(args)
else:
seeds = np.random.randint(1,1000,args.num_para)
main_multiprocessing(args,seeds)