-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
119 lines (92 loc) · 3.6 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gymnasium as gym
import numpy as np
from collections import deque
import cv2
class RepeatActionAndMaxFrame(gym.Wrapper):
def __init__(self, env, repeat=4, clip_reward=True, no_ops=0, fire_first=False):
super(RepeatActionAndMaxFrame, self).__init__(env)
self.env = env
self.repeat = repeat
self.clip_reward = clip_reward
self.no_ops = no_ops
self.fire_first = fire_first
self.frame_buffer = np.zeros(
(2, *self.env.observation_space.shape), dtype=np.float32
)
def step(self, action):
total_reward = 0
term, trunc = False, False
for i in range(self.repeat):
state, reward, term, trunc, info = self.env.step(action)
if self.clip_reward:
reward = np.clip(reward, -1, 1)
total_reward += reward
self.frame_buffer[i % 2] = state
if term or trunc:
break
# max_frame = np.max(self.frame_buffer, axis=0)
max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1])
return max_frame, total_reward, term, trunc, info
def reset(self, seed=None, options=None):
state, info = self.env.reset(seed=seed, options=options)
no_ops = np.random.randint(self.no_ops) + 1 if self.no_ops > 0 else 0
for _ in range(no_ops):
_, _, term, trunc, info = self.env.step(0)
if term or trunc:
_, _ = self.env.reset()
if self.fire_first:
assert self.env.unwrapped.get_action_meanings()[1] == "FIRE"
state, _, _, _, _ = self.env.step(1)
self.frame_buffer = np.zeros(
(2, *self.env.observation_space.shape), dtype=np.float32
)
self.frame_buffer[0] = state
return state, info
class PreprocessFrame(gym.ObservationWrapper):
def __init__(self, env, shape=(84, 84)):
super(PreprocessFrame, self).__init__(env)
self.shape = shape
self.observation_space = gym.spaces.Box(0.0, 1.0, self.shape, dtype=np.float32)
def observation(self, state):
state = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
state = cv2.resize(state, self.shape, interpolation=cv2.INTER_AREA)
return state / 255.0
class StackFrames(gym.ObservationWrapper):
def __init__(self, env, size=4):
super(StackFrames, self).__init__(env)
self.size = int(size)
self.stack = deque([], maxlen=self.size)
shape = self.env.observation_space.shape
self.observation_space = gym.spaces.Box(
0.0, 1.0, (self.size, *shape), dtype=np.float32
)
def reset(self, seed=None, options=None):
state, info = self.env.reset(seed=seed, options=options)
self.stack = deque([state] * self.size, maxlen=self.size)
return np.array(self.stack), info
def observation(self, state):
self.stack.append(state)
return np.array(self.stack)
class AtariEnv:
def __init__(
self,
env,
shape=(84, 84),
repeat=4,
clip_rewards=False,
no_ops=0,
fire_first=False,
):
self.env = gym.make(env, render_mode="rgb_array")
self.env = RepeatActionAndMaxFrame(
self.env, repeat, clip_rewards, no_ops, fire_first
)
self.env = PreprocessFrame(self.env, shape)
self.env = StackFrames(self.env, repeat)
def make(self):
return self.env
if __name__ == "__main__":
env = AtariEnv("ALE/Pong-v5").make()
state, _ = env.reset()
print("Expected Shape:", env.observation_space.shape)
print("Actual Shape:", state.shape)