Skip to content

Implement timeouts/ttl of offered services #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
python -m pip install --upgrade pip
pip install -e .
pip install pytest
pip install pytest-asyncio

- name: Run pytest
run: pytest tests
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ The library is still under development. The current major limitations and deviat
### Service Discovery

- Configuration and load balancing options in SOME/IP SD messages are not supported.
- TTL of Service Discovery entries is not checked yet.
- Stop subscribe message of notifications is not supported
- The Initial Wait Phase and Repetition Phase of the Service Discovery specification are skipped. The Main Phase is directly entered, i.e. SD Offer Entries are immediately sent cyclically.
83 changes: 83 additions & 0 deletions src/someipy/_internal/store_with_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import asyncio
from typing import Set, Generic, TypeVar, Protocol


T = TypeVar("T")


class ObjectWithTtl(Protocol):
ttl: int # The TTL of the object in seconds


class StoreWithTimeout:

class ObjectWithTtlWrapper:
def __init__(self, value: ObjectWithTtl):
self.value: ObjectWithTtl = value
self.timeout_task: asyncio.Task = None
self.active = False # Flag needed to prevent callbacks being called during cancellation

def __eq__(self, other):
# Two objects are equal if their values are equal even if the timeout is different
return self.value == other.value

def __hash__(self):
return hash(self.value)

async def _wait(self):
await asyncio.sleep(self.value.ttl)

def __init__(self):
self.values: Set[self.ObjectWithTtlWrapper] = set()
self._current = 0

async def add(self, object_to_add: ObjectWithTtl, callback=None):
wrapper = self.ObjectWithTtlWrapper(object_to_add)

if wrapper in self.values:
await self.remove(wrapper.value)

wrapper.timeout_task = asyncio.create_task(wrapper._wait())
wrapper.active = True
wrapper.timeout_task.add_done_callback(
lambda _: self._done_callback(wrapper, callback)
)
self.values.add(wrapper)

def _done_callback(self, caller: ObjectWithTtlWrapper, callback=None):
self.values.discard(caller)
if caller.active is True and callback is not None:
callback(caller.value)

async def remove(self, object_to_remove: ObjectWithTtl):
wrapper = self.ObjectWithTtlWrapper(object_to_remove)
if wrapper in self.values:
for value in self.values:
if value == wrapper:
value.active = False
value.timeout_task.cancel()
try:
await value.timeout_task
except asyncio.CancelledError:
pass
break

async def clear(self):
while len(self.values) > 0:
await self.remove(next(iter(self.values)).value)

def __contains__(self, item):
wrapper = self.ObjectWithTtlWrapper(item)
return wrapper in self.values

def __iter__(self):
self._current = 0
return self

