Skip to content

Commit

Permalink
Implement timeouts/ttl of offered services
Browse files Browse the repository at this point in the history
  • Loading branch information
chrizog committed Nov 9, 2024
1 parent 0af1e53 commit d072d44
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ The library is still under development. The current major limitations and deviat
### Service Discovery

- Configuration and load balancing options in SOME/IP SD messages are not supported.
- TTL of Service Discovery entries is not checked yet.
- Stop subscribe message of notifications is not supported
- The Initial Wait Phase and Repetition Phase of the Service Discovery specification are skipped. The Main Phase is directly entered, i.e. SD Offer Entries are immediately sent cyclically.
83 changes: 83 additions & 0 deletions src/someipy/_internal/store_with_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import asyncio
from typing import Set, Generic, TypeVar, Protocol


T = TypeVar("T")


class ObjectWithTtl(Protocol):
ttl: int # The TTL of the object in seconds


class StoreWithTimeout:

class ObjectWithTtlWrapper:
def __init__(self, value: ObjectWithTtl):
self.value: ObjectWithTtl = value
self.timeout_task: asyncio.Task = None
self.active = False # Flag needed to prevent callbacks being called during cancellation

def __eq__(self, other):
# Two objects are equal if their values are equal even if the timeout is different
return self.value == other.value

def __hash__(self):
return hash(self.value)

async def _wait(self):
await asyncio.sleep(self.value.ttl)

def __init__(self):
self.values: Set[self.ObjectWithTtlWrapper] = set()
self._current = 0

async def add(self, object_to_add: ObjectWithTtl, callback=None):
wrapper = self.ObjectWithTtlWrapper(object_to_add)

if wrapper in self.values:
await self.remove(wrapper.value)

wrapper.timeout_task = asyncio.create_task(wrapper._wait())
wrapper.active = True
wrapper.timeout_task.add_done_callback(
lambda _: self._done_callback(wrapper, callback)
)
self.values.add(wrapper)

def _done_callback(self, caller: ObjectWithTtlWrapper, callback=None):
self.values.discard(caller)
if caller.active is True and callback is not None:
callback(caller.value)

async def remove(self, object_to_remove: ObjectWithTtl):
wrapper = self.ObjectWithTtlWrapper(object_to_remove)
if wrapper in self.values:
for value in self.values:
if value == wrapper:
value.active = False
value.timeout_task.cancel()
try:
await value.timeout_task
except asyncio.CancelledError:
pass
break

async def clear(self):
while len(self.values) > 0:
await self.remove(next(iter(self.values)).value)

def __contains__(self, item):
wrapper = self.ObjectWithTtlWrapper(item)
return wrapper in self.values

def __iter__(self):
self._current = 0
return self

