Skip to content

Commit 3ebdba2

Browse files
committed
Implement timeouts/ttl of offered services
1 parent 0af1e53 commit 3ebdba2

File tree

5 files changed

+215
-28
lines changed

5 files changed

+215
-28
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
python -m pip install --upgrade pip
2727
pip install -e .
2828
pip install pytest
29+
pip install pytest-asyncio
2930
3031
- name: Run pytest
3132
run: pytest tests

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,5 @@ The library is still under development. The current major limitations and deviat
6161
### Service Discovery
6262

6363
- Configuration and load balancing options in SOME/IP SD messages are not supported.
64-
- TTL of Service Discovery entries is not checked yet.
64+
- Stop subscribe message of notifications is not supported
6565
- The Initial Wait Phase and Repetition Phase of the Service Discovery specification are skipped. The Main Phase is directly entered, i.e. SD Offer Entries are immediately sent cyclically.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import asyncio
2+
from typing import Set, Generic, TypeVar, Protocol
3+
4+
5+
T = TypeVar("T")
6+
7+
8+
class ObjectWithTtl(Protocol):
9+
ttl: int # The TTL of the object in seconds
10+
11+
12+
class StoreWithTimeout:
13+
14+
class ObjectWithTtlWrapper:
15+
def __init__(self, value: ObjectWithTtl):
16+
self.value: ObjectWithTtl = value
17+
self.timeout_task: asyncio.Task = None
18+
self.active = False # Flag needed to prevent callbacks being called during cancellation
19+
20+
def __eq__(self, other):
21+
# Two objects are equal if their values are equal even if the timeout is different
22+
return self.value == other.value
23+
24+
def __hash__(self):
25+
return hash(self.value)
26+
27+
async def _wait(self):
28+
await asyncio.sleep(self.value.ttl)
29+
30+
def __init__(self):
31+
self.values: Set[self.ObjectWithTtlWrapper] = set()
32+
self._current = 0
33+
34+
async def add(self, object_to_add: ObjectWithTtl, callback=None):
35+
wrapper = self.ObjectWithTtlWrapper(object_to_add)
36+
37+
if wrapper in self.values:
38+
await self.remove(wrapper.value)
39+
40+
wrapper.timeout_task = asyncio.create_task(wrapper._wait())
41+
wrapper.active = True
42+
wrapper.timeout_task.add_done_callback(
43+
lambda _: self._done_callback(wrapper, callback)
44+
)
45+
self.values.add(wrapper)
46+
47+
def _done_callback(self, caller: ObjectWithTtlWrapper, callback=None):
48+
self.values.discard(caller)
49+
if caller.active is True and callback is not None:
50+
callback(caller.value)
51+
52+
async def remove(self, object_to_remove: ObjectWithTtl):
53+
wrapper = self.ObjectWithTtlWrapper(object_to_remove)
54+
if wrapper in self.values:
55+
for value in self.values:
56+
if value == wrapper:
57+
value.active = False
58+
value.timeout_task.cancel()
59+
try:
60+
await value.timeout_task
61+
except asyncio.CancelledError:
62+
pass
63+
break
64+
65+
async def clear(self):
66+
while len(self.values) > 0:
67+
await self.remove(next(iter(self.values)).value)
68+
69+
def __contains__(self, item):
70+
wrapper = self.ObjectWithTtlWrapper(item)
71+
return wrapper in self.values
72+
73+
def __iter__(self):
74+
self._current = 0
75+
return self
76+
77+
def __next__(self):
78+
if self._current < len(self.values):
79+
result = next(iter(self.values))
80+
self._current += 1
81+
return result.value
82+
else:
83+
raise StopIteration