def __next__(self):
if self._current < len(self.values):
result = next(iter(self.values))
self._current += 1
return result.value
else:
raise StopIteration
63 changes: 36 additions & 27 deletions src/someipy/client_service_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ServiceDiscoveryObserver,
ServiceDiscoverySender,
)
from someipy._internal.store_with_timeout import StoreWithTimeout
from someipy._internal.utils import (
create_udp_socket,
EndpointType,
Expand All @@ -60,17 +61,6 @@ def __eq__(self, value: object) -> bool:
return self.eventgroup_id == value.eventgroup_id


class FoundService:
service: SdService

def __init__(self, service: SdService) -> None:
self.service = service

def __eq__(self, __value: object) -> bool:
if isinstance(__value, FoundService):
return self.service == __value.service


class ClientServiceInstance(ServiceDiscoveryObserver):
_service: Service
_instance_id: int
Expand All @@ -84,7 +74,7 @@ class ClientServiceInstance(ServiceDiscoveryObserver):
_expected_acks: List[ExpectedAck]

_callback: Callable[[bytes], None]
_found_services: Iterable[FoundService]
_offered_services: StoreWithTimeout
_subscription_active: bool

_method_call_futures: Dict[int, asyncio.Future]
Expand Down Expand Up @@ -121,7 +111,8 @@ def __init__(
self._tcp_connection_established_event = asyncio.Event()
self._shutdown_requested = False

self._found_services = []
self._offered_services = StoreWithTimeout()

self._subscription_active = False
self._method_call_futures: Dict[int, asyncio.Future] = {}
self._client_id = client_id
Expand All @@ -146,11 +137,8 @@ def service_found(self) -> bool:
Returns whether the service instance represented by the ClientServiceInstance has been offered by a server and was found.
"""
has_service = False
for s in self._found_services:
if (
s.service.service_id == self._service.id
and s.service.instance_id == self._instance_id
):
for s in self._offered_services:
if s.service_id == self._service.id and s.instance_id == self._instance_id:
has_service = True
break
return has_service
Expand Down Expand Up @@ -201,8 +189,12 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult:
call_future = asyncio.get_running_loop().create_future()
self._method_call_futures[session_id] = call_future

dst_address = str(self._found_services[0].service.endpoint[0])
dst_port = self._found_services[0].service.endpoint[1]
# At this point the service should be found since an exception would have been raised before
for s in self._offered_services:
if s.service_id == self._service.id and s.instance_id == self._instance_id:
dst_address = str(s.endpoint[0])
dst_port = s.endpoint[1]
break

if self._protocol == TransportLayerProtocol.TCP:
# In case of TCP, first try to connect to the TCP server
Expand Down Expand Up @@ -241,9 +233,19 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult:

else:
# In case of UDP, just send out the datagram and wait for the response
# At this point the service should be found since an exception would have been raised before
for s in self._offered_services:
if (
s.service_id == self._service.id
and s.instance_id == self._instance_id
):
dst_address = str(s.endpoint[0])
dst_port = s.endpoint[1]
break

self._someip_endpoint.sendto(
someip_message.serialize(),
endpoint_to_str_int_tuple(self._found_services[0].service.endpoint),
(dst_address, dst_port),
)

# After sending the method call wait for maximum 10 seconds
Expand Down Expand Up @@ -348,6 +350,11 @@ def handle_find_service(self):
# Not needed in client service instance
pass

def _timeout_of_offered_service(self, offered_service: SdService):
get_logger(_logger_name).debug(
f"Offered service timed out: service id 0x{offered_service.service_id:04x}, instance id 0x{offered_service.instance_id:04x}"
)

def handle_offer_service(self, offered_service: SdService):
if self._service.id != offered_service.service_id:
return
Expand All @@ -367,8 +374,11 @@ def handle_offer_service(self, offered_service: SdService):
# 0xFFFFFFFF allows to handle any minor version
return

if FoundService(offered_service) not in self._found_services:
self._found_services.append(FoundService(offered_service))
asyncio.get_event_loop().create_task(
self._offered_services.add(
offered_service, self._timeout_of_offered_service
)
)

if len(self._eventgroups_to_subscribe) == 0:
return
Expand Down Expand Up @@ -424,10 +434,9 @@ def handle_stop_offer_service(self, offered_service: SdService) -> None:
if self._instance_id != offered_service.instance_id:
return

# Remove the service from the found services
self._found_services = [
f for f in self._found_services if f.service != offered_service
]
asyncio.get_event_loop().create_task(
self._offered_services.remove(offered_service)
)

self._expected_acks = []
self._subscription_active = False
Expand Down
94 changes: 94 additions & 0 deletions tests/test_store_with_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
import pytest
import pytest_asyncio
from someipy._internal.store_with_timeout import StoreWithTimeout


class MyTestObjectWithTtl:
def __init__(self, value: int, ttl: int):
self.ttl = ttl
self.value = value

def __hash__(self):
return hash(self.value)

def __eq__(self, other):
return self.value == other.value


@pytest.mark.asyncio
async def test_add():
store = StoreWithTimeout()

await store.add(MyTestObjectWithTtl(1, 2))
assert len(store.values) == 1
assert (MyTestObjectWithTtl(1, 2)) in store

await store.add(MyTestObjectWithTtl(1, 2))
assert len(store.values) == 1

await store.add(MyTestObjectWithTtl(1, 1))
assert len(store.values) == 1
await asyncio.sleep(1.5)
assert len(store.values) == 0

await store.add(MyTestObjectWithTtl(2, 2))
assert len(store.values) == 1

await asyncio.sleep(2.5)
assert len(store.values) == 0
await store.clear()


@pytest.mark.asyncio
async def test_clear():
store = StoreWithTimeout()
await store.add(MyTestObjectWithTtl(1, 5))
await store.add(MyTestObjectWithTtl(2, 5))
await store.add(MyTestObjectWithTtl(3, 5))
assert len(store.values) == 3

await store.clear()
assert len(store.values) == 0


@pytest.mark.asyncio
async def test_remove():
store = StoreWithTimeout()
await store.add(MyTestObjectWithTtl(1, 5))
await store.add(MyTestObjectWithTtl(2, 5))
await store.add(MyTestObjectWithTtl(3, 5))
assert len(store.values) == 3

await store.remove(MyTestObjectWithTtl(2, 5))
assert len(store.values) == 2
await store.clear()

# Try to remove some object that is not in the store
# This shall not raise an error
await store.remove(MyTestObjectWithTtl(2, 5))


@pytest.mark.asyncio
async def test_callback():
store = StoreWithTimeout()

callback_was_called = 0

def callback(obj):
nonlocal callback_was_called
print("Callback was called")
callback_was_called += 1

await store.add(MyTestObjectWithTtl(1, 1), callback)
await asyncio.sleep(1.5)
assert callback_was_called == 1
assert len(store.values) == 0

callback_was_called = 0

await store.add(MyTestObjectWithTtl(1, 2), callback)
assert len(store.values) == 1
await store.remove(MyTestObjectWithTtl(1, 2))
assert len(store.values) == 0
assert callback_was_called == 0
Loading