-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
64 lines (50 loc) · 1.98 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Written by Ruijia Li (ruijia2017@163.com), UESTC, 2020-12-1.
import argparse
from Functions import ahrl
from Environments.AntPush.maze_env import AntPushEnv
from Environments.PointMaze.maze_env import PointMazeEnv
from Environments.DoubleInvertedPendulum.DoubleInvertedPendulum import DoubleInvertedPendulumEnv
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--env_name", default="PointMaze") # Environment name: PointMaze, DoubleInvertedPendulum or AntPush
parser.add_argument("--max_test_steps", default=5e2, type=int) # Max test steps
args = parser.parse_args()
if args.env_name == "AntPush":
env = AntPushEnv()
elif args.env_name == "PointMaze":
env = PointMazeEnv()
else:
env = DoubleInvertedPendulumEnv()
state = env.reset()
obs = env.reset()
state = obs['state']
state_dim = state.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
maxEpisode_step = env.max_step()
file_name = "%s%s" % (args.env_name, str(50))
policy = ahrl.AHRL(state_dim=state_dim, action_dim=action_dim, scale=max_action, args=args)
policy.load(file_name, "./Results/Example")
Reward = 0
total_step = 0
env_done = False
episode_step = 0
"Test"
while total_step < args.max_test_steps:
#env.render()
if env_done or episode_step == maxEpisode_step:
print(" Reward={}".format(round(Reward)))
obs = env.reset()
state = obs['state']
anchor = obs['achieved_goal']
Reward = 0
achieved = 0
episode_step = 0
episode_step += 1
total_step += 1
action = policy.select_action(state)
next_obs, reward, env_done, _ = env.step(action)
next_state = next_obs['state']
achieved_goal = next_obs['achieved_goal']
Reward = Reward + reward
state = next_state