-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
64 lines (51 loc) · 1.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import cv2
import gym
import numpy as np
import torch
def run_test_episode(model, env, device, max_steps=1000):
frames = []
obs = env.reset()
frames.append(env.frame)
idx = 0
done = False
reward = 0
while not done and idx < max_steps:
action = model(torch.Tensor(obs).unsqueeze(0).to(device)).max(-1)[-1].item()
obs, r, done, _ = env.step(action)
reward += r
frames.append(env.frame)
idx += 1
return reward, np.stack(frames, 0)
class FrameStackingAndResizingEnv(gym.Wrapper):
def __init__(self, env, w, h, num_stack=4):
self.env = env
self.n = num_stack
self.w = w
self.h = h
self.buffer = np.zeros((num_stack, h, w), 'uint8')
self.frame = None
def _preprocess_frame(self, frame):
image = cv2.resize(frame, (self.w, self.h))
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
return image
def step(self, action):
im, reward, done, info = self.env.step(action)
self.frame = im.copy()
im = self._preprocess_frame(im)
self.buffer[1:self.n, :, :] = self.buffer[0:self.n-1, :, :]
self.buffer[0, :, :] = im
return self.buffer.copy(), reward, done, info
@property
def observation_space(self):
return np.zeros((self.n, self.h, self.w))
@property
def action_space(self):
return self.env.action_space
def reset(self):
im = self.env.reset()
self.frame = im.copy()
im = self._preprocess_frame(im)
self.buffer = np.stack([im]*self.n, 0)
return self.buffer.copy()
def render(self, mode):
self.env.render(mode)