From 65acde3ec5eb1d385e1757e5556cd13ab8675fb2 Mon Sep 17 00:00:00 2001 From: Leon Morten Richter Date: Wed, 26 Jun 2024 11:43:09 +0200 Subject: [PATCH] Improve filter api (#142) * feat: improves API of FilterChain * chore: updates docs for updated API of FilterChain * chore: bumps versions & updates changelog --- CHANGELOG.txt | 7 ++++++ README.md | 8 +++--- docs/filters.rst | 6 ++--- examples/filters.py | 25 +++++------------- pyais/__init__.py | 5 ++-- pyais/filter.py | 57 ++++++++++++++++++++++-------------------- tests/test_examples.py | 10 +++++++- tests/test_filters.py | 22 ++++++++++++++++ 8 files changed, 84 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 0ea40e9..27caaac 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,6 +1,13 @@ ==================== pyais CHANGELOG ==================== +------------------------------------------------------------------------------- + Version 2.6.6 26 Jun 2024 +------------------------------------------------------------------------------- +* improves the API of `FilterChain` + * `FilterChain.filter(stream)` now accepts a stream instance + * this stream MUST implement the `Stream` interface defined in pyais.stream + * individual messages can be filtered using `IterMessages(...)` ------------------------------------------------------------------------------- Version 2.6.5 10 May 2024 ------------------------------------------------------------------------------- diff --git a/README.md b/README.md index 91595e8..a32e6b9 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ This is useful for debugging or for getting used to pyais. It is also possible to encode messages. | :exclamation: Every message needs at least a single keyword argument: `mmsi`. All other fields have most likely default values. | -|----------------------------------------------------------------------------------------------------------------------------------| +| -------------------------------------------------------------------------------------------------------------------------------- | ### Encode data using a dictionary @@ -385,7 +385,7 @@ The filtering system is built around a series of filter classes, each designed t ## Example Usage ```python -from pyais import decode +from pyais import decode, TCPConnection # ... (importing necessary classes) # Define and initialize filters @@ -405,8 +405,8 @@ chain = FilterChain([ ]) # Decode AIS data and filter -data = [decode(b"!AIVDM..."), ...] -filtered_data = list(chain.filter(data)) +stream = TCPConnection(...) +filtered_data = list(chain.filter(stream)) for msg in filtered_data: print(msg.lat, msg.lon) diff --git a/docs/filters.rst b/docs/filters.rst index 4be37c4..062e23e 100644 --- a/docs/filters.rst +++ b/docs/filters.rst @@ -62,7 +62,7 @@ Example Usage .. code-block:: python - from pyais import decode + from pyais import decode, TCPConnection # ... (importing necessary classes) # Define and initialize filters @@ -82,8 +82,8 @@ Example Usage ]) # Decode AIS data and filter - data = [decode(b"!AIVDM..."), ...] - filtered_data = list(chain.filter(data)) + stream = TCPConnection(...) + filtered_data = chain.filter(stream) for msg in filtered_data: print(msg.lat, msg.lon) diff --git a/examples/filters.py b/examples/filters.py index 20b9dd5..1e37bd8 100644 --- a/examples/filters.py +++ b/examples/filters.py @@ -1,4 +1,3 @@ -from pyais import decode from pyais.filter import ( AttributeFilter, DistanceFilter, @@ -7,6 +6,7 @@ MessageTypeFilter, NoneFilter ) +from pyais.stream import TCPConnection # Define the filter chain with various criteria chain = FilterChain([ @@ -26,21 +26,8 @@ GridFilter(lat_min=50, lon_min=0, lat_max=52, lon_max=5), ]) -# Example AIS data to filter -data = [ - decode(b"!AIVDM,1,1,,B,15NG6V0P01G?cFhE`R2IU?wn28R>,0*05"), - decode(b"!AIVDM,1,1,,A,13HOI:0P0000VOHLCnHQKwvL05Ip,0*23"), - decode(b"!AIVDM,1,1,,B,100h00PP0@PHFV`Mg5gTH?vNPUIp,0*3B"), - decode(b"!AIVDM,1,1,,A,133sVfPP00PD>hRMDH@jNOvN20S8,0*7F"), - decode(b"!AIVDM,1,1,,B,13eaJF0P00Qd388Eew6aagvH85Ip,0*45"), - decode(b"!AIVDM,1,1,,A,14eGrSPP00ncMJTO5C6aBwvP2D0?,0*7A"), - decode(b"!AIVDM,1,1,,A,15MrVH0000KH<:V:NtBLoqFP2H9:,0*2F"), - decode(b"!AIVDM,1,1,,A,702R5`hwCjq8,0*6B"), -] - -# Filter the data using the defined chain -filtered_data = list(chain.filter(data)) - -# Print the latitude and longitude of each message that passed the filters -for msg in filtered_data: - print(msg.lat, msg.lon) +# Create a stream of ais messages +with TCPConnection('153.44.253.27', port=5631) as ais_stream: + for ais_msg in chain.filter(ais_stream): + # Only messages that pass this filter chain are printed + print(ais_msg) diff --git a/pyais/__init__.py b/pyais/__init__.py index d84d292..c03aa8f 100644 --- a/pyais/__init__.py +++ b/pyais/__init__.py @@ -1,11 +1,11 @@ from pyais.messages import NMEAMessage, ANY_MESSAGE, AISSentence -from pyais.stream import TCPConnection, FileReaderStream, IterMessages +from pyais.stream import TCPConnection, FileReaderStream, IterMessages, Stream from pyais.encode import encode_dict, encode_msg, ais_to_nmea_0183 from pyais.decode import decode from pyais.tracker import AISTracker, AISTrack __license__ = 'MIT' -__version__ = '2.6.5' +__version__ = '2.6.6' __author__ = 'Leon Morten Richter' __all__ = ( @@ -18,6 +18,7 @@ 'TCPConnection', 'IterMessages', 'FileReaderStream', + 'Stream', 'decode', 'AISTracker', 'AISTrack', diff --git a/pyais/filter.py b/pyais/filter.py index 5a7cc76..750e5b3 100644 --- a/pyais/filter.py +++ b/pyais/filter.py @@ -6,12 +6,15 @@ """ import math +import socket import typing import pyais # Type Aliases for readability -AIS_STREAM = typing.Generator[pyais.AISSentence, None, None] -FILTER_FUNCTION = typing.Callable[[pyais.AISSentence], bool] +F = typing.TypeVar("F", typing.BinaryIO, socket.socket, None) +AIS_STREAM = pyais.Stream[F] +MESSAGE_STREAM = typing.Generator[pyais.ANY_MESSAGE, None, None] +FILTER_FUNCTION = typing.Callable[[pyais.ANY_MESSAGE], bool] LAT_LON = typing.Tuple[float, float] # Tuple type for latitude and longitude @@ -66,30 +69,30 @@ def set_next(self, filter: 'Filter') -> None: """ self.next_filter = filter - def filter(self, data: AIS_STREAM) -> AIS_STREAM: + def filter(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM: """ Apply the filter to the data and then pass it to the next filter. Parameters: - data (AIS_STREAM): The stream of data to filter. + data (MESSAGE_STREAM): The stream of data to filter. Returns: - AIS_STREAM: The filtered data stream. + MESSAGE_STREAM: The filtered data stream. """ data = self.filter_data(data) if self.next_filter: return self.next_filter.filter(data) return data - def filter_data(self, data: AIS_STREAM) -> AIS_STREAM: + def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM: """ Abstract method to filter data. Should be implemented by subclasses. Parameters: - data (AIS_STREAM): The stream of data to filter. + data (MESSAGE_STREAM): The stream of data to filter. Returns: - AIS_STREAM: The filtered data stream. + MESSAGE_STREAM: The filtered data stream. """ raise NotImplementedError("This method should be overridden by subclasses.") @@ -109,15 +112,15 @@ def __init__(self, ff: FILTER_FUNCTION) -> None: super().__init__() self.ff = ff - def filter_data(self, data: AIS_STREAM) -> AIS_STREAM: + def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM: """ Filter the data based on the user-defined function. Parameters: - data (AIS_STREAM): The stream of data to filter. + data (MESSAGE_STREAM): The stream of data to filter. Yields: - AIS_STREAM: The filtered data stream. + MESSAGE_STREAM: The filtered data stream. """ yield from filter(self.ff, data) @@ -137,15 +140,15 @@ def __init__(self, *attrs: str) -> None: super().__init__() self.attrs = attrs - def filter_data(self, data: AIS_STREAM) -> AIS_STREAM: + def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM: """ Filter the data, allowing only messages where specified attributes are not None. Parameters: - data (AIS_STREAM): The stream of data to filter. + data (MESSAGE_STREAM): The stream of data to filter. Yields: - AIS_STREAM: The filtered data stream. + MESSAGE_STREAM: The filtered data stream. """ for msg in data: if all(getattr(msg, attr, None) is not None for attr in self.attrs): @@ -167,18 +170,18 @@ def __init__(self, *types: int) -> None: super().__init__() self.types = types - def filter_data(self, data: AIS_STREAM) -> AIS_STREAM: + def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM: """ Filter the data, allowing only messages of specified types. Parameters: - data (AIS_STREAM): The stream of data to filter. + data (MESSAGE_STREAM): The stream of data to filter. Yields: - AIS_STREAM: The filtered data stream. + MESSAGE_STREAM: The filtered data stream. """ for msg in data: - if msg.msg_type not in self.types: # type: ignore + if msg.msg_type not in self.types: continue yield msg @@ -200,15 +203,15 @@ def __init__(self, ref_lat_lon: LAT_LON, distance_km: float) -> None: self.ref_lat_lon = ref_lat_lon self.distance_km = distance_km - def filter_data(self, data: AIS_STREAM) -> AIS_STREAM: + def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM: """ Filter the data based on distance from a reference point. Parameters: - data (AIS_STREAM): The stream of data to filter. + data (MESSAGE_STREAM): The stream of data to filter. Yields: - AIS_STREAM: The filtered data stream. + MESSAGE_STREAM: The filtered data stream. """ for msg in data: if hasattr(msg, 'lat'): @@ -235,15 +238,15 @@ def __init__(self, lat_min: float, lon_min: float, lat_max: float, lon_max: floa self.lat_max = lat_max self.lon_max = lon_max - def filter_data(self, data: AIS_STREAM) -> AIS_STREAM: + def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM: """ Filter the data based on whether it falls within a specified grid. Parameters: - data (AIS_STREAM): The stream of data to filter. + data (MESSAGE_STREAM): The stream of data to filter. Yields: - AIS_STREAM: The filtered data stream. + MESSAGE_STREAM: The filtered data stream. """ for msg in data: if hasattr(msg, 'lat'): @@ -274,14 +277,14 @@ def __init__(self, filters: typing.List[Filter]) -> None: self.filters = filters self.start = filters[0] - def filter(self, data: AIS_STREAM) -> AIS_STREAM: + def filter(self, stream: AIS_STREAM[F]) -> MESSAGE_STREAM: """ Apply the chain of filters to the data. Parameters: - data (AIS_STREAM): The stream of data to filter. + stream (AIS_STREAM): The stream of data to filter. Yields: AIS_STREAM: The filtered data stream. """ - yield from self.start.filter(data) + yield from self.start.filter(x.decode() for x in stream) diff --git a/tests/test_examples.py b/tests/test_examples.py index 5e61971..6dac7c9 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -4,6 +4,14 @@ import subprocess import unittest +KEYWORDS_TO_IGNORE = ( + 'tcp', + 'udp', + 'live', + 'tracking', + 'filters', +) + class TestExamples(unittest.TestCase): """ @@ -14,7 +22,7 @@ def test_run_every_file(self): i = -1 exe = sys.executable for i, file in enumerate(pathlib.Path(__file__).parent.parent.joinpath('examples').glob('*.py')): - if 'tcp' not in str(file) and 'udp' not in str(file) and 'live' not in str(file) and 'tracking' not in str(file): + if all(kw not in str(file) for kw in KEYWORDS_TO_IGNORE): env = os.environ env['PYTHONPATH'] = f':{pathlib.Path(__file__).parent.parent.absolute()}' assert subprocess.check_call(f'{exe} {file}'.split(), env=env, shell=False) == 0 diff --git a/tests/test_filters.py b/tests/test_filters.py index 753a7a0..775c49e 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,5 +1,7 @@ +import pathlib import unittest from pyais.filter import AttributeFilter, DistanceFilter, FilterChain, GridFilter, MessageTypeFilter, NoneFilter, haversine +from pyais.stream import FileReaderStream class MockAISMessage: @@ -9,6 +11,9 @@ def __init__(self, msg_type=None, lat=None, lon=None, other_attr=None): self.lon = lon self.other_attr = other_attr + def decode(self): + return self + class TestNoneFilter(unittest.TestCase): def test_filtering_none_attributes(self): @@ -165,6 +170,7 @@ def test_filter_chain(self): filter1 = NoneFilter('lat', 'lon') filter2 = MessageTypeFilter(1, 2) chain = FilterChain([filter1, filter2]) + mock_data = [MockAISMessage(lat=1, lon=1, msg_type=1), MockAISMessage(lat=None, lon=1, msg_type=2)] # Execute @@ -200,6 +206,22 @@ def test_complex_filter_chain(self): self.assertEqual(filtered_data[0].lon, -73.965) self.assertEqual(filtered_data[0].msg_type, 1) + def test_filter_chain_with_file_stream(self): + # Setup: Define the filters and chain + chain = FilterChain([AttributeFilter(lambda x: x.mmsi == 445451000)]) + + # Setup: Define sample file + file = pathlib.Path(__file__).parent.joinpath('messages.ais') + + with FileReaderStream(file) as ais_stream: + total = len(list(ais_stream)) + + with FileReaderStream(file) as ais_stream: + filtered = list(chain.filter(ais_stream)) + + self.assertEqual(len(filtered), 2) + self.assertEqual(total, 6) + class TestAttributeFilter(unittest.TestCase): def test_attribute_filtering(self):