diff --git a/forch/__main__.py b/forch/__main__.py index 2762498cb..ea840ac6c 100644 --- a/forch/__main__.py +++ b/forch/__main__.py @@ -4,6 +4,7 @@ import functools import os import sys +import multiprocessing as mp import forch.faucet_event_client from forch.forchestrator import Forchestrator @@ -100,4 +101,5 @@ def main(): if __name__ == '__main__': + mp.set_start_method('spawn') main() diff --git a/forch/device_report_client.py b/forch/device_report_client.py index fe81f91d5..e17581deb 100644 --- a/forch/device_report_client.py +++ b/forch/device_report_client.py @@ -1,10 +1,13 @@ """Server to handle incoming session requests""" +import sys +import signal import threading +import multiprocessing as mp + import grpc import forch.endpoint_handler as endpoint_handler - from forch.proto.shared_constants_pb2 import PortBehavior from forch.proto.devices_state_pb2 import DevicesState from forch.base_classes import DeviceStateReporter @@ -36,10 +39,12 @@ def __init__(self, result_handler, target, unauth_vlan, tunnel_ip): self._logger = get_logger('devreport') self._logger.info('Initializing with unauthenticated vlan %s', unauth_vlan) self._logger.info('Using target %s, proto %s', target, bool(PORT_BEHAVIOR_SESSION_RESULT)) - self._channel = grpc.insecure_channel(target) - self._stub = None + self._target = target self._dp_mac_map = {} - self._mac_sessions = {} + self._mac_session_processes = {} + self._progress_q = mp.Queue() + self._progress_q_thread = None + self._mac_device_vlan_map = {} self._mac_assigned_vlan_map = {} self._unauth_vlan = unauth_vlan @@ -50,30 +55,48 @@ def __init__(self, result_handler, target, unauth_vlan, tunnel_ip): def start(self): """Start the client handler""" - grpc.channel_ready_future(self._channel).result(timeout=CONNECT_TIMEOUT_SEC) - self._stub = SessionServerStub(self._channel) + # Context may be set already + try: + mp.set_start_method('spawn') + except RuntimeError: + pass + self._progress_q_thread = threading.Thread(target=self._process_progress_q) + self._progress_q_thread.start() def stop(self): """Stop client handler""" - - def _connect(self, mac, vlan, assigned): - self._logger.info('Connecting %s to %s/%s', mac, vlan, assigned) + self._progress_q.put((None, None)) + if self._progress_q_thread: + self._progress_q_thread.join() + + # pylint: disable=too-many-arguments + @classmethod + def _connect(cls, mac, vlan, assigned, target, tunnel_ip, progress_q): + channel = grpc.insecure_channel(target, options=(('grpc.so_reuseport', 0),)) + grpc.channel_ready_future(channel).result(timeout=CONNECT_TIMEOUT_SEC) + stub = SessionServerStub(channel) session_params = SessionParams() session_params.device_mac = mac session_params.device_vlan = vlan session_params.assigned_vlan = assigned - session_params.endpoint.ip = self._tunnel_ip or DEFAULT_SERVER_ADDRESS - session = self._stub.StartSession(session_params) - thread = threading.Thread(target=lambda: self._process_progress(mac, session)) - thread.start() - return session + session_params.endpoint.ip = tunnel_ip + session = stub.StartSession(session_params) + + def terminate(*args): + session.cancel() + progress_q.put((mac, None)) + sys.exit() + signal.signal(signal.SIGTERM, terminate) + for progress in session: + progress_q.put((mac, progress)) + progress_q.put((mac, None)) def disconnect(self, mac): with self._lock: - session = self._mac_sessions.get(mac) - if session: - session.cancel() - self._mac_sessions.pop(mac) + process = self._mac_session_processes.get(mac) + if process: + process.terminate() + self._mac_session_processes.pop(mac) self._logger.info('Device %s disconnected', mac) else: self._logger.warning('Attempt to disconnect unconnected device %s', mac) @@ -99,18 +122,20 @@ def _convert_and_handle(self, mac, progress): self._endpoint_handler.process_endpoint(progress.endpoint) return False - def _process_progress(self, mac, session): - try: - for progress in session: - if self._convert_and_handle(mac, progress): - break - self._logger.info('Progress complete for %s', mac) - except Exception as e: - self._logger.error('Progress exception: %s', e) - self.disconnect(mac) + def _process_progress_q(self): + while True: + mac, progress = self._progress_q.get(block=True) + if not mac: # device client shutdown + break + try: + if not progress or self._convert_and_handle(mac, progress): + self._logger.info('Progress complete for %s', mac) + self.disconnect(mac) + except Exception as e: + self._logger.error('Progress exception for %s: %s', mac, e) def _process_session_ready(self, mac): - if mac in self._mac_sessions: + if mac in self._mac_session_processes: self._logger.info('Ignoring b/c existing session %s', mac) return device_vlan = self._mac_device_vlan_map.get(mac) @@ -119,7 +144,11 @@ def _process_session_ready(self, mac): good_device_vlan = device_vlan and device_vlan not in (self._unauth_vlan, assigned_vlan) if assigned_vlan and good_device_vlan: - self._mac_sessions[mac] = self._connect(mac, device_vlan, assigned_vlan) + self._logger.info('Connecting %s to %s/%s', mac, device_vlan, assigned_vlan) + args = (mac, device_vlan, assigned_vlan, self._target, + self._tunnel_ip or DEFAULT_SERVER_ADDRESS, self._progress_q) + self._mac_session_processes[mac] = mp.Process(target=self._connect, args=args) + self._mac_session_processes[mac].start() def process_port_state(self, dp_name, port, state): """Process faucet port state events"""