-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathstart_federated_workers.py
More file actions
executable file
·143 lines (126 loc) · 8.28 KB
/
start_federated_workers.py
File metadata and controls
executable file
·143 lines (126 loc) · 8.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#-----------------------------------------------------------------------------------------------#
# #
# I M P O R T G L O B A L L I B R A R I E S #
# #
#-----------------------------------------------------------------------------------------------#
import sys
import signal
import subprocess
import logging
import asyncio
import websockets
import argparse
from pathlib import Path
#-----------------------------------------------------------------------------------------------#
# #
# I M P O R T L O C A L L I B R A R I E S / F I L E S #
# #
#-----------------------------------------------------------------------------------------------#
#-----------------------------------------------------------------------------------------------#
# #
# Define global parameters. #
# #
#-----------------------------------------------------------------------------------------------#
PYTHON_PATH = Path(sys.executable).name
S_FILE_PATH = Path(__file__).resolve().parents[0].joinpath("run_websocket_worker.py")
PROCESS_LIST = []
#***********************************************************************************************#
# #
# description: #
# starting websocket server for a given client / worker node. Each now worker has a server. #
# #
#***********************************************************************************************#
async def send_local_info(remote_host, remote_port, worker_list):
# send local client information
uri = 'ws://{0}:{1}'.format(remote_host, remote_port)
# print log message
print("establishing connection at {0}".format(uri))
# send information of each worker to the server
for worker in worker_list:
async with websockets.connect(uri) as websocket:
await websocket.send(worker[2])
await websocket.send(worker[0])
await websocket.send("{0}".format(worker[1]))
#***********************************************************************************************#
# #
# description: #
# helper function to create a list of workers assuming consecutive ports are avialable. #
# #
#***********************************************************************************************#
def generate_worker_list(suffix_id, worker_host, starting_port, count, rank):
worker_list = []
for i in range(count):
worker_list.append([worker_host,
starting_port+i,
"{0}_{1}_{2}".format(suffix_id, i, rank+i),
"{0}".format(rank+i)
])
return worker_list
#***********************************************************************************************#
# #
# description: #
# starting websocket server for a given client / worker node. Each now worker has a server. #
# #
#***********************************************************************************************#
def start_federated_workers(worker_list, world):
# create a process call and run all the required workers
for i, worker in enumerate(worker_list):
# create server command
process_call = [PYTHON_PATH,
S_FILE_PATH,
"--host", "{0}".format(worker[0]),
"--port", "{0}".format(worker[1]),
"--id", "{0}".format(worker[2]),
"--rank", "{0}".format(worker[3]),
"--world", "{0}".format(world)]
# run and keep track of worker processes
PROCESS_LIST.append(subprocess.Popen(process_call))
# start the server for new client
print("started a total of {0} workers".format(len(worker_list)))
#***********************************************************************************************#
# #
# description: #
# helper function to forcefully terminate all processes once Ctrl+C is hit. #
# #
#***********************************************************************************************#
def signal_handler(sig, frame):
print("You pressed Ctrl+C!")
for p in PROCESS_LIST:
p.terminate()
sys.exit(0)
#***********************************************************************************************#
# #
# description: #
# argument parsing and configurations for setting up the websocket server. #
# #
#***********************************************************************************************#
if __name__ == "__main__":
# parsing arguments
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument("--remotehost", type=str, default="localhost", help="host addres of remote server.")
parser.add_argument("--remoteport", type=int, help="port number of federated server, e.g. --port 8778", required=True)
parser.add_argument("--host", type=str, default="localhost", help="current and local worker's deployment host.")
parser.add_argument("--port", type=int, help="port number current worker, e.g. --port 8778", required=True)
parser.add_argument("--count", type=int, help="number of workers to instantiate on this machine, e.g. 5", required=True)
parser.add_argument("--rank", type=int, help="the starting rank of workers on this machine", required=True)
parser.add_argument("--world", type=int, help="the total number of workers in the entire federation", required=True)
parser.add_argument("--id", type=str, help="suffix to the name (id) of the websocket workers, e.g. --id vw", required=True)
args = parser.parse_args()
# Logging setup
FORMAT = "%(asctime)s | %(message)s"
logging.basicConfig(format=FORMAT)
logger = logging.getLogger("run_websocket_client")
logger.setLevel(level=logging.DEBUG)
# Websockets setup
websockets_logger = logging.getLogger("websockets")
websockets_logger.setLevel(logging.INFO)
websockets_logger.addHandler(logging.StreamHandler())
# create a worker list
worker_list = generate_worker_list(args.id, args.host, args.port, args.count, args.rank)
# start the required number of workers
start_federated_workers(worker_list, args.world)
# connect to server module and send detailed information
asyncio.run(send_local_info(args.remotehost, args.remoteport, worker_list))
# create a signal
signal.signal(signal.SIGINT, signal_handler)
signal.pause()