-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining_worker.py
105 lines (90 loc) · 4.01 KB
/
training_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
'''
The worker class. The workers only execute the commands from the master.
'''
from __future__ import print_function
import math
import subprocess
import time
from constants import WorkerInstruction
class TrainingWorker:
def __init__(self, comm, master_rank, target_model_class):
self.worker_graphs = []
self.is_expolore_only = False
self.rank = comm.Get_rank()
self.comm = comm
self.master_rank = master_rank
self.target_model_class = target_model_class
self.train_time = 0
self.explore_time = 0
def main_loop(self):
while True:
data = self.comm.recv(source=self.master_rank)
inst = data[0]
if inst == WorkerInstruction.ADD_GRAPHS:
hparam_list = data[1]
cluster_id_begin = data[2]
self.is_expolore_only = data[3]
self.add_graphs(hparam_list, cluster_id_begin)
elif inst == WorkerInstruction.TRAIN:
num_steps = data[1]
total_epochs = data[2]
self.train(num_steps, total_epochs)
elif inst == WorkerInstruction.GET:
self.comm.send(self.get_all_values(), dest=self.master_rank)
elif inst == WorkerInstruction.SET:
vars_to_set = data[1]
self.set_values(vars_to_set)
elif inst == WorkerInstruction.EXPLORE:
self.explore_necessary_graphs()
elif inst == WorkerInstruction.GET_PROFILING_INFO:
self.comm.send([self.train_time, self.explore_time], dest=self.master_rank)
elif inst == WorkerInstruction.EXIT:
break
else:
print('Invalid instruction!!!!')
def add_graphs(self, hparam_list, id_begin):
cluster_id_end = id_begin + len(hparam_list)
print('[{}]Got {} hparams'.format(self.rank, len(hparam_list)))
for i in range(id_begin, cluster_id_end):
hparam = hparam_list[i - id_begin]
new_graph = self.target_model_class(i, hparam, './savedata/model_')
self.worker_graphs.append(new_graph)
def train(self, num_epoches, total_epochs):
train_begin_time = time.time()
graphs_to_remove = []
for g in self.worker_graphs:
#g.train(num_epoches, total_epochs)
#print('Model {} epoch = {}, acc = {}'.format(g.cluster_id, g.epoches_trained, g.get_accuracy()))
try:
g.train(num_epoches, total_epochs)
print('Model {} epoch = {}, acc = {}'.format(g.cluster_id, g.epoches_trained, g.get_accuracy()))
if math.isnan(g.get_accuracy()) == True:
graphs_to_remove.append(g)
subprocess.call(['rm', '-rf', 'savedata/model_' + str(g.cluster_id)])
print('Error occured , graph {} removed'.format(g.cluster_id))
except:
graphs_to_remove.append(g)
subprocess.call(['rm', '-rf', 'savedata/model_' + str(g.cluster_id)])
print('Error occured , graph {} removed'.format(g.cluster_id))
for i in graphs_to_remove:
self.worker_graphs.remove(i)
self.train_time += time.time() - train_begin_time
def get_all_values(self):
vars_to_send = []
for g in self.worker_graphs:
vars_to_send.append(g.get_values())
return vars_to_send
def set_values(self, values_to_set):
for v in values_to_set:
for g in self.worker_graphs:
if g.cluster_id == v[0]:
g.set_values(v)
g.need_explore = True
def explore_necessary_graphs(self):
explore_begin_time = time.time()
for g in self.worker_graphs:
if g.need_explore or self.is_expolore_only:
print('[{}]Exploring graph {}'.format(self.rank, g.cluster_id))
g.perturb_hparams()
g.need_explore = False
self.explore_time += time.time() - explore_begin_time