-
Notifications
You must be signed in to change notification settings - Fork 0
/
wrappers.py
executable file
·117 lines (98 loc) · 3.58 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
from collections import deque
import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def reset(self, **kwargs):
self.env.reset(**kwargs)
noops = self.unwrapped.np_random.randint(1, self.noop_max+1)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs
def step(self, a):
return self.env.step(a)
class FireResetEnv(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.step(2)
if done:
self.env.reset(**kwargs)
return obs
def step(self, a):
return self.env.step(a)
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
gym.Wrapper.__init__(self, env)
self._skip = skip
self._obs = deque([], maxlen=2)
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def step(self, a):
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(a)
if i == self._skip - 2: self._obs.append(obs)
if i == self._skip - 1: self._obs.append(obs)
total_reward += reward
if done:
break
max_obs = np.array(self._obs).max(axis=0)
return max_obs, total_reward, done, info
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
def reset(self, **kwargs):
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs
def step(self, a):
obs, reward, done, info = self.env.step(a)
self.was_real_done = done
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
done = True
self.lives = lives
return obs, reward, done, info
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.width = 84
self.height = 84
self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
def observation(self, frame):
frame = frame.astype('float32')
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = frame.astype('uint8')
frame = cv2.resize(frame, (self.height, self.width), interpolation=cv2.INTER_AREA)
return frame
def make_env(env_id, noop_max, skip=4):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=noop_max)
env = MaxAndSkipEnv(env, skip=skip)
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
return env