-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
124 lines (107 loc) · 4.43 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def run(weight_path, x, num_inferences_per_task):
# Limit usable cores of tensorflow
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TF_NUM_INTRAOP_THREADS'] = '1'
os.environ['TF_NUM_INTEROP_THREADS'] = '1'
import tensorflow as tf
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.run_functions_eagerly(True)
from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50(weights=weight_path)
for i in range(num_inferences_per_task):
preds = model.predict(x)
return 1
def run_serverless(x, num_inferences_per_task):
global model
for i in range(num_inferences_per_task):
preds = model.predict(x)
return 1
def context_setup(args):
# Limit usable cores of tensorflow
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TF_NUM_INTRAOP_THREADS'] = '1'
os.environ['TF_NUM_INTEROP_THREADS'] = '1'
import tensorflow as tf
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.run_functions_eagerly(True)
from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50(weights=args['weight_path'])
return {'model': model, 'direct_mode': True}
def main():
import time
start = time.time()
num_inferences_per_task = 16
num_tasks = 10
env_tarball = 'env.tar.gz'
weight_path = 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'
init_command = 'conda run -n lnni '
cores_per_task = 2 # each worker has 4 cores as defined in run_worker.sh
port = 9126
# preprocess input data
from tensorflow.keras.preprocessing import image
import numpy as np
from tensorflow.keras.applications.resnet50 import preprocess_input
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
# get run mode
import sys
mode = sys.argv[1]
if mode == 'local-p':
run(weight_path, x, num_inferences_per_task)
exit(1)
elif mode == 'local-s':
d = context_setup({'weight_path': weight_path})
for k in d:
globals()[k] = d[k]
run_serverless(x, num_inferences_per_task)
exit(2)
else:
import ndcctools.taskvine as vine
import os
cwd = os.getcwd()
q = vine.Manager(port=port, name='nn_exp')
print(f"TaskVine manager listening on port {q.port}")
weight_path_vine_file = q.declare_file(weight_path, cache=True, peer_transfer=True)
if mode == 'remote-p':
env_tarball_vine_file = q.declare_poncho(env_tarball, cache=True, peer_transfer=True)
for i in range(num_tasks):
t = vine.PythonTask(run, weight_path, x, num_inferences_per_task)
t.add_environment(env_tarball_vine_file)
t.add_input(weight_path_vine_file, weight_path)
t.set_cores(cores_per_task)
q.submit(t)
elif mode == 'remote-r':
for i in range(num_tasks):
t = vine.PythonTask(run, cwd+'/'+weight_path, x, num_inferences_per_task)
t.set_cores(cores_per_task)
t.set_command(init_command + t.command)
q.submit(t)
elif mode == 'remote-s':
print("Creating library from functions...")
libtask = q.create_library_from_functions('lib', run_serverless, context=context_setup, context_arg={'weight_path': weight_path}, poncho_env='env.tar.gz')
libtask.add_input(weight_path_vine_file, weight_path)
libtask.set_cores(cores_per_task)
q.install_library(libtask)
print("Submitting function call tasks...")
for i in range(num_tasks):
t = vine.FunctionCall('lib', 'run_serverless', x, num_inferences_per_task)
t.set_exec_method('direct')
q.submit(t)
else:
raise Exception
num_completed_tasks = 0
while not q.empty():
t = q.wait(5)
if t:
num_completed_tasks += t.output
print(f'Task {num_completed_tasks}/{num_tasks} returned with output {t.output}')
print(f'Completed: {num_completed_tasks}/{num_tasks}, mode: {mode}')
print('Elapsed:', time.time() - start)
main()