From 1c8eee9b67c0ef9a3bab299059f8a85067d84872 Mon Sep 17 00:00:00 2001 From: Peter Saveliev Date: Tue, 26 Jul 2022 18:04:20 +0200 Subject: [PATCH 1/5] nlsocket: recv_all(), recv_all_into() --- pyroute2/netlink/nlsocket.py | 27 ++++++++++++++- tests/test_linux/test_ipr/test_compile.py | 40 ++++++++++++++++++++++- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/pyroute2/netlink/nlsocket.py b/pyroute2/netlink/nlsocket.py index 5a584488a..e0bcf48a5 100644 --- a/pyroute2/netlink/nlsocket.py +++ b/pyroute2/netlink/nlsocket.py @@ -91,7 +91,14 @@ import time import traceback import warnings -from socket import MSG_PEEK, SO_RCVBUF, SO_SNDBUF, SOCK_DGRAM, SOL_SOCKET +from socket import ( + MSG_DONTWAIT, + MSG_PEEK, + SO_RCVBUF, + SO_SNDBUF, + SOCK_DGRAM, + SOL_SOCKET, +) from pyroute2 import config from pyroute2.common import DEFAULT_RCVBUF, AddrPool @@ -617,6 +624,24 @@ def async_recv(self): else: return + def recv_all(self, bufsize=NL_BUFSIZE): + buffers = [] + while True: + try: + buffers.append(self._sock.recv(bufsize, MSG_DONTWAIT)) + except BlockingIOError: + return buffers + + def recv_all_into(self, buffer): + count = 0 + while True: + page = buffer.get_free_page() + try: + count += self._sock.recv_into(page.view, 0, MSG_DONTWAIT) + except BlockingIOError: + page.free() + return count + def compile(self): return CompileContext(self) diff --git a/tests/test_linux/test_ipr/test_compile.py b/tests/test_linux/test_ipr/test_compile.py index bbad08683..1772eec2a 100644 --- a/tests/test_linux/test_ipr/test_compile.py +++ b/tests/test_linux/test_ipr/test_compile.py @@ -3,7 +3,8 @@ import pytest from pyroute2 import IPRoute -from pyroute2.netlink import NLM_F_DUMP, NLM_F_REQUEST +from pyroute2.netlink import NLM_F_DUMP, NLM_F_REQUEST, NLMSG_DONE +from pyroute2.netlink.buffer import Buffer from pyroute2.netlink.rtnl import ( RTM_GETLINK, RTM_GETROUTE, @@ -69,3 +70,40 @@ def test_compile_context_manager(ipr, name, argv, kwarg, msg_type, msg_flags): assert ipr.compiled is None for msg in getattr(ipr, name)(*argv, **kwarg): assert msg['header']['type'] == msg_type[1] + + +@pytest.mark.parametrize(*test_config) +def test_compile_recv_all(ipr, name, argv, kwarg, msg_type, msg_flags): + with ipr.compile(): + data = getattr(ipr, name)(*argv, **kwarg) + assert msg_type[0], msg_flags == struct.unpack_from( + 'HH', data[0], offset=4 + ) + assert ipr.compiled is None + for request in data: + ipr.sendto(request, (0, 0)) + response = ipr.recv_all() + for page in response: + assert {msg_type[1], NLMSG_DONE} > set( + struct.unpack_from('H', page, offset=4) + ) + + +@pytest.mark.parametrize(*test_config) +def test_compile_recv_all_into(ipr, name, argv, kwarg, msg_type, msg_flags): + with ipr.compile(): + data = getattr(ipr, name)(*argv, **kwarg) + assert msg_type[0], msg_flags == struct.unpack_from( + 'HH', data[0], offset=4 + ) + assert ipr.compiled is None + for request in data: + ipr.sendto(request, (0, 0)) + buffer = Buffer() + ipr.recv_all_into(buffer) + for page in buffer.directory.values(): + if page.is_free: + continue + assert {msg_type[1], NLMSG_DONE} > set( + struct.unpack_from('H', page.view, offset=4) + ) From ea7cecb5b6d683c5972766ba466eb57fccb416d4 Mon Sep 17 00:00:00 2001 From: Peter Saveliev Date: Wed, 27 Jul 2022 00:35:53 +0200 Subject: [PATCH 2/5] nlsocket: re-implement get() --- pyroute2/netlink/nlsocket.py | 368 ++++++++++------------------------- 1 file changed, 105 insertions(+), 263 deletions(-) diff --git a/pyroute2/netlink/nlsocket.py b/pyroute2/netlink/nlsocket.py index e0bcf48a5..378416c1c 100644 --- a/pyroute2/netlink/nlsocket.py +++ b/pyroute2/netlink/nlsocket.py @@ -89,16 +89,9 @@ import struct import threading import time -import traceback import warnings -from socket import ( - MSG_DONTWAIT, - MSG_PEEK, - SO_RCVBUF, - SO_SNDBUF, - SOCK_DGRAM, - SOL_SOCKET, -) +from queue import Queue +from socket import MSG_DONTWAIT, SO_RCVBUF, SO_SNDBUF, SOCK_DGRAM, SOL_SOCKET from pyroute2 import config from pyroute2.common import DEFAULT_RCVBUF, AddrPool @@ -131,17 +124,16 @@ NetlinkHeaderDecodeError, ) -try: - from Queue import Queue -except ImportError: - from queue import Queue - log = logging.getLogger(__name__) Stats = collections.namedtuple('Stats', ('qsize', 'delta', 'delay')) NL_BUFSIZE = 32768 +class Enough: + pass + + class CompileContext: def __init__(self, netlink_socket): self.netlink_socket = netlink_socket @@ -605,25 +597,6 @@ def recv_into(self, data, *argv, **kwarg): return len(data_in) return self._sock.recv_into(data, *argv, **kwarg) - def async_recv(self): - poll = select.poll() - poll.register(self._sock, select.POLLIN | select.POLLPRI) - poll.register(self._ctrl_read, select.POLLIN | select.POLLPRI) - sockfd = self._sock.fileno() - while True: - events = poll.poll() - for (fd, event) in events: - if fd == sockfd: - try: - data = bytearray(64000) - self._sock.recv_into(data, 64000) - self.buffer_queue.put_nowait(data) - except Exception as e: - self.buffer_queue.put(e) - return - else: - return - def recv_all(self, bufsize=NL_BUFSIZE): buffers = [] while True: @@ -723,6 +696,93 @@ def sendto_gate(self, msg, addr): return self.compiled.append(msg.data) return self._sock.sendto(msg.data, addr) + def load_backlog(self, chunk, msg_seq=0, callback=None): + ''' + Parse a chunk and load messages into the backlog + ''' + seqs = set() + for msg in self.marshal.parse(chunk, msg_seq, callback): + seq = msg['header']['sequence_number'] + msg['header']['target'] = self.target + msg['header']['stats'] = Stats(0, 0, 0) + if seq not in self.backlog: + seq = 0 + seqs.add(seq) + self.backlog[seq].append(msg) + for cr in self.callbacks: + try: + if cr[0](msg): + cr[1](msg, *cr[2]) + except Exception as e: + log.warning(f'Callback fail: {cr} -> {e}') + for seq in seqs: + os.write(self._ctrl_write, b'\x01') + + def fetch_backlog(self, msg_seq=0, terminate=None, noraise=False): + ''' + Fetch and yield already parsed messages from the backlog. + ''' + enough = False + tmsg = None + with self.lock[msg_seq]: + ret = self.backlog[msg_seq] + self.backlog[msg_seq] = [] + + if ret: + os.read(self._ctrl_read, 1) + + # Collect messages up to the terminator. + # Terminator conditions: + # * NLMSG_ERROR != 0 + # * NLMSG_DONE + # * terminate() function (if defined) + # * not NLM_F_MULTI + for index in range(len(ret)): + msg = ret[index] + + if msg_seq == 0: + yield msg + continue + + # If there is an error, raise exception + if msg['header']['error'] is not None and not noraise: + with self.lock[msg_seq]: + # reschedule all the remaining messages, including + # errors and acks, into a separate deque + self.error_deque.extend(ret[index + 1 :]) + self.error_deque.extend(self.backlog[msg_seq]) + del self.backlog[msg_seq] + raise msg['header']['error'] + + # If it is a terminator message, say "enough" + # and requeue all the rest into the backlog zero + if callable(terminate): + tmsg = terminate(msg) + if isinstance(tmsg, nlmsg): + yield msg + + if (msg['header']['type'] == NLMSG_DONE) or tmsg: + enough = True + + # If it is just a normal message, append it to + # the response + if not enough: + # finish the loop on single messages + if not msg['header']['flags'] & NLM_F_MULTI: + enough = True + yield msg + + # Enough is enough, requeue the rest and delete + # our backlog + if enough: + with self.lock[0]: + self.backlog[0].extend(ret[index + 1 :]) + del self.backlog[msg_seq] + break + + if enough or msg_seq == 0: + yield Enough + def get( self, bufsize=DEFAULT_RCVBUF, @@ -748,230 +808,20 @@ def get( If `noraise` is true, error messages will be treated as any other message. ''' - ctime = time.time() + while True: + with self.read_lock: + fdlist = (self.fileno(), self._ctrl_read) + rlist, _, _ = select.select(fdlist, [], fdlist) - with self.lock[msg_seq]: - if bufsize == -1: - # get bufsize from the network data - bufsize = struct.unpack("I", self.recv(4, MSG_PEEK))[0] - elif bufsize == 0: - # get bufsize from SO_RCVBUF - bufsize = self.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2 - - tmsg = None - enough = False - backlog_acquired = False - try: - while not enough: - # 8<----------------------------------------------------------- - # - # This stage changes the backlog, so use mutex to - # prevent side changes - self.backlog_lock.acquire() - backlog_acquired = True - ## - # Stage 1. BEGIN - # - # 8<----------------------------------------------------------- - # - # Check backlog and return already collected - # messages. - # - if msg_seq == 0 and self.backlog[0]: - # Zero queue. - # - # Load the backlog, if there is valid - # content in it - for msg in self.backlog[0]: - yield msg - self.backlog[0] = [] - # And just exit - break - elif msg_seq != 0 and len(self.backlog.get(msg_seq, [])): - # Any other msg_seq. - # - # Collect messages up to the terminator. - # Terminator conditions: - # * NLMSG_ERROR != 0 - # * NLMSG_DONE - # * terminate() function (if defined) - # * not NLM_F_MULTI - # - # Please note, that if terminator not occured, - # more `recv()` rounds CAN be required. - for msg in tuple(self.backlog[msg_seq]): - - # Drop the message from the backlog, if any - self.backlog[msg_seq].remove(msg) - - # If there is an error, raise exception - if ( - msg['header']['error'] is not None - and not noraise - ): - # reschedule all the remaining messages, - # including errors and acks, into a - # separate deque - self.error_deque.extend(self.backlog[msg_seq]) - # flush the backlog for this msg_seq - del self.backlog[msg_seq] - # The loop is done - raise msg['header']['error'] - - # If it is the terminator message, say "enough" - # and requeue all the rest into Zero queue - if terminate is not None: - tmsg = terminate(msg) - if isinstance(tmsg, nlmsg): - yield msg - if (msg['header']['type'] == NLMSG_DONE) or tmsg: - # The loop is done - enough = True - - # If it is just a normal message, append it to - # the response - if not enough: - # finish the loop on single messages - if not msg['header']['flags'] & NLM_F_MULTI: - enough = True - yield msg + if self.fileno() in rlist: + chunk = self.recv(32768) + self.load_backlog(chunk, msg_seq, callback) + continue - # Enough is enough, requeue the rest and delete - # our backlog - if enough: - self.backlog[0].extend(self.backlog[msg_seq]) - del self.backlog[msg_seq] - break - - # Next iteration - self.backlog_lock.release() - backlog_acquired = False - else: - # Stage 1. END - # - # 8<------------------------------------------------------- - # - # Stage 2. BEGIN - # - # 8<------------------------------------------------------- - # - # Receive the data from the socket and put the messages - # into the backlog - # - self.backlog_lock.release() - backlog_acquired = False - ## - # - # Control the timeout. We should not be within the - # function more than TIMEOUT seconds. All the locks - # MUST be released here. - # - if (msg_seq != 0) and ( - time.time() - ctime > self.get_timeout - ): - # requeue already received for that msg_seq - self.backlog[0].extend(self.backlog[msg_seq]) - del self.backlog[msg_seq] - # throw an exception - if self.get_timeout_exception: - raise self.get_timeout_exception() - else: - return - # - if self.read_lock.acquire(False): - try: - self.change_master.clear() - # If the socket is free to read from, occupy - # it and wait for the data - # - # This is a time consuming process, so all the - # locks, except the read lock must be released - data = self.recv(bufsize) - # Parse data - msgs = self.marshal.parse( - data, msg_seq, callback - ) - # Reset ctime -- timeout should be measured - # for every turn separately - ctime = time.time() - # - current = self.buffer_queue.qsize() - delta = current - self.qsize - delay = 0 - if delta > 10: - delay = min( - 3, max(0.01, float(current) / 60000) - ) - message = ( - "Packet burst: " - "delta=%s qsize=%s delay=%s" - % (delta, current, delay) - ) - if delay < 1: - log.debug(message) - else: - log.warning(message) - time.sleep(delay) - self.qsize = current - - # We've got the data, lock the backlog again - with self.backlog_lock: - for msg in msgs: - msg['header']['target'] = self.target - msg['header']['stats'] = Stats( - current, delta, delay - ) - seq = msg['header']['sequence_number'] - if seq not in self.backlog: - if ( - msg['header']['type'] - == NLMSG_ERROR - ): - # Drop orphaned NLMSG_ERROR - # messages - continue - seq = 0 - # 8<----------------------------------- - # Callbacks section - for cr in self.callbacks: - try: - if cr[0](msg): - cr[1](msg, *cr[2]) - except: - # FIXME - # - # Usually such code formatting - # means that the method should - # be refactored to avoid such - # indentation. - # - # Plz do something with it. - # - lw = log.warning - lw("Callback fail: %s" % (cr)) - lw(traceback.format_exc()) - # 8<----------------------------------- - self.backlog[seq].append(msg) - - # Now wake up other threads - self.change_master.set() - finally: - # Finally, release the read lock: all data - # processed - self.read_lock.release() - else: - # If the socket is occupied and there is still no - # data for us, wait for the next master change or - # for a timeout - self.change_master.wait(1) - # 8<------------------------------------------------------- - # - # Stage 2. END - # - # 8<------------------------------------------------------- - finally: - if backlog_acquired: - self.backlog_lock.release() + for msg in self.fetch_backlog(msg_seq, terminate, noraise): + if msg is Enough: + return + yield msg def nlm_request_batch(self, msgs, noraise=False): """ @@ -1204,7 +1054,6 @@ def bind(self, groups=0, pid=None, **kwarg): 'use "async_cache" instead of "async", ' '"async" is a keyword from Python 3.7' ) - async_cache = kwarg.get('async_cache') or kwarg.get('async') self.groups = groups # if we have pre-defined port, use it strictly @@ -1224,13 +1073,6 @@ def bind(self, groups=0, pid=None, **kwarg): self.post_init() else: raise KeyError('no free address available') - # all is OK till now, so start async recv, if we need - if async_cache: - self.pthread = threading.Thread( - name="Netlink async cache", target=self.async_recv - ) - self.pthread.daemon = True - self.pthread.start() def add_membership(self, group): self.setsockopt(SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, group) From 3c96c9d81f953234eedbea9c48b51d5190877a39 Mon Sep 17 00:00:00 2001 From: Peter Saveliev Date: Thu, 28 Jul 2022 16:21:07 +0200 Subject: [PATCH 3/5] ndb.source: fix a potential deadlock --- pyroute2/ndb/source.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyroute2/ndb/source.py b/pyroute2/ndb/source.py index dfb285238..32c6f682f 100644 --- a/pyroute2/ndb/source.py +++ b/pyroute2/ndb/source.py @@ -327,10 +327,10 @@ def receiver(self): # The routine exists on an event with error code == 104 # while self.state.get() != 'stop': - with self.lock: - if self.shutdown.is_set(): - break + if self.shutdown.is_set(): + break + with self.lock: if self.nl is not None: try: self.nl.close(code=0) From ede63eddf7ef5f395d165541af14f6919e4c3ac9 Mon Sep 17 00:00:00 2001 From: Peter Saveliev Date: Fri, 29 Jul 2022 11:31:06 +0200 Subject: [PATCH 4/5] ndb: close sources using the proper call --- pyroute2/ndb/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyroute2/ndb/main.py b/pyroute2/ndb/main.py index a1f0ff078..21fff19fd 100644 --- a/pyroute2/ndb/main.py +++ b/pyroute2/ndb/main.py @@ -750,8 +750,8 @@ def check_sources_started(self, _locals, target, event): for target in tuple(self.sources.cache): source = self.sources.remove(target, sync=False) if source is not None and source.th is not None: - source.shutdown.set() - source.th.join() + self.log.debug(f'closing source {source}') + source.close() if self._db_cleanup: self.log.debug('flush DB for the target %s' % target) self.schema.flush(target) From 6c027d49f2dc315c94c70a08462e2332bfec3f3b Mon Sep 17 00:00:00 2001 From: Peter Saveliev Date: Fri, 29 Jul 2022 11:32:54 +0200 Subject: [PATCH 5/5] ndb.objects.route: optimize dump --- pyroute2/ndb/objects/route.py | 75 ++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/pyroute2/ndb/objects/route.py b/pyroute2/ndb/objects/route.py index 399eee811..8463e1df6 100644 --- a/pyroute2/ndb/objects/route.py +++ b/pyroute2/ndb/objects/route.py @@ -478,45 +478,48 @@ def dump(cls, view): for record in view.ndb.schema.fetch(req + where, values): route_id = record[-1] record = list(record[:-1]) - # - # fetch metrics - metrics = tuple( - view.ndb.schema.fetch( - ''' - SELECT * FROM metrics WHERE f_route_id = %s - ''' - % (plch,), - (route_id,), + if route_id is not None: + # + # fetch metrics + metrics = tuple( + view.ndb.schema.fetch( + ''' + SELECT * FROM metrics WHERE f_route_id = %s + ''' + % (plch,), + (route_id,), + ) ) - ) - if metrics: - ret = {} - names = view.ndb.schema.compiled['metrics']['norm_names'] - for k, v in zip(names, metrics[0]): - if v is not None and k not in ( - 'target', - 'route_id', - 'tflags', - ): - ret[k] = v - record.append(json.dumps(ret)) - else: - record.append(None) - # - # fetch encap - enc_mpls = tuple( - view.ndb.schema.fetch( - ''' - SELECT * FROM enc_mpls WHERE f_route_id = %s - ''' - % (plch,), - (route_id,), + if metrics: + ret = {} + names = view.ndb.schema.compiled['metrics']['norm_names'] + for k, v in zip(names, metrics[0]): + if v is not None and k not in ( + 'target', + 'route_id', + 'tflags', + ): + ret[k] = v + record.append(json.dumps(ret)) + else: + record.append(None) + # + # fetch encap + enc_mpls = tuple( + view.ndb.schema.fetch( + ''' + SELECT * FROM enc_mpls WHERE f_route_id = %s + ''' + % (plch,), + (route_id,), + ) ) - ) - if enc_mpls: - record.append(enc_mpls[0][2]) + if enc_mpls: + record.append(enc_mpls[0][2]) + else: + record.append(None) else: - record.append(None) + record.extend((None, None)) yield record @classmethod