-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement timeouts/ttl of offered services
- Loading branch information
Showing
5 changed files
with
215 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import asyncio | ||
from typing import Set, Generic, TypeVar, Protocol | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
|
||
class ObjectWithTtl(Protocol): | ||
ttl: int # The TTL of the object in seconds | ||
|
||
|
||
class StoreWithTimeout: | ||
|
||
class ObjectWithTtlWrapper: | ||
def __init__(self, value: ObjectWithTtl): | ||
self.value: ObjectWithTtl = value | ||
self.timeout_task: asyncio.Task = None | ||
self.active = False # Flag needed to prevent callbacks being called during cancellation | ||
|
||
def __eq__(self, other): | ||
# Two objects are equal if their values are equal even if the timeout is different | ||
return self.value == other.value | ||
|
||
def __hash__(self): | ||
return hash(self.value) | ||
|
||
async def _wait(self): | ||
await asyncio.sleep(self.value.ttl) | ||
|
||
def __init__(self): | ||
self.values: Set[self.ObjectWithTtlWrapper] = set() | ||
self._current = 0 | ||
|
||
async def add(self, object_to_add: ObjectWithTtl, callback=None): | ||
wrapper = self.ObjectWithTtlWrapper(object_to_add) | ||
|
||
if wrapper in self.values: | ||
await self.remove(wrapper.value) | ||
|
||
wrapper.timeout_task = asyncio.create_task(wrapper._wait()) | ||
wrapper.active = True | ||
wrapper.timeout_task.add_done_callback( | ||
lambda _: self._done_callback(wrapper, callback) | ||
) | ||
self.values.add(wrapper) | ||
|
||
def _done_callback(self, caller: ObjectWithTtlWrapper, callback=None): | ||
self.values.discard(caller) | ||
if caller.active is True and callback is not None: | ||
callback(caller.value) | ||
|
||
async def remove(self, object_to_remove: ObjectWithTtl): | ||
wrapper = self.ObjectWithTtlWrapper(object_to_remove) | ||
if wrapper in self.values: | ||
for value in self.values: | ||
if value == wrapper: | ||
value.active = False | ||
value.timeout_task.cancel() | ||
try: | ||
await value.timeout_task | ||
except asyncio.CancelledError: | ||
pass | ||
break | ||
|
||
async def clear(self): | ||
while len(self.values) > 0: | ||
await self.remove(next(iter(self.values)).value) | ||
|
||
def __contains__(self, item): | ||
wrapper = self.ObjectWithTtlWrapper(item) | ||
return wrapper in self.values | ||
|
||
def __iter__(self): | ||
self._current = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self._current < len(self.values): | ||
result = next(iter(self.values)) | ||
self._current += 1 | ||
return result.value | ||
else: | ||
raise StopIteration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import asyncio | ||
import pytest | ||
import pytest_asyncio | ||
from someipy._internal.store_with_timeout import StoreWithTimeout | ||
|
||
|
||
class MyTestObjectWithTtl: | ||
def __init__(self, value: int, ttl: int): | ||
self.ttl = ttl | ||
self.value = value | ||
|
||
def __hash__(self): | ||
return hash(self.value) | ||
|
||
def __eq__(self, other): | ||
return self.value == other.value | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_add(): | ||
store = StoreWithTimeout() | ||
|
||
await store.add(MyTestObjectWithTtl(1, 2)) | ||
assert len(store.values) == 1 | ||
assert (MyTestObjectWithTtl(1, 2)) in store | ||
|
||
await store.add(MyTestObjectWithTtl(1, 2)) | ||
assert len(store.values) == 1 | ||
|
||
await store.add(MyTestObjectWithTtl(1, 1)) | ||
assert len(store.values) == 1 | ||
await asyncio.sleep(1.5) | ||
assert len(store.values) == 0 | ||
|
||
await store.add(MyTestObjectWithTtl(2, 2)) | ||
assert len(store.values) == 1 | ||
|
||
await asyncio.sleep(2.5) | ||
assert len(store.values) == 0 | ||
await store.clear() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_clear(): | ||
store = StoreWithTimeout() | ||
await store.add(MyTestObjectWithTtl(1, 5)) | ||
await store.add(MyTestObjectWithTtl(2, 5)) | ||
await store.add(MyTestObjectWithTtl(3, 5)) | ||
assert len(store.values) == 3 | ||
|
||
await store.clear() | ||
assert len(store.values) == 0 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_remove(): | ||
store = StoreWithTimeout() | ||
await store.add(MyTestObjectWithTtl(1, 5)) | ||
await store.add(MyTestObjectWithTtl(2, 5)) | ||
await store.add(MyTestObjectWithTtl(3, 5)) | ||
assert len(store.values) == 3 | ||
|
||
await store.remove(MyTestObjectWithTtl(2, 5)) | ||
assert len(store.values) == 2 | ||
await store.clear() | ||
|
||
# Try to remove some object that is not in the store | ||
# This shall not raise an error | ||
await store.remove(MyTestObjectWithTtl(2, 5)) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_callback(): | ||
store = StoreWithTimeout() | ||
|
||
callback_was_called = 0 | ||
|
||
def callback(obj): | ||
nonlocal callback_was_called | ||
print("Callback was called") | ||
callback_was_called += 1 | ||
|
||
await store.add(MyTestObjectWithTtl(1, 1), callback) | ||
await asyncio.sleep(1.5) | ||
assert callback_was_called == 1 | ||
assert len(store.values) == 0 | ||
|
||
callback_was_called = 0 | ||
|
||
await store.add(MyTestObjectWithTtl(1, 2), callback) | ||
assert len(store.values) == 1 | ||
await store.remove(MyTestObjectWithTtl(1, 2)) | ||
assert len(store.values) == 0 | ||
assert callback_was_called == 0 |