-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv.py
279 lines (240 loc) · 11.9 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import io
import logging
from collections import OrderedDict
import gym
import numpy as np
from gym import spaces, Env
from gym.spaces import Box
from prettytable import PrettyTable
from random import randint
log = logging.getLogger(__name__)
class ColoringEnv(gym.Env):
emb_pattern_empty = 0
emb_pattern_to_color = 1
agent_pos_code = 1
agent_pos_empty = 0
channel_pattern = 0
channel_stitch = 1
channel_agent = 2
def __init__(self, c, worker_id, start_position, with_step_penalty, with_revisit_penalty,
stay_inside, with_color_reward, total_reward, covered_steps_ratio,
depth_channel_first=True, changing_start_positions=False, as_image=False, color_on_visit=True):
Env.__init__(self)
log.info('creating environment for files {}'.format(c.data_files))
# needed in order to simulate gym environment
self.reward_range = None
self.metadata = {'render.modes': []}
self.spec = None
self.enabled = False
self.observation_space = None
self.c = c
# First channel
# 0 - blank cell
# 1 - pattern cell
# Second channel
# 0 - not stitched
# 1 - stitched
# Third channel
# 0 - no agent
# 1 - agent
self.with_step_penalty = with_step_penalty
self.with_revisit_penalty = with_revisit_penalty
self.stay_inside = stay_inside
self.action_encodings = {0: 'u', 1: 'd', 2: 'l', 3: 'r'}
self.with_color_reward = with_color_reward
self.total_reward = total_reward
self.covered_steps_ratio = covered_steps_ratio
self.inv_action_encodings = {v: k for k, v in self.action_encodings.items()}
self.action_space = spaces.Discrete(len(self.action_encodings))
self.layer_descriptions = OrderedDict([
(ColoringEnv.channel_pattern, 'Pattern'),
(ColoringEnv.channel_stitch, 'Completed pattern'),
(ColoringEnv.channel_agent, 'Agent position'),
])
self.worker_id = worker_id
self.start_position = start_position
self.env_reset_count = 0
self.steps = []
self.emb_pattern_layer = None
self.emb_pattern_count = 0
self.x_dim = None
self.y_dim = None
self.base_observation = None
self.initial_observation = None
self.max_steps = -1
self.step_count = 0
self.data_file = None
self.done = False
self.depth_channel_first = depth_channel_first
self.changing_start_positions = changing_start_positions
self.as_image = as_image
self.color_on_visit = color_on_visit
self.alice_state = None
self.init_uncovered_count = 0
self.reset()
def reset(self, alice_state=None):
self.data_file, x = self.c.load_file(self.env_reset_count)
# TODO make sure all the environments are between [0, 1]
assert x.min() == 0
assert x.max() == 1
self.alice_state = alice_state
self.emb_pattern_layer = x
del x
self.nonzero_indices = [(i, j) for i, j in zip(*np.nonzero(self.emb_pattern_layer))]
self.emb_pattern_count = np.count_nonzero(self.emb_pattern_layer)
log.info('Worker {} loaded the file {} (reset count {})'
.format(self.worker_id, self.data_file, self.env_reset_count))
self.env_reset_count += 1
# TODO do not do that for embroideries > 10x10
self.max_steps = np.prod(self.emb_pattern_layer.shape) * 2 # ie, 28 * 28 * 2
#self.max_steps = 30
self.step_count = 0
self.x_dim, self.y_dim = self.emb_pattern_layer.shape
original_position = self.nonzero_indices[randint(0, len(self.nonzero_indices) - 1)] \
if self.changing_start_positions else self.start_position
#print('new position:', original_position, 'value at the new position:', self.emb_pattern_layer[original_position])
# no .clear() because we do not want to changed returned by reinforce data
self.steps = []
self.done = False
# Base observation
self.base_observation = np.zeros((len(self.layer_descriptions), self.x_dim, self.y_dim), dtype=int)
self.base_observation[ColoringEnv.channel_pattern] = self.emb_pattern_layer.astype(float)
self.base_observation[ColoringEnv.channel_stitch] = np.zeros_like(self.emb_pattern_layer, dtype=int)
self.base_observation[ColoringEnv.channel_agent] = np.zeros_like(self.emb_pattern_layer, dtype=int)
self.base_observation[ColoringEnv.channel_agent][original_position] = ColoringEnv.agent_pos_code
self.initial_observation = np.copy(self.base_observation)
curr_agent_position = self._get_agent_position()
# Bob starts at the states where only the cells covered by Alice are uncolored,
# the other cells (corresponding to the ColoringEnv.channel_pattern) are filled out
if alice_state is not None:
alice_state_channel = alice_state[ColoringEnv.channel_stitch].astype(np.copy(self.emb_pattern_layer).dtype)
diff = self.base_observation[ColoringEnv.channel_pattern] - alice_state_channel
diff[curr_agent_position] = ColoringEnv.emb_pattern_to_color
# Alice has not completed the whole pattern
if not (self.base_observation[ColoringEnv.channel_pattern] == diff).all():
self.base_observation[ColoringEnv.channel_stitch] = diff
#del diff, alice_state_channel
#del alice_state
if self.base_observation[ColoringEnv.channel_pattern][curr_agent_position] == ColoringEnv.emb_pattern_to_color:
self.base_observation[ColoringEnv.channel_stitch][curr_agent_position] = ColoringEnv.emb_pattern_to_color
self.init_uncovered_count = self.emb_pattern_count - self.covered_count(self.base_observation)
#print('init_uncovered_count:', self.init_uncovered_count, 'self.base_observation[ColoringEnv.channel_stitch]:', self.base_observation[ColoringEnv.channel_stitch])
if self.depth_channel_first:
box_shape = (self.base_observation.shape[0], self.base_observation.shape[1], self.base_observation.shape[2])
else:
box_shape = (self.base_observation.shape[1], self.base_observation.shape[2], self.base_observation.shape[0])
self.observation_space = Box(low=0, high=255, shape=(box_shape), dtype=np.uint8)
return self._gen_state(self.base_observation)
def covered_count(self, state):
assert state.shape[0] == len(self.layer_descriptions)
return np.count_nonzero(state[ColoringEnv.channel_stitch])
def seed(self, s):
pass
def _get_agent_position(self):
agent_position = tuple(np.argwhere(self.base_observation[ColoringEnv.channel_agent]
== ColoringEnv.agent_pos_code)[0])
return agent_position
# gets the current position of the agent and colors cells around it more/less intense
def _gen_state(self, obs):
obs = np.copy(obs)
if self.as_image:
obs = obs * 255
if not self.depth_channel_first:
# https://machinelearningmastery.com/a-gentle-introduction-to-channels-first-and-channels-last-image-formats-for-deep-learning/
# tensorflow expects the channel to be the last dimension
# ie (28,28,3) instead of (3, 28, 28)
obs = np.moveaxis(obs, 0, -1)
return obs
def get_rewards(self):
return [s[2] for s in self.steps]
def get_rewards_sum(self):
return sum(self.get_rewards())
def get_stitches(self):
return [s[0] for s in self.steps if 's' == self.action_encodings[s[1]]]
def get_positions_and_actions(self):
return [(s[0], self.action_encodings[s[1]]) for s in self.steps]
# See #layer_descriptions
def step(self, a_t_num):
if self.done:
raise RuntimeError('Calling step when environment is done')
self.step_count += 1
a_t = self.action_encodings[a_t_num]
x_lower_boundary = 0
y_lower_boundary = 0
x_upper_boundary = self.x_dim - 1
y_upper_boundary = self.y_dim - 1
s_tp1 = np.copy(self.base_observation)
r_tp1 = None
def move_agent(step):
old_agent_position = self._get_agent_position()
new_agent_position = tuple(np.array(old_agent_position) + step)
# agent does not leave the the frames OR the pattern cells
if not (x_lower_boundary <= new_agent_position[0] <= x_upper_boundary and
y_lower_boundary <= new_agent_position[1] <= y_upper_boundary and
(not self.stay_inside or s_tp1[ColoringEnv.channel_pattern][new_agent_position] != 0)):
new_agent_position = old_agent_position
# move the agent to the new position
s_tp1[ColoringEnv.channel_agent][old_agent_position] = ColoringEnv.agent_pos_empty
s_tp1[ColoringEnv.channel_agent][new_agent_position] = ColoringEnv.agent_pos_code
ret_val = 0
if self.with_step_penalty:
ret_val += self.c.step_reward
if self.with_revisit_penalty:
if s_tp1[ColoringEnv.channel_stitch][new_agent_position] == ColoringEnv.emb_pattern_to_color:
ret_val += self.c.step_reward
if s_tp1[ColoringEnv.channel_pattern][new_agent_position] == ColoringEnv.emb_pattern_to_color:
if self.with_color_reward and s_tp1[ColoringEnv.channel_stitch][new_agent_position] != ColoringEnv.emb_pattern_to_color:
ret_val -= self.c.step_reward
#ret_val = 1
s_tp1[ColoringEnv.channel_stitch][new_agent_position] = ColoringEnv.emb_pattern_to_color
return ret_val
curr_agent_position = self._get_agent_position()
if a_t == 'u':
r_tp1 = move_agent((-1, 0))
elif a_t == 'd':
r_tp1 = move_agent((+1, 0))
elif a_t == 'l':
r_tp1 = move_agent((0, -1))
elif a_t == 'r':
r_tp1 = move_agent((0, +1))
#print('step a_t', a_t, 'state:', self.base_observation, 's_tp1', s_tp1)
# done stitching if all color cells are stitched
covered_count = self.covered_count(s_tp1)
if self.step_count >= self.max_steps or covered_count == self.emb_pattern_count:
self.done = True
self.base_observation = s_tp1
if self.done:
if self.total_reward:
r_tp1 += self.c.step_reward * self.step_count
if self.covered_steps_ratio:
r_tp1 += covered_count / self.step_count
if r_tp1 is None:
raise RuntimeError('Return cannot be None')
self.steps.append([curr_agent_position, a_t, r_tp1, self.done])
assert self.step_count == len(self.steps)
ret_s_tp1 = self._gen_state(s_tp1)
reduced_ratio = self.init_uncovered_count / self.step_count
assert not self.done or 0 < reduced_ratio <= 1
infos = {
'episode': None,
'steps': self.steps,
'steps_count': self.step_count,
'covered_count': covered_count,
'total_count': self.emb_pattern_count,
'init_uncovered_count': self.init_uncovered_count,
'reduced_ratio': reduced_ratio
}
return ret_s_tp1, r_tp1, self.done, infos
def render(self):
x = PrettyTable(['{}, {} (Layer {})'.format(self.worker_id, v, k) for k, v in self.layer_descriptions.items()])
row = []
for r in self.base_observation:
# https://stackoverflow.com/a/42046765/256002
bio = io.BytesIO()
np.savetxt(bio, r, fmt='%d')
mystr = bio.getvalue().decode('latin1').rstrip('\n')
row.append(mystr)
x.add_row(row)
return x.get_string()
def __str__(self):
return '<{}>'.format(type(self).__name__)