-
Notifications
You must be signed in to change notification settings - Fork 2
/
central_scheduler.py
533 lines (446 loc) · 23 KB
/
central_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
import sys
import time
import json
import uuid
import logging
import requests
from queue import Queue, Empty
from threading import Thread
from collections import defaultdict
from funcx import FuncXClient
from funcx.serialize import FuncXSerializer
from utils import colored, endpoint_name
from transfer import TransferManager
from strategies import init_strategy
from predictors import init_runtime_predictor, TransferPredictor, \
ImportPredictor
logger = logging.getLogger(__name__)
ch = logging.StreamHandler()
ch.setFormatter(logging.Formatter(
colored("[SCHEDULER] %(message)s", 'yellow')))
logger.addHandler(ch)
FUNCX_API = 'https://funcx.org/api/v1'
HEARTBEAT_THRESHOLD = 75.0 # Endpoints send regular heartbeats
CLIENT_ID = 'f06739da-ad7d-40bd-887f-abb1d23bbd6f'
BLOCK_ERRORS = [ModuleNotFoundError, MemoryError]
class CentralScheduler(object):
def __init__(self, endpoints, strategy='round-robin',
runtime_predictor='rolling-average', last_n=3, train_every=1,
log_level='INFO', import_model_file=None,
transfer_model_file=None, sync_level='exists',
max_backups=0, backup_delay_threshold=2.0,
*args, **kwargs):
self._fxc = FuncXClient(*args, **kwargs)
# Initialize a transfer client
self._transfer_manger = TransferManager(endpoints=endpoints,
sync_level=sync_level,
log_level=log_level)
# Info about FuncX endpoints we can execute on
self._endpoints = endpoints
self._dead_endpoints = set()
self.last_result_time = defaultdict(float)
self.temperature = defaultdict(lambda: 'WARM')
self._imports = defaultdict(list)
self._imports_required = defaultdict(list)
# Track which endpoints a function can't run on
self._blocked = defaultdict(set)
# Track pending tasks
# We will provide the client our own task ids, since we may submit the
# same task multiple times to the FuncX service, and sometimes we may
# wait to submit a task to FuncX (e.g., wait for a data transfer).
self._task_id_translation = {}
self._pending = {}
self._pending_by_endpoint = defaultdict(set)
self._task_info = {}
# List of endpoints a (virtual) task was scheduled to
self._endpoints_sent_to = defaultdict(list)
self.max_backups = max_backups
self.backup_delay_threshold = backup_delay_threshold
self._latest_status = {}
self._last_task_ETA = defaultdict(float)
# Maximum ETA, if any, of a task which we allow to be scheduled on an
# endpoint. This is to prevent backfill tasks to be longer than the
# estimated time for when a pending data transfer will finish.
self._transfer_ETAs = defaultdict(dict)
# Estimated error in the pending-task time of an endpoint.
# Updated every time a task result is received from an endpoint.
self._queue_error = defaultdict(float)
# Set logging levels
logger.setLevel(log_level)
self.execution_log = []
# Intialize serializer
self.fx_serializer = FuncXSerializer()
self.fx_serializer.use_custom('03\n', 'code')
# Initialize runtime predictor
self.runtime = init_runtime_predictor(runtime_predictor,
endpoints=endpoints,
last_n=last_n,
train_every=train_every)
logger.info(f"Runtime predictor using strategy {self.runtime}")
# Initialize transfer-time predictor
self.transfer_time = TransferPredictor(endpoints=endpoints,
train_every=train_every,
state_file=transfer_model_file)
# Initialize import-time predictor
self.import_predictor = ImportPredictor(endpoints=endpoints,
state_file=import_model_file)
# Initialize scheduling strategy
self.strategy = init_strategy(strategy, endpoints=endpoints,
runtime_predictor=self.runtime,
queue_predictor=self.queue_delay,
cold_start_predictor=self.cold_start,
transfer_predictor=self.transfer_time)
logger.info(f"Scheduler using strategy {self.strategy}")
# Start thread to check on endpoints regularly
self._endpoint_watchdog = Thread(target=self._check_endpoints)
self._endpoint_watchdog.start()
# Start thread to monitor tasks and send tasks to FuncX service
self._scheduled_tasks = Queue()
self._task_watchdog_sleep = 0.15
self._task_watchdog = Thread(target=self._monitor_tasks)
self._task_watchdog.start()
def block(self, func, endpoint):
if endpoint not in self._endpoints:
logger.error('Cannot block unknown endpoint {}'
.format(endpoint))
return {
'status': 'Failed',
'reason': 'Unknown endpoint {}'.format(endpoint)
}
elif len(self._blocked[func]) == len(self._endpoints) - 1:
logger.error('Cannot block last remaining endpoint {}'
.format(endpoint))
return {
'status': 'Failed',
'reason': 'Cannot block all endpoints for {}'.format(func)
}
else:
logger.info('Blocking endpoint {} for function {}'
.format(endpoint_name(endpoint), func))
self._blocked[func].add(endpoint)
return {'status': 'Success'}
def register_imports(self, func, imports):
logger.info('Registered function {} with imports {}'
.format(func, imports))
self._imports_required[func] = imports
def batch_submit(self, tasks, headers):
# TODO: smarter scheduling for batch submissions
task_ids = []
endpoints = []
for func, payload in tasks:
_, ser_kwargs = self.fx_serializer.unpack_buffers(payload)
kwargs = self.fx_serializer.deserialize(ser_kwargs)
files = kwargs['_globus_files']
task_id, endpoint = self._schedule_task(func=func,
payload=payload,
headers=headers,
files=files)
task_ids.append(task_id)
endpoints.append(endpoint)
return task_ids, endpoints
def _schedule_task(self, func, payload, headers, files,
task_id=None):
# If this is the first time scheduling this task_id
# (i.e., non-backup task), record the necessary metadata
if task_id is None:
# Create (fake) task id to return to client
task_id = str(uuid.uuid4())
# Store task information
self._task_id_translation[task_id] = set()
# Information required to schedule the task, now and in the future
info = {
'function_id': func,
'payload': payload,
'headers': headers,
'files': files,
'time_requested': time.time()
}
self._task_info[task_id] = info
# TODO: do not choose a dead endpoint (reliably)
# exclude = self._blocked[func] | self._dead_endpoints | set(self._endpoints_sent_to[task_id]) # noqa
if len(self._dead_endpoints) > 0:
logger.warn('{} endpoints seem dead. Hope they still work!'
.format(len(self._dead_endpoints)))
exclude = self._blocked[func] | set(self._endpoints_sent_to[task_id])
choice = self.strategy.choose_endpoint(func, payload=payload,
files=files,
exclude=exclude,
transfer_ETAs=self._transfer_ETAs) # noqa
endpoint = choice['endpoint']
logger.info('Choosing endpoint {} for func {}, task id {}'
.format(endpoint_name(endpoint), func, task_id))
choice['ETA'] = self.strategy.predict_ETA(func, endpoint, payload,
files=files)
# Start Globus transfer of required files, if any
if len(files) > 0:
transfer_num = self._transfer_manger.transfer(files, endpoint,
task_id)
if transfer_num is not None:
transfer_ETA = time.time() + self.transfer_time(files, endpoint)
self._transfer_ETAs[endpoint][transfer_num] = transfer_ETA
else:
transfer_num = None
# Record endpoint ETA for queue-delay prediction here,
# since task will be immediately scheduled
self._last_task_ETA[endpoint] = choice['ETA']
# If a cold endpoint is being started, mark it as no longer cold,
# so that subsequent launch-time predictions are correct (i.e., 0)
if self.temperature[endpoint] == 'COLD':
self.temperature[endpoint] = 'WARMING'
logger.info('A cold endpoint {} was chosen; marked as warming.'
.format(endpoint_name(endpoint)))
# Schedule task for sending to FuncX
self._endpoints_sent_to[task_id].append(endpoint)
self._scheduled_tasks.put((task_id, endpoint, transfer_num))
return task_id, endpoint
def translate_task_id(self, task_id):
return self._task_id_translation[task_id]
def log_status(self, real_task_id, data):
if real_task_id not in self._pending:
logger.warn('Ignoring unknown task id {}'.format(real_task_id))
return
task_id = self._pending[real_task_id]['task_id']
func = self._pending[real_task_id]['function_id']
endpoint = self._pending[real_task_id]['endpoint_id']
# Don't overwrite latest status if it is a result/exception
if task_id not in self._latest_status or \
self._latest_status[task_id].get('status') == 'PENDING':
self._latest_status[task_id] = data
if 'result' in data:
result = self.fx_serializer.deserialize(data['result'])
runtime = result['runtime']
name = endpoint_name(endpoint)
logger.info('Got result from {} for task {} with time {}'
.format(name, real_task_id, runtime))
self.runtime.update(self._pending[real_task_id], runtime)
self._pending[real_task_id]['runtime'] = runtime
self._record_completed(real_task_id)
self.last_result_time[endpoint] = time.time()
self._imports[endpoint] = result['imports']
elif 'exception' in data:
exception = self.fx_serializer.deserialize(data['exception'])
try:
exception.reraise()
except Exception as e:
logger.error('Got exception on task {}: {}'
.format(real_task_id, e))
exc_type, _, _ = sys.exc_info()
if exc_type in BLOCK_ERRORS:
self.block(func, endpoint)
self._record_completed(real_task_id)
self.last_result_time[endpoint] = time.time()
elif 'status' in data and data['status'] == 'PENDING':
pass
else:
logger.error('Unexpected status message: {}'.format(data))
def get_status(self, task_id):
if task_id not in self._task_id_translation:
logger.warn('Unknown client task id {}'.format(task_id))
elif len(self._task_id_translation[task_id]) == 0:
return {'status': 'PENDING'} # Task has not been scheduled yet
elif task_id not in self._latest_status:
return {'status': 'PENDING'} # Status has not been queried yet
else:
return self._latest_status[task_id]
def queue_delay(self, endpoint):
# Otherwise, queue delay is the ETA of most recent task,
# plus the estimated error in the ETA prediction.
# Note that if there are no pending tasks on endpoint, no queue delay.
# This is implicit since, in this case, both summands will be 0.
delay = self._last_task_ETA[endpoint] + self._queue_error[endpoint]
return max(delay, time.time())
def _record_completed(self, real_task_id):
info = self._pending[real_task_id]
endpoint = info['endpoint_id']
# If this is the last pending task on this endpoint, reset ETA offset
if len(self._pending_by_endpoint[endpoint]) == 1:
self._last_task_ETA[endpoint] = 0.0
self._queue_error[endpoint] = 0.0
else:
prediction_error = time.time() - self._pending[real_task_id]['ETA']
self._queue_error[endpoint] = prediction_error
# print(colored(f'Prediction error {prediction_error}', 'red'))
info['ATA'] = time.time()
del info['headers']
self.execution_log.append(info)
logger.info('Task exec time: expected = {:.3f}, actual = {:.3f}'
.format(info['ETA'] - info['time_sent'],
time.time() - info['time_sent']))
# logger.info(f'ETA_offset = {self._queue_error[endpoint]:.3f}')
# Stop tracking this task
del self._pending[real_task_id]
self._pending_by_endpoint[endpoint].remove(real_task_id)
if info['task_id'] in self._task_info:
del self._task_info[info['task_id']]
def cold_start(self, endpoint, func):
# If endpoint is warm, there is no launch time
if self.temperature[endpoint] != 'COLD':
launch_time = 0.0
# Otherwise, return the launch time in the endpoint config
elif 'launch_time' in self._endpoints[endpoint]:
launch_time = self._endpoints[endpoint]['launch_time']
else:
logger.warn('Endpoint {} should always be warm, but is cold'
.format(endpoint_name(endpoint)))
launch_time = 0.0
# Time to import dependencies
import_time = 0.0
for pkg in self._imports_required[func]:
if pkg not in self._imports[endpoint]:
logger.debug('Cold-start has import time for pkg {} on {}'
.format(pkg, endpoint_name(endpoint)))
import_time += self.import_predictor(pkg, endpoint)
return launch_time + import_time
def _monitor_tasks(self):
logger.info('Starting task-watchdog thread')
scheduled = {}
while True:
time.sleep(self._task_watchdog_sleep)
# Get newly scheduled tasks
while True:
try:
task_id, end, num = self._scheduled_tasks.get_nowait()
if task_id not in self._task_info:
logger.warn('Task id {} scheduled but no info found'
.format(task_id))
continue
info = self._task_info[task_id]
scheduled[task_id] = dict(info) # Create new copy of info
scheduled[task_id]['task_id'] = task_id
scheduled[task_id]['endpoint_id'] = end
scheduled[task_id]['transfer_num'] = num
except Empty:
break
# Filter out all tasks whose data transfer has not been completed
ready_to_send = set()
for task_id, info in scheduled.items():
transfer_num = info['transfer_num']
if transfer_num is None:
ready_to_send.add(task_id)
info['transfer_time'] = 0.0
elif self._transfer_manger.is_complete(transfer_num):
ready_to_send.add(task_id)
del self._transfer_ETAs[info['endpoint_id']][transfer_num]
info['transfer_time'] = self._transfer_manger.get_transfer_time(transfer_num) # noqa
else: # This task cannot be scheduled yet
continue
if len(ready_to_send) == 0:
logger.debug('No new tasks to send. Task watchdog sleeping...')
continue
# TODO: different clients send different headers. change eventually
headers = list(scheduled.values())[0]['headers']
logger.info('Scheduling a batch of {} tasks'
.format(len(ready_to_send)))
# Submit all ready tasks to FuncX
data = {'tasks': []}
for task_id in ready_to_send:
info = scheduled[task_id]
submit_info = (info['function_id'], info['endpoint_id'],
info['payload'])
data['tasks'].append(submit_info)
res_str = requests.post(f'{FUNCX_API}/submit', headers=headers,
data=json.dumps(data))
try:
res = res_str.json()
except ValueError:
logger.error(f'Could not parse JSON from {res_str.text}')
continue
if res['status'] != 'Success':
logger.error('Could not send tasks to FuncX. Got response: {}'
.format(res))
continue
# Update task info with submission info
for task_id, real_task_id in zip(ready_to_send, res['task_uuids']):
info = scheduled[task_id]
# This ETA calculation does not take into account transfer time
# since, at this point, the transfer has already completed.
info['ETA'] = self.strategy.predict_ETA(info['function_id'],
info['endpoint_id'],
info['payload'])
# Record if this ETA prediction is "reliable". If it is not
# (e.g., when we have not learned about this (func, ep) pair),
# backup tasks will not be sent for this task if it is delayed.
info['is_ETA_reliable'] = self.runtime.has_learned(
info['function_id'], info['endpoint_id'])
info['time_sent'] = time.time()
endpoint = info['endpoint_id']
self._task_id_translation[task_id].add(real_task_id)
self._pending[real_task_id] = info
self._pending_by_endpoint[endpoint].add(real_task_id)
# Record endpoint ETA for queue-delay prediction
self._last_task_ETA[endpoint] = info['ETA']
logger.info('Sent task id {} to {} with real task id {}'
.format(task_id, endpoint_name(endpoint),
real_task_id))
# Stop tracking all newly sent tasks
for task_id in ready_to_send:
del scheduled[task_id]
def _check_endpoints(self):
logger.info('Starting endpoint-watchdog thread')
while True:
for end in self._endpoints.keys():
statuses = self._fxc.get_endpoint_status(end)
if len(statuses) == 0:
logger.warn('Endpoint {} does not have any statuses'
.format(endpoint_name(end)))
else:
status = statuses[0] # Most recent endpoint status
# Mark endpoint as dead/alive based on heartbeat's age
# Heartbeats are delayed when an endpoint is executing
# tasks, so take into account last execution too
age = time.time() - max(status['timestamp'],
self.last_result_time[end])
is_dead = end in self._dead_endpoints
if not is_dead and age > HEARTBEAT_THRESHOLD:
self._dead_endpoints.add(end)
logger.warn('Endpoint {} seems to have died! '
'Last heartbeat was {:.2f} seconds ago.'
.format(endpoint_name(end), age))
elif is_dead and age <= HEARTBEAT_THRESHOLD:
self._dead_endpoints.remove(end)
logger.warn('Endpoint {} is back alive! '
'Last heartbeat was {:.2f} seconds ago.'
.format(endpoint_name(end), age))
# Mark endpoint as "cold" or "warm" depending on if it
# has active managers (nodes) allocated to it
if self.temperature[end] == 'WARM' \
and status['active_managers'] == 0:
self.temperature[end] = 'COLD'
logger.info('Endpoint {} is cold!'
.format(endpoint_name(end)))
elif self.temperature[end] != 'WARM' \
and status['active_managers'] > 0:
self.temperature[end] = 'WARM'
logger.info('Endpoint {} is warm again!'
.format(endpoint_name(end)))
# Send backup tasks if needed
self._send_backups_if_needed()
# Sleep before checking statuses again
time.sleep(5)
def _send_backups_if_needed(self):
# Get all tasks which have not been completed yet and still have a
# pending (real) task on a dead endpoint
task_ids = {
self._pending[real_task_id]['task_id']
for endpoint in self._dead_endpoints
for real_task_id in self._pending_by_endpoint[endpoint]
if self._pending[real_task_id]['task_id'] in self._task_info
}
# Get all tasks for which we had ETA-predictions but haven't
# been completed even past their ETA
for real_task_id, info in self._pending.items():
# If the predicted ETA wasn't reliable, don't send backups
if not info['is_ETA_reliable']:
continue
expected = info['ETA'] - info['time_sent']
elapsed = time.time() - info['time_sent']
if elapsed / expected > self.backup_delay_threshold:
task_ids.add(info['task_id'])
for task_id in task_ids:
if len(self._endpoints_sent_to[task_id]) > self.max_backups:
logger.debug(f'Skipping sending new backup task for {task_id}')
else:
logger.info(f'Sending new backup task for {task_id}')
info = self._task_info[task_id]
self._schedule_task(info['function_id'], info['payload'],
info['headers'], info['files'], task_id)