def __next__(self):
if self._current < len(self.values):
result = next(iter(self.values))
self._current += 1
return result.value
else:
raise StopIteration
63 changes: 36 additions & 27 deletions src/someipy/client_service_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ServiceDiscoveryObserver,
ServiceDiscoverySender,
)
from someipy._internal.store_with_timeout import StoreWithTimeout
from someipy._internal.utils import (
create_udp_socket,
EndpointType,
Expand All @@ -60,17 +61,6 @@ def __eq__(self, value: object) -> bool:
return self.eventgroup_id == value.eventgroup_id


class FoundService:
service: SdService

def __init__(self, service: SdService) -> None:
self.service = service

def __eq__(self, __value: object) -> bool:
if isinstance(__value, FoundService):
return self.service == __value.service


class ClientServiceInstance(ServiceDiscoveryObserver):
_service: Service
_instance_id: int
Expand All @@ -84,7 +74,7 @@ class ClientServiceInstance(ServiceDiscoveryObserver):
_expected_acks: List[ExpectedAck]

_callback: Callable[[bytes], None]
_found_services: Iterable[FoundService]
_offered_services: StoreWithTimeout
_subscription_active: bool

_method_call_futures: Dict[int, asyncio.Future]
Expand Down Expand Up @@ -121,7 +111,8 @@ def __init__(
self._tcp_connection_established_event = asyncio.Event()
self._shutdown_requested = False

self._found_services = []
self._offered_services = StoreWithTimeout()

self._subscription_active = False
self._method_call_futures: Dict[int, asyncio.Future] = {}
self._client_id = client_id
Expand All @@ -146,11 +137,8 @@ def service_found(self) -> bool:
Returns whether the service instance represented by the ClientServiceInstance has been offered by a server and was found.
"""
has_service = False
for s in self._found_services:
if (
s.service.service_id == self._service.id
and s.service.instance_id == self._instance_id
):
for s in self._offered_services:
if s.service_id == self._service.id and s.instance_id == self._instance_id:
has_service = True
break
return has_service
Expand Down Expand Up @@ -201,8 +189,12 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult:
call_future = asyncio.get_running_loop().create_future()
self._method_call_futures[session_id] = call_future

dst_address = str(self._found_services[0].service.endpoint[0])
dst_port = self._found_services[0].service.endpoint[1]
# At this point the service should be found since an exception would have been raised before
for s in self._offered_services:
if s.service_id == self._service.id and s.instance_id == self._instance_id:
dst_address = str(s.endpoint[0])
dst_port = s.endpoint[1]
break

if self._protocol == TransportLayerProtocol.TCP:
# In case of TCP, first try to connect to the TCP server
Expand Down Expand Up @@ -241,9 +233,19 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult:

else:
# In case of UDP, just send out the datagram and wait for the response
# At this point the service should be found since an exception would have been raised before
for s in self._offered_services:
if (
s.service_id == self._service.id
and s.instance_id == self._instance_id
):
dst_address = str(s.endpoint[0])
dst_port = s.endpoint[1]
break

self._someip_endpoint.sendto(
someip_message.serialize(),
endpoint_to_str_int_tuple(self._found_services[0].service.endpoint),
(dst_address, dst_port),
)

# After sending the method call wait for maximum 10 seconds
Expand Down Expand Up @@ -348,6 +350,11 @@ def handle_find_service(self):
# Not needed in client service instance
pass

def _timeout_of_offered_service(self, offered_service: SdService):
get_logger(_logger_name).debug(
f"Offered service timed out: service id 0x{offered_service.service_id:04x}, instance id 0x{offered_service.instance_id:04x}"
)

def handle_offer_service(self, offered_service: SdService):
if self._service.id != offered_service.service_id:
return
Expand All @@ -367,8 +374,11 @@ def handle_offer_service(self, offered_service: SdService):
# 0xFFFFFFFF allows to handle any minor version
return

if FoundService(offered_service) not in self._found_services:
self._found_services.append(FoundService(offered_service))
asyncio.get_event_loop().create_task(
self._offered_services.add(
offered_service, self._timeout_of_offered_service
)
)

if len(self._eventgroups_to_subscribe) == 0:
return
Expand Down Expand Up @@ -424,10 +434,9 @@ def handle_stop_offer_service(self, offered_service: SdService) -> None:
if self._instance_id != offered_service.instance_id:
return

# Remove the service from the found services
self._found_services = [
f for f in self._found_services if f.service != offered_service
]
asyncio.get_event_loop().create_task(
self._offered_services.remove(offered_service)
)

self._expected_acks = []
self._subscription_active = False
Expand Down
94 changes: 94 additions & 0 deletions tests/test_store_with_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
import pytest
import pytest_asyncio
from someipy._internal.store_with_timeout import StoreWithTimeout


class MyTestObjectWithTtl:
def __init__(self, value: int, ttl: int):
self.ttl = ttl
self.value = value

def __hash__(self):
return hash(self.value)

def __eq__(self, other):
return self.value == other.value


@pytest.mark.asyncio
async def test_add():
store = StoreWithTimeout()

await store.add(MyTestObjectWithTtl(1, 2))
assert len(store.values) == 1
assert (MyTestObjectWithTtl(1, 2)) in store

await store.add(MyTestObjectWithTtl(1, 2))
assert len(store.values) == 1

await store.add(MyTestObjectWithTtl(1, 1))
assert len(store.values) == 1
await asyncio.sleep(1.5)
assert len(store.values) == 0

await store.add(MyTestObjectWithTtl(2, 2))
assert len(store.values) == 1

await asyncio.sleep(2.5)
assert len(store.values) == 0
await store.clear()


@pytest.mark.asyncio
async def test_clear():
store = StoreWithTimeout()
await store.add(MyTestObjectWithTtl(1, 5))
await store.add(MyTestObjectWithTtl(2, 5))
await store.add(MyTestObjectWithTtl(3, 5))
assert len(store.values) == 3

await store.clear()
assert len(store.values) == 0


@pytest.mark.asyncio
async def test_remove():
store = StoreWithTimeout()
await store.add(MyTestObjectWithTtl(1, 5))
await store.add(MyTestObjectWithTtl(2, 5))
await store.add(MyTestObjectWithTtl(3, 5))
assert len(store.values) == 3

await store.remove(MyTestObjectWithTtl(2, 5))
assert len(store.values) == 2
await store.clear()

# Try to remove some object that is not in the store
# This shall not raise an error
await store.remove(MyTestObjectWithTtl(2, 5))


@pytest.mark.asyncio
async def test_callback():
store = StoreWithTimeout()

callback_was_called = 0

def callback():
nonlocal callback_was_called
print("Callback was called")
callback_was_called += 1

await store.add(MyTestObjectWithTtl(1, 1), callback)
await asyncio.sleep(1.5)
assert callback_was_called == 1
assert len(store.values) == 0

callback_was_called = 0

await store.add(MyTestObjectWithTtl(1, 2), callback)
assert len(store.values) == 1
await store.remove(MyTestObjectWithTtl(1, 2))
assert len(store.values) == 0
assert callback_was_called == 0

0 comments on commit d072d44

Please sign in to comment.