src/someipy/client_service_instance.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ServiceDiscoveryObserver,
3535
ServiceDiscoverySender,
3636
)
37+
from someipy._internal.store_with_timeout import StoreWithTimeout
3738
from someipy._internal.utils import (
3839
create_udp_socket,
3940
EndpointType,
@@ -60,17 +61,6 @@ def __eq__(self, value: object) -> bool:
6061
return self.eventgroup_id == value.eventgroup_id
6162

6263

63-
class FoundService:
64-
service: SdService
65-
66-
def __init__(self, service: SdService) -> None:
67-
self.service = service
68-
69-
def __eq__(self, __value: object) -> bool:
70-
if isinstance(__value, FoundService):
71-
return self.service == __value.service
72-
73-
7464
class ClientServiceInstance(ServiceDiscoveryObserver):
7565
_service: Service
7666
_instance_id: int
@@ -84,7 +74,7 @@ class ClientServiceInstance(ServiceDiscoveryObserver):
8474
_expected_acks: List[ExpectedAck]
8575

8676
_callback: Callable[[bytes], None]
87-
_found_services: Iterable[FoundService]
77+
_offered_services: StoreWithTimeout
8878
_subscription_active: bool
8979

9080
_method_call_futures: Dict[int, asyncio.Future]
@@ -121,7 +111,8 @@ def __init__(
121111
self._tcp_connection_established_event = asyncio.Event()
122112
self._shutdown_requested = False
123113

124-
self._found_services = []
114+
self._offered_services = StoreWithTimeout()
115+
125116
self._subscription_active = False
126117
self._method_call_futures: Dict[int, asyncio.Future] = {}
127118
self._client_id = client_id
@@ -146,11 +137,8 @@ def service_found(self) -> bool:
146137
Returns whether the service instance represented by the ClientServiceInstance has been offered by a server and was found.
147138
"""
148139
has_service = False
149-
for s in self._found_services:
150-
if (
151-
s.service.service_id == self._service.id
152-
and s.service.instance_id == self._instance_id
153-
):
140+
for s in self._offered_services:
141+
if s.service_id == self._service.id and s.instance_id == self._instance_id:
154142
has_service = True
155143
break
156144
return has_service
@@ -201,8 +189,12 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult:
201189
call_future = asyncio.get_running_loop().create_future()
202190
self._method_call_futures[session_id] = call_future
203191

204-
dst_address = str(self._found_services[0].service.endpoint[0])
205-
dst_port = self._found_services[0].service.endpoint[1]
192+
# At this point the service should be found since an exception would have been raised before
193+
for s in self._offered_services:
194+
if s.service_id == self._service.id and s.instance_id == self._instance_id:
195+
dst_address = str(s.endpoint[0])
196+
dst_port = s.endpoint[1]
197+
break
206198

207199
if self._protocol == TransportLayerProtocol.TCP:
208200
# In case of TCP, first try to connect to the TCP server
@@ -241,9 +233,19 @@ async def call_method(self, method_id: int, payload: bytes) -> MethodResult:
241233

242234
else:
243235
# In case of UDP, just send out the datagram and wait for the response
236+
# At this point the service should be found since an exception would have been raised before
237+
for s in self._offered_services:
238+
if (
239+
s.service_id == self._service.id
240+
and s.instance_id == self._instance_id
241+
):
242+
dst_address = str(s.endpoint[0])
243+
dst_port = s.endpoint[1]
244+
break
245+
244246
self._someip_endpoint.sendto(
245247
someip_message.serialize(),
246-
endpoint_to_str_int_tuple(self._found_services[0].service.endpoint),
248+
(dst_address, dst_port),
247249
)
248250

249251
# After sending the method call wait for maximum 10 seconds
@@ -348,6 +350,11 @@ def handle_find_service(self):
348350
# Not needed in client service instance
349351
pass
350352

353+
def _timeout_of_offered_service(self, offered_service: SdService):
354+
get_logger(_logger_name).debug(
355+
f"Offered service timed out: service id 0x{offered_service.service_id:04x}, instance id 0x{offered_service.instance_id:04x}"
356+
)
357+
351358
def handle_offer_service(self, offered_service: SdService):
352359
if self._service.id != offered_service.service_id:
353360
return
@@ -367,8 +374,11 @@ def handle_offer_service(self, offered_service: SdService):
367374
# 0xFFFFFFFF allows to handle any minor version
368375
return
369376

370-
if FoundService(offered_service) not in self._found_services:
371-
self._found_services.append(FoundService(offered_service))
377+
asyncio.get_event_loop().create_task(
378+
self._offered_services.add(
379+
offered_service, self._timeout_of_offered_service
380+
)
381+
)
372382

373383
if len(self._eventgroups_to_subscribe) == 0:
374384
return
@@ -424,10 +434,9 @@ def handle_stop_offer_service(self, offered_service: SdService) -> None:
424434
if self._instance_id != offered_service.instance_id:
425435
return
426436

427-
# Remove the service from the found services
428-
self._found_services = [
429-
f for f in self._found_services if f.service != offered_service
430-
]
437+
asyncio.get_event_loop().create_task(
438+
self._offered_services.remove(offered_service)
439+
)
431440

432441
self._expected_acks = []
433442
self._subscription_active = False

tests/test_store_with_timeout.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import asyncio
2+
import pytest
3+
import pytest_asyncio
4+
from someipy._internal.store_with_timeout import StoreWithTimeout
5+
6+
7+
class MyTestObjectWithTtl:
8+
def __init__(self, value: int, ttl: int):
9+
self.ttl = ttl
10+
self.value = value
11+
12+
def __hash__(self):
13+
return hash(self.value)
14+
15+
def __eq__(self, other):
16+
return self.value == other.value
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_add():
21+
store = StoreWithTimeout()
22+
23+
await store.add(MyTestObjectWithTtl(1, 2))
24+
assert len(store.values) == 1
25+
assert (MyTestObjectWithTtl(1, 2)) in store
26+
27+
await store.add(MyTestObjectWithTtl(1, 2))
28+
assert len(store.values) == 1
29+
30+
await store.add(MyTestObjectWithTtl(1, 1))
31+
assert len(store.values) == 1
32+
await asyncio.sleep(1.5)
33+
assert len(store.values) == 0
34+
35+
await store.add(MyTestObjectWithTtl(2, 2))
36+
assert len(store.values) == 1
37+
38+
await asyncio.sleep(2.5)
39+
assert len(store.values) == 0
40+
await store.clear()
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_clear():
45+
store = StoreWithTimeout()
46+
await store.add(MyTestObjectWithTtl(1, 5))
47+
await store.add(MyTestObjectWithTtl(2, 5))
48+
await store.add(MyTestObjectWithTtl(3, 5))
49+
assert len(store.values) == 3
50+
51+
await store.clear()
52+
assert len(store.values) == 0
53+
54+
55+
@pytest.mark.asyncio
56+
async def test_remove():
57+
store = StoreWithTimeout()
58+
await store.add(MyTestObjectWithTtl(1, 5))
59+
await store.add(MyTestObjectWithTtl(2, 5))
60+
await store.add(MyTestObjectWithTtl(3, 5))
61+
assert len(store.values) == 3
62+
63+
await store.remove(MyTestObjectWithTtl(2, 5))
64+
assert len(store.values) == 2
65+
await store.clear()
66+
67+
# Try to remove some object that is not in the store
68+
# This shall not raise an error
69+
await store.remove(MyTestObjectWithTtl(2, 5))
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_callback():
74+
store = StoreWithTimeout()
75+
76+
callback_was_called = 0
77+
78+
def callback():
79+
nonlocal callback_was_called
80+
print("Callback was called")
81+
callback_was_called += 1
82+
83+
await store.add(MyTestObjectWithTtl(1, 1), callback)
84+
await asyncio.sleep(1.5)
85+
assert callback_was_called == 1
86+
assert len(store.values) == 0
87+
88+
callback_was_called = 0
89+
90+
await store.add(MyTestObjectWithTtl(1, 2), callback)
91+
assert len(store.values) == 1
92+
await store.remove(MyTestObjectWithTtl(1, 2))
93+
assert len(store.values) == 0
94+
assert callback_was_called == 0

0 commit comments

Comments
 (0)