diff --git a/README.md b/README.md index e3fa76a..a180f8f 100644 --- a/README.md +++ b/README.md @@ -61,5 +61,5 @@ The library is still under development. The current major limitations and deviat ### Service Discovery - Configuration and load balancing options in SOME/IP SD messages are not supported. -- TTL of Service Discovery entries is not checked yet. +- Stop subscribe message of notifications is not supported - The Initial Wait Phase and Repetition Phase of the Service Discovery specification are skipped. The Main Phase is directly entered, i.e. SD Offer Entries are immediately sent cyclically. diff --git a/src/someipy/_internal/store_with_timeout.py b/src/someipy/_internal/store_with_timeout.py new file mode 100644 index 0000000..c46a5bf --- /dev/null +++ b/src/someipy/_internal/store_with_timeout.py @@ -0,0 +1,83 @@ +import asyncio +from typing import Set, Generic, TypeVar, Protocol + + +T = TypeVar("T") + + +class ObjectWithTtl(Protocol): + ttl: int # The TTL of the object in seconds + + +class StoreWithTimeout: + + class ObjectWithTtlWrapper: + def __init__(self, value: ObjectWithTtl): + self.value: ObjectWithTtl = value + self.timeout_task: asyncio.Task = None + self.active = False # Flag needed to prevent callbacks being called during cancellation + + def __eq__(self, other): + # Two objects are equal if their values are equal even if the timeout is different + return self.value == other.value + + def __hash__(self): + return hash(self.value) + + async def _wait(self): + await asyncio.sleep(self.value.ttl) + + def __init__(self): + self.values: Set[self.ObjectWithTtlWrapper] = set() + self._current = 0 + + async def add(self, object_to_add: ObjectWithTtl, callback=None): + wrapper = self.ObjectWithTtlWrapper(object_to_add) + + if wrapper in self.values: + await self.remove(wrapper.value) + + wrapper.timeout_task = asyncio.create_task(wrapper._wait()) + wrapper.active = True + wrapper.timeout_task.add_done_callback( + lambda _: self._done_callback(wrapper, callback) + ) + self.values.add(wrapper) + + def _done_callback(self, caller: ObjectWithTtlWrapper, callback=None): + self.values.discard(caller) + if caller.active is True and callback is not None: + callback(caller.value) + + async def remove(self, object_to_remove: ObjectWithTtl): + wrapper = self.ObjectWithTtlWrapper(object_to_remove) + if wrapper in self.values: + for value in self.values: + if value == wrapper: + value.active = False + value.timeout_task.cancel() + try: + await value.timeout_task + except asyncio.CancelledError: + pass + break + + async def clear(self): + while len(self.values) > 0: + await self.remove(next(iter(self.values)).value) + + def __contains__(self, item): + wrapper = self.ObjectWithTtlWrapper(item) + return wrapper in self.values + + def __iter__(self): + self._current = 0 + return self + + def __next__(self): + if self._current < len(self.values): + result = next(iter(self.values)) + self._current += 1 + return result.value + else: + raise StopIteration diff --git a/src/someipy/client_service_instance.py b/src/someipy/client_service_instance.py index c0054b4..44d072b 100644 --- a/src/someipy/client_service_instance.py +++ b/src/someipy/client_service_instance.py @@ -34,6 +34,7 @@ ServiceDiscoveryObserver, ServiceDiscoverySender, ) +from someipy._internal.store_with_timeout import StoreWithTimeout from someipy._internal.utils import ( create_udp_socket, EndpointType, @@ -60,17 +61,6 @@ def __eq__(self, value: object) -> bool: return self.eventgroup_id == value.eventgroup_id -class FoundService: - service: SdService - - def __init__(self, service: SdService) -> None: - self.service = service - - def __eq__(self, __value: object) -> bool: - if isinstance(__value, FoundService): - return self.service == __value.service - - class ClientServiceInstance(ServiceDiscoveryObserver): _service: Service _instance_id: int @@ -84,7 +74,7 @@ class ClientServiceInstance(ServiceDiscoveryObserver): _expected_acks: List[ExpectedAck] _callback: Callable[[bytes], None] - _found_services: Iterable[FoundService] + _offered_services: StoreWithTimeout _subscription_active: bool _method_call_futures: Dict[int, asyncio.Future] @@ -121,7 +111,8 @@ def __init__( self._tcp_connection_established_event = asyncio.Event() self._shutdown_requested = False - self._found_services = [] + self._offered_services = StoreWithTimeout() + self._subscription_active = False self._method_call_futures: Dict[int, asyncio.Future] = {} self._client_id = client_id @@ -146,11 +137,8 @@ def service_found(self) -> bool: Returns whether the service instance represented by the ClientServiceInstance has been offered by a server and was found. """ has_service = False - for s in self._found_services: - if ( - s.service.service_id == self._service.id - and s.service.instance_id == self._instance_id - ): + for s in self._offered_services: + if s.service_id == self._service.id and s.instance_id == self._instance_id: has_service = True break return has_service @@ -201,8 +189,12 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult: call_future = asyncio.get_running_loop().create_future() self._method_call_futures[session_id] = call_future - dst_address = str(self._found_services[0].service.endpoint[0]) - dst_port = self._found_services[0].service.endpoint[1] + # At this point the service should be found since an exception would have been raised before + for s in self._offered_services: + if s.service_id == self._service.id and s.instance_id == self._instance_id: + dst_address = str(s.endpoint[0]) + dst_port = s.endpoint[1] + break if self._protocol == TransportLayerProtocol.TCP: # In case of TCP, first try to connect to the TCP server @@ -241,9 +233,19 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult: else: # In case of UDP, just send out the datagram and wait for the response + # At this point the service should be found since an exception would have been raised before + for s in self._offered_services: + if ( + s.service_id == self._service.id + and s.instance_id == self._instance_id + ): + dst_address = str(s.endpoint[0]) + dst_port = s.endpoint[1] + break + self._someip_endpoint.sendto( someip_message.serialize(), - endpoint_to_str_int_tuple(self._found_services[0].service.endpoint), + (dst_address, dst_port), ) # After sending the method call wait for maximum 10 seconds @@ -348,6 +350,11 @@ def handle_find_service(self): # Not needed in client service instance pass + def _timeout_of_offered_service(self, offered_service: SdService): + get_logger(_logger_name).debug( + f"Offered service timed out: service id 0x{offered_service.service_id:04x}, instance id 0x{offered_service.instance_id:04x}" + ) + def handle_offer_service(self, offered_service: SdService): if self._service.id != offered_service.service_id: return @@ -367,8 +374,11 @@ def handle_offer_service(self, offered_service: SdService): # 0xFFFFFFFF allows to handle any minor version return - if FoundService(offered_service) not in self._found_services: - self._found_services.append(FoundService(offered_service)) + asyncio.get_event_loop().create_task( + self._offered_services.add( + offered_service, self._timeout_of_offered_service + ) + ) if len(self._eventgroups_to_subscribe) == 0: return @@ -424,10 +434,9 @@ def handle_stop_offer_service(self, offered_service: SdService) -> None: if self._instance_id != offered_service.instance_id: return - # Remove the service from the found services - self._found_services = [ - f for f in self._found_services if f.service != offered_service - ] + asyncio.get_event_loop().create_task( + self._offered_services.remove(offered_service) + ) self._expected_acks = [] self._subscription_active = False diff --git a/tests/test_store_with_timeout.py b/tests/test_store_with_timeout.py new file mode 100644 index 0000000..6089e08 --- /dev/null +++ b/tests/test_store_with_timeout.py @@ -0,0 +1,94 @@ +import asyncio +import pytest +import pytest_asyncio +from someipy._internal.store_with_timeout import StoreWithTimeout + + +class MyTestObjectWithTtl: + def __init__(self, value: int, ttl: int): + self.ttl = ttl + self.value = value + + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return self.value == other.value + + +@pytest.mark.asyncio +async def test_add(): + store = StoreWithTimeout() + + await store.add(MyTestObjectWithTtl(1, 2)) + assert len(store.values) == 1 + assert (MyTestObjectWithTtl(1, 2)) in store + + await store.add(MyTestObjectWithTtl(1, 2)) + assert len(store.values) == 1 + + await store.add(MyTestObjectWithTtl(1, 1)) + assert len(store.values) == 1 + await asyncio.sleep(1.5) + assert len(store.values) == 0 + + await store.add(MyTestObjectWithTtl(2, 2)) + assert len(store.values) == 1 + + await asyncio.sleep(2.5) + assert len(store.values) == 0 + await store.clear() + + +@pytest.mark.asyncio +async def test_clear(): + store = StoreWithTimeout() + await store.add(MyTestObjectWithTtl(1, 5)) + await store.add(MyTestObjectWithTtl(2, 5)) + await store.add(MyTestObjectWithTtl(3, 5)) + assert len(store.values) == 3 + + await store.clear() + assert len(store.values) == 0 + + +@pytest.mark.asyncio +async def test_remove(): + store = StoreWithTimeout() + await store.add(MyTestObjectWithTtl(1, 5)) + await store.add(MyTestObjectWithTtl(2, 5)) + await store.add(MyTestObjectWithTtl(3, 5)) + assert len(store.values) == 3 + + await store.remove(MyTestObjectWithTtl(2, 5)) + assert len(store.values) == 2 + await store.clear() + + # Try to remove some object that is not in the store + # This shall not raise an error + await store.remove(MyTestObjectWithTtl(2, 5)) + + +@pytest.mark.asyncio +async def test_callback(): + store = StoreWithTimeout() + + callback_was_called = 0 + + def callback(): + nonlocal callback_was_called + print("Callback was called") + callback_was_called += 1 + + await store.add(MyTestObjectWithTtl(1, 1), callback) + await asyncio.sleep(1.5) + assert callback_was_called == 1 + assert len(store.values) == 0 + + callback_was_called = 0 + + await store.add(MyTestObjectWithTtl(1, 2), callback) + assert len(store.values) == 1 + await store.remove(MyTestObjectWithTtl(1, 2)) + assert len(store.values) == 0 + assert callback_was_called == 0