forked from p4lang/p4app-switchML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrpc_server.py
382 lines (305 loc) · 13.9 KB
/
grpc_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# Copyright 2021 Intel-KAUST-Microsoft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import asyncio
import ipaddress
import threading
import switchml_pb2
import switchml_pb2_grpc
from grpc import aio
from concurrent import futures
from common import PacketSize
class GRPCServer(switchml_pb2_grpc.SessionServicer,
switchml_pb2_grpc.SyncServicer):
def __init__(self, ip='[::]', port=50099, folded_pipe=False):
self.log = logging.getLogger(__name__)
self.ip = ip
self.port = port
self.folded_pipe = folded_pipe
# Event to stop the server
self._stopped = asyncio.Event()
async def _serve(self, controller):
''' Server task '''
# Setup server
self._server = aio.server()
switchml_pb2_grpc.add_SessionServicer_to_server(self, self._server)
switchml_pb2_grpc.add_SyncServicer_to_server(self, self._server)
self._server.add_insecure_port('{}:{}'.format(self.ip, self.port))
# Lock to synchronize Barrier/Broadcast in case of reset
self.lock = threading.RLock()
## Barrier
# Incrementing operation id
self._barrier_op_id = 0
# Worker counters and release events
self._barrier_ctrs = {self._barrier_op_id: 0}
self._barrier_events = {self._barrier_op_id: asyncio.Event()}
## Broadcast
# Op values, bitmap and release events
self._bcast_values = []
self._bcast_bitmap = []
self._bcast_events = []
# Controller
self.ctrl = controller
# Start gRPC server
await self._server.start()
def reset(self):
''' Reset broadcast and barrier state '''
with self.lock:
## Barrier
self._barrier_op_id = 0
self._barrier_ctrs = {self._barrier_op_id: 0}
self._barrier_events = {self._barrier_op_id: asyncio.Event()}
## Broadcast
self._bcast_values = []
self._bcast_bitmap = []
self._bcast_events = []
def run(self, loop, controller):
''' Run the gRPC server '''
# Submit gRPC server task
loop.create_task(self._serve(controller))
# Run event loop
loop.run_until_complete(self._stopped.wait())
# Stop gRPC server
if self._server:
loop.run_until_complete(self._server.stop(None))
loop.run_until_complete(self._server.wait_for_termination())
def stop(self):
''' Stop the gRPC server '''
# Stop event loop
self._stopped.set()
async def Barrier(self, request, context):
''' Barrier method.
All the requests for the same session ID return at the same time
only when all the requests are received.
'''
with self.lock:
# Increment counter for this operation
self._barrier_ctrs[self._barrier_op_id] += 1
if self._barrier_ctrs[self._barrier_op_id] < request.num_workers:
# Barrier incomplete
tmp_id = self._barrier_op_id
# Wait for completion event
await self._barrier_events[tmp_id].wait()
# Decrement counter and delete entries for this operation
# once all are released
self._barrier_ctrs[tmp_id] -= 1
if self._barrier_ctrs[tmp_id] == 0:
del self._barrier_ctrs[tmp_id]
del self._barrier_events[tmp_id]
else:
# This completes the barrier -> release
self._barrier_events[self._barrier_op_id].set()
self._barrier_ctrs[self._barrier_op_id] -= 1
# Create entries for next operation
self._barrier_op_id += 1
self._barrier_ctrs[self._barrier_op_id] = 0
self._barrier_events[self._barrier_op_id] = asyncio.Event()
return switchml_pb2.BarrierResponse()
async def Broadcast(self, request, context):
''' Broadcast method.
The value received from the root (with rank = root) is sent back
to all the participants.
The requests received before the one from the root are kept on hold
and released when the request from the root is received.
The ones received afterwards return immediately.
'''
with self.lock:
# Remove old operations
for idx in range(len(self._bcast_bitmap)):
if all(self._bcast_bitmap[idx]):
del self._bcast_bitmap[idx]
del self._bcast_values[idx]
del self._bcast_events[idx]
#Scan bitmap
idx = -1
for idx in range(len(self._bcast_bitmap)):
if not self._bcast_bitmap[idx][request.rank]:
break
if idx == -1 or self._bcast_bitmap[idx][request.rank]:
# If there is no operation pending with the bit 0 for this worker
# then this is a new operation
idx += 1
self._bcast_bitmap.append(
[False for _ in range(request.num_workers)])
self._bcast_values.append(None)
self._bcast_events.append(asyncio.Event())
if request.rank == request.root:
# Root: write value and release
self._bcast_values[idx] = request.value
self._bcast_events[idx].set()
# Set bit for this worker
self._bcast_bitmap[idx][request.rank] = True
else:
# Non-root
if self._bcast_values[idx] is None:
# Value not available yet
await self._bcast_events[idx].wait()
# Set bit for this worker (after waiting)
self._bcast_bitmap[idx][request.rank] = True
return switchml_pb2.BroadcastResponse(value=self._bcast_values[idx])
def RdmaSession(self, request, context):
''' RDMA session setup '''
# Convert MAC to string
mac_hex = '{:012X}'.format(request.mac)
mac_str = ':'.join(mac_hex[i:i + 2] for i in range(0, len(mac_hex), 2))
# Convert IP to string
ipv4_str = str(ipaddress.ip_address(request.ipv4))
self.log.debug(
'# RDMA:\n Session ID: {}\n Rank: {}\n Num workers: {}\n MAC: {}\n'
' IP: {}\n Rkey: {}\n Pkt size: {}B\n Msg size: {}B\n QPs: {}\n'
' PSNs: {}\n'.format(
request.session_id, request.rank, request.num_workers, mac_str,
ipv4_str, request.rkey,
str(PacketSize(request.packet_size)).split('.')[1][4:],
request.message_size, request.qpns, request.psns))
if not self.folded_pipe and PacketSize(
request.packet_size) == PacketSize.MTU_1024:
self.log.error(
"Processing 1024B per packet requires a folded pipeline. Using 256B payload."
)
request.packet_size = int(PacketSize.MTU_256)
if not self.ctrl:
# This is a test, return the received parameters
return switchml_pb2.RdmaSessionResponse(
session_id=request.session_id,
mac=request.mac,
ipv4=request.ipv4,
rkey=request.rkey,
qpns=request.qpns,
psns=request.psns)
if request.rank == 0:
# This is the first message, clear out old workers state
self.ctrl.clear_rdma_workers(request.session_id)
# Add new worker
success, error_msg = self.ctrl.add_rdma_worker(
request.session_id, request.rank, request.num_workers, mac_str,
ipv4_str, request.rkey, request.packet_size, request.message_size,
zip(request.qpns, request.psns))
if not success:
self.log.error(error_msg)
#TODO return error message
return switchml_pb2.RdmaSessionResponse(session_id=0,
mac=0,
ipv4=0,
rkey=0,
qpns=[],
psns=[])
# Get switch addresses
switch_mac, switch_ipv4 = self.ctrl.get_switch_mac_and_ip()
switch_mac = int(switch_mac.replace(':', ''), 16)
switch_ipv4 = int(ipaddress.ip_address(switch_ipv4))
# Mirror this worker's rkey, since the switch doesn't care
switch_rkey = request.rkey
# Switch QPNs are used for two purposes:
# 1. Indexing into the PSN registers
# 2. Differentiating between processes running on the same server
#
# Additionally, there are two restrictions:
#
# 1. In order to make debugging easier, we should
# avoid QPN 0 (sometimes used for management) and QPN
# 0xffffff (sometimes used for multicast) because
# Wireshark decodes them improperly, even when the NIC
# treats them properly.
#
# 2. Due to the way the switch sends aggregated
# packets that are part of a message, only one message
# should be in flight at a time on a given QPN to
# avoid reordering packets. The clients will take care
# of this as long as we give them as many QPNs as they
# give us.
#
# Thus, we construct QPNs as follows.
# - Bit 23 is always 1. This ensures we avoid QPN 0.
# - Bits 22 through 16 are the rank of the
# client. Since we only support 32 clients per
# aggregation in the current design, we will never
# use QPN 0xffffff.
# - Bits 15 through 0 are just the index of the queue;
# if 4 queues are requested, these bits will
# represent 0, 1, 2, and 3.
#
# So if a client with rank 3 sends us a request with 4
# QPNs, we will reply with QPNs 0x830000, 0x830001,
# 0x830002, and 0x830003.
switch_qpns = [
0x800000 | (request.rank << 16) | i
for i, _ in enumerate(request.qpns)
]
# Initial PSNs don't matter; they're overwritten by each _FIRST or _ONLY packet.
switch_psns = [i for i, _ in enumerate(request.qpns)]
return switchml_pb2.RdmaSessionResponse(session_id=request.session_id,
mac=switch_mac,
ipv4=switch_ipv4,
rkey=switch_rkey,
qpns=switch_qpns,
psns=switch_psns)
def UdpSession(self, request, context):
''' UDP session setup '''
# Convert MAC to string
mac_hex = '{:012X}'.format(request.mac)
mac_str = ':'.join(mac_hex[i:i + 2] for i in range(0, len(mac_hex), 2))
# Convert IP to string
ipv4_str = str(ipaddress.ip_address(request.ipv4))
self.log.debug(
'# UDP:\n Session ID: {}\n Rank: {}\n Num workers: {}\n MAC: {}\n'
' IP: {}\n Pkt size: {}\n'.format(request.session_id, request.rank,
request.num_workers, mac_str,
ipv4_str, request.packet_size))
if not self.ctrl:
# This is a test, return the received parameters
return switchml_pb2.UdpSessionResponse(
session_id=request.session_id,
mac=request.mac,
ipv4=request.ipv4)
if request.rank == 0:
# This is the first message, clear out old workers state
self.ctrl.clear_udp_workers(request.session_id)
# Add new worker
success, error_msg = self.ctrl.add_udp_worker(request.session_id,
request.rank,
request.num_workers,
mac_str, ipv4_str)
if not success:
self.log.error(error_msg)
#TODO return error message
return switchml_pb2.UdpSessionResponse(session_id=0, mac=0, ipv4=0)
# Get switch addresses
switch_mac, switch_ipv4 = self.ctrl.get_switch_mac_and_ip()
switch_mac = int(switch_mac.replace(':', ''), 16)
switch_ipv4 = int(ipaddress.ip_address(switch_ipv4))
return switchml_pb2.UdpSessionResponse(session_id=request.session_id,
mac=switch_mac,
ipv4=switch_ipv4)
if __name__ == '__main__':
# Set up gRPC server
grpc_server = GRPCServer()
# Run event loop for gRPC server in a separate thread
with futures.ThreadPoolExecutor(max_workers=1) as executor:
loop = asyncio.get_event_loop()
future = executor.submit(grpc_server.run, loop, None)
try:
# Busy wait
while True:
pass
except KeyboardInterrupt:
print('\nExiting...')
finally:
# Stop gRPC server and event loop
loop.call_soon_threadsafe(grpc_server.stop)
# Wait for thread to end
future.result()
loop.close()
# Flush log
logging.shutdown()