Skip to content

Commit

Permalink
Add actor state ttl support (#704)
Browse files Browse the repository at this point in the history
* feat: add set_state_ttl method and refactor set_state

Signed-off-by: KentHsu <chiahaohsu9@gmail.com>

* tests: add tests for set_state_ttl

Signed-off-by: KentHsu <chiahaohsu9@gmail.com>

* feat: update save_state method and state_change with ttl

Signed-off-by: KentHsu <chiahaohsu9@gmail.com>

* fix: update and fix test_save_state

Signed-off-by: KentHsu <chiahaohsu9@gmail.com>

* Apply ruff autoformatter

Signed-off-by: Bernd Verst <github@bernd.dev>

---------

Signed-off-by: KentHsu <chiahaohsu9@gmail.com>
Signed-off-by: Bernd Verst <github@bernd.dev>
Co-authored-by: Bernd Verst <github@bernd.dev>
  • Loading branch information
KentHsu and berndverst authored May 1, 2024
1 parent 15d0573 commit de8b454
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 11 deletions.
10 changes: 9 additions & 1 deletion dapr/actor/runtime/_state_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ async def save_state(
"operation": "upsert",
"request": {
"key": "key1",
"value": "myData"
"value": "myData",
"metadata": {
"ttlInSeconds": "3600"
}
}
},
{
Expand Down Expand Up @@ -94,6 +97,11 @@ async def save_state(
serialized = self._state_serializer.serialize(state.value)
json_output.write(b',"value":')
json_output.write(serialized)
if state.ttl_in_seconds is not None and state.ttl_in_seconds >= 0:
serialized = self._state_serializer.serialize(state.ttl_in_seconds)
json_output.write(b',"metadata":{"ttlInSeconds":"')
json_output.write(serialized)
json_output.write(b'"}')
json_output.write(b'}}')
first_state = False
json_output.write(b']')
Expand Down
15 changes: 13 additions & 2 deletions dapr/actor/runtime/state_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from enum import Enum
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Optional

T = TypeVar('T')

Expand All @@ -35,10 +35,17 @@ class StateChangeKind(Enum):


class ActorStateChange(Generic[T]):
def __init__(self, state_name: str, value: T, change_kind: StateChangeKind):
def __init__(
self,
state_name: str,
value: T,
change_kind: StateChangeKind,
ttl_in_seconds: Optional[int] = None,
):
self._state_name = state_name
self._value = value
self._change_kind = change_kind
self._ttl_in_seconds = ttl_in_seconds

@property
def state_name(self) -> str:
Expand All @@ -51,3 +58,7 @@ def value(self) -> T:
@property
def change_kind(self) -> StateChangeKind:
return self._change_kind

@property
def ttl_in_seconds(self) -> Optional[int]:
return self._ttl_in_seconds
35 changes: 31 additions & 4 deletions dapr/actor/runtime/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@


class StateMetadata(Generic[T]):
def __init__(self, value: T, change_kind: StateChangeKind):
def __init__(
self, value: T, change_kind: StateChangeKind, ttl_in_seconds: Optional[int] = None
):
self._value = value
self._change_kind = change_kind
self._ttl_in_seconds = ttl_in_seconds

@property
def value(self) -> T:
Expand All @@ -49,6 +52,14 @@ def change_kind(self) -> StateChangeKind:
def change_kind(self, new_kind: StateChangeKind) -> None:
self._change_kind = new_kind

@property
def ttl_in_seconds(self) -> Optional[int]:
return self._ttl_in_seconds

@ttl_in_seconds.setter
def ttl_in_seconds(self, new_ttl_in_seconds: int) -> None:
self._ttl_in_seconds = new_ttl_in_seconds


class ActorStateManager(Generic[T]):
def __init__(self, actor: 'Actor'):
Expand Down Expand Up @@ -103,10 +114,17 @@ async def try_get_state(self, state_name: str) -> Tuple[bool, Optional[T]]:
return has_value, val

async def set_state(self, state_name: str, value: T) -> None:
await self.set_state_ttl(state_name, value, None)

async def set_state_ttl(self, state_name: str, value: T, ttl_in_seconds: Optional[int]) -> None:
if ttl_in_seconds is not None and ttl_in_seconds < 0:
return

state_change_tracker = self._get_contextual_state_tracker()
if state_name in state_change_tracker:
state_metadata = state_change_tracker[state_name]
state_metadata.value = value
state_metadata.ttl_in_seconds = ttl_in_seconds

if (
state_metadata.change_kind == StateChangeKind.none
Expand All @@ -120,9 +138,13 @@ async def set_state(self, state_name: str, value: T) -> None:
self._type_name, self._actor.id.id, state_name
)
if existed:
state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.update)
state_change_tracker[state_name] = StateMetadata(
value, StateChangeKind.update, ttl_in_seconds
)
else:
state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
state_change_tracker[state_name] = StateMetadata(
value, StateChangeKind.add, ttl_in_seconds
)

async def remove_state(self, state_name: str) -> None:
if not await self.try_remove_state(state_name):
Expand Down Expand Up @@ -231,7 +253,12 @@ async def save_state(self) -> None:
if state_metadata.change_kind == StateChangeKind.none:
continue
state_changes.append(
ActorStateChange(state_name, state_metadata.value, state_metadata.change_kind)
ActorStateChange(
state_name,
state_metadata.value,
state_metadata.change_kind,
state_metadata.ttl_in_seconds,
)
)
if state_metadata.change_kind == StateChangeKind.remove:
states_to_remove.append(state_name)
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ pyOpenSSL>=23.2.0
# needed for type checking
Flask>=1.1
# needed for auto fix
ruff===0.4.2
ruff===0.2.2
# needed for dapr-ext-workflow
durabletask>=0.1.1a1
82 changes: 79 additions & 3 deletions tests/actor/test_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_set_state_for_new_state(self):
state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value1', state.value)
self.assertEqual(None, state.ttl_in_seconds)

@mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock())
def test_set_state_for_existing_state_only_in_mem(self):
Expand All @@ -131,6 +132,7 @@ def test_set_state_for_existing_state_only_in_mem(self):
state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value2', state.value)
self.assertEqual(None, state.ttl_in_seconds)

@mock.patch(
'tests.actor.fake_client.FakeDaprActorClient.get_state',
Expand All @@ -143,6 +145,73 @@ def test_set_state_for_existing_state(self):
state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.update, state.change_kind)
self.assertEqual('value2', state.value)
self.assertEqual(None, state.ttl_in_seconds)

@mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock())
def test_set_state_ttl_for_new_state(self):
state_manager = ActorStateManager(self._fake_actor)
state_change_tracker = state_manager._get_contextual_state_tracker()
_run(state_manager.set_state_ttl('state1', 'value1', 3600))

state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value1', state.value)
self.assertEqual(3600, state.ttl_in_seconds)

@mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock())
def test_set_state_ttl_for_existing_state_only_in_mem(self):
state_manager = ActorStateManager(self._fake_actor)
state_change_tracker = state_manager._get_contextual_state_tracker()
_run(state_manager.set_state_ttl('state1', 'value1', 3600))

state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value1', state.value)
self.assertEqual(3600, state.ttl_in_seconds)

_run(state_manager.set_state_ttl('state1', 'value2', 7200))
state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value2', state.value)
self.assertEqual(7200, state.ttl_in_seconds)

@mock.patch(
'tests.actor.fake_client.FakeDaprActorClient.get_state',
new=_async_mock(return_value=b'"value1"'),
)
def test_set_state_ttl_for_existing_state(self):
state_manager = ActorStateManager(self._fake_actor)
state_change_tracker = state_manager._get_contextual_state_tracker()
_run(state_manager.set_state_ttl('state1', 'value2', 3600))

state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.update, state.change_kind)
self.assertEqual('value2', state.value)
self.assertEqual(3600, state.ttl_in_seconds)

@mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock())
def test_set_state_ttl_lt_0_for_new_state(self):
state_manager = ActorStateManager(self._fake_actor)
state_change_tracker = state_manager._get_contextual_state_tracker()
_run(state_manager.set_state_ttl('state1', 'value1', -3600))
self.assertNotIn('state1', state_change_tracker)

@mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock())
def test_set_state_ttl_lt_0_for_existing_state_only_in_mem(self):
state_manager = ActorStateManager(self._fake_actor)
state_change_tracker = state_manager._get_contextual_state_tracker()
_run(state_manager.set_state_ttl('state1', 'value1', 3600))

state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value1', state.value)
self.assertEqual(3600, state.ttl_in_seconds)

_run(state_manager.set_state_ttl('state1', 'value2', -3600))
state = state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value1', state.value)
self.assertEqual(3600, state.ttl_in_seconds)

@mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock())
def test_remove_state_for_non_existing_state(self):
Expand Down Expand Up @@ -360,13 +429,20 @@ def test_save_state(self):
_run(state_manager.remove_state('state4'))
# set state which is StateChangeKind.update
_run(state_manager.set_state('state5', 'value5'))
expected = b'[{"operation":"upsert","request":{"key":"state1","value":"value1"}},{"operation":"upsert","request":{"key":"state2","value":"value2"}},{"operation":"delete","request":{"key":"state4"}},{"operation":"upsert","request":{"key":"state5","value":"value5"}}]' # noqa: E501
_run(state_manager.set_state('state5', 'new_value5'))
# set state with ttl >= 0
_run(state_manager.set_state_ttl('state6', 'value6', 3600))
_run(state_manager.set_state_ttl('state7', 'value7', 0))
# set state with ttl < 0
_run(state_manager.set_state_ttl('state8', 'value8', -3600))

expected = b'[{"operation":"upsert","request":{"key":"state1","value":"value1"}},{"operation":"upsert","request":{"key":"state2","value":"value2"}},{"operation":"delete","request":{"key":"state4"}},{"operation":"upsert","request":{"key":"state5","value":"new_value5"}},{"operation":"upsert","request":{"key":"state6","value":"value6","metadata":{"ttlInSeconds":"3600"}}},{"operation":"upsert","request":{"key":"state7","value":"value7","metadata":{"ttlInSeconds":"0"}}}]' # noqa: E501

# Save the state
def mock_save_state(actor_type, actor_id, data):
async def mock_save_state(actor_type, actor_id, data):
self.assertEqual(expected, data)

self._fake_client.save_state_transactionally.mock = mock_save_state
self._fake_client.save_state_transactionally = mock_save_state
_run(state_manager.save_state())


Expand Down

0 comments on commit de8b454

Please sign in to comment.