Skip to content

Commit

Permalink
Feature/implement fast load lock (#276)
Browse files Browse the repository at this point in the history
* add redis test

* fix yapf

* fix isort

* use mock to test function arguments

* add fast redis lock

fix test

fix test

* restore lock

* fix test

* restore existance check at TargetOnKart.remove()

* add doc

* fix docs

* remove unnecessary try block
  • Loading branch information
mski-iksm authored Mar 28, 2022
1 parent d41436b commit 904268f
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 41 deletions.
2 changes: 1 addition & 1 deletion docs/using_task_cache_collision_lock.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ This setting must be done to each gokart task which you want to lock the ``run()
from gokart.run_with_lock import RunWithLock
@RunWithLock
class SomeTask(gokart.TaskOnKart):
@RunWithLock
def run(self):
...
Expand Down
114 changes: 95 additions & 19 deletions gokart/redis_lock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import os
from logging import getLogger
from typing import NamedTuple
from typing import Callable, NamedTuple

import redis
from apscheduler.schedulers.background import BackgroundScheduler
Expand Down Expand Up @@ -41,28 +42,39 @@ def get_redis_client(self):
return self._redis_client


def with_lock(func, redis_params: RedisParams):
def _extend_lock(redis_lock: redis.lock.Lock, redis_timeout: int):
redis_lock.extend(additional_time=redis_timeout, replace_ttl=True)


def _set_redis_lock(redis_params: RedisParams) -> RedisClient:
redis_client = RedisClient(host=redis_params.redis_host, port=redis_params.redis_port).get_redis_client()
blocking = not redis_params.redis_fail_on_collision
redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, thread_local=False)
if not redis_lock.acquire(blocking=blocking):
raise TaskLockException('Lock already taken by other task.')
return redis_lock


def _set_lock_scheduler(redis_lock: redis.lock.Lock, redis_params: RedisParams) -> BackgroundScheduler:
scheduler = BackgroundScheduler()
extend_lock = functools.partial(_extend_lock, redis_lock=redis_lock, redis_timeout=redis_params.redis_timeout)
scheduler.add_job(extend_lock,
'interval',
seconds=redis_params.lock_extend_seconds,
max_instances=999999999,
misfire_grace_time=redis_params.redis_timeout,
coalesce=False)
scheduler.start()
return scheduler


def _wrap_with_lock(func, redis_params: RedisParams):
if not redis_params.should_redis_lock:
return func

def wrapper(*args, **kwargs):
redis_client = RedisClient(host=redis_params.redis_host, port=redis_params.redis_port).get_redis_client()
blocking = not redis_params.redis_fail_on_collision
redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, thread_local=False)
if not redis_lock.acquire(blocking=blocking):
raise TaskLockException('Lock already taken by other task.')

def extend_lock():
redis_lock.extend(additional_time=redis_params.redis_timeout, replace_ttl=True)

scheduler = BackgroundScheduler()
scheduler.add_job(extend_lock,
'interval',
seconds=redis_params.lock_extend_seconds,
max_instances=999999999,
misfire_grace_time=redis_params.redis_timeout,
coalesce=False)
scheduler.start()
redis_lock = _set_redis_lock(redis_params=redis_params)
scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params)

try:
logger.debug(f'Task lock of {redis_params.redis_key} locked.')
Expand All @@ -80,6 +92,70 @@ def extend_lock():
return wrapper


def wrap_with_run_lock(func, redis_params: RedisParams):
"""Redis lock wrapper function for RunWithLock.
When a fucntion is wrapped by RunWithLock, the wrapped function will be simply wrapped with redis lock.
https://github.com/m3dev/gokart/issues/265
"""
return _wrap_with_lock(func=func, redis_params=redis_params)


def wrap_with_dump_lock(func: Callable, redis_params: RedisParams, exist_check: Callable):
"""Redis lock wrapper function for TargetOnKart.dump().
When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check.
https://github.com/m3dev/gokart/issues/265
"""

if not redis_params.should_redis_lock:
return func

def wrapper(*args, **kwargs):
redis_lock = _set_redis_lock(redis_params=redis_params)
scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params)

try:
logger.debug(f'Task lock of {redis_params.redis_key} locked.')
if not exist_check():
func(*args, **kwargs)
finally:
logger.debug(f'Task lock of {redis_params.redis_key} released.')
redis_lock.release()
scheduler.shutdown()

return wrapper


def wrap_with_load_lock(func, redis_params: RedisParams):
"""Redis lock wrapper function for TargetOnKart.load().
When TargetOnKart.load() is called, redis lock will be locked and released before load().
https://github.com/m3dev/gokart/issues/265
"""

if not redis_params.should_redis_lock:
return func

def wrapper(*args, **kwargs):
redis_lock = _set_redis_lock(redis_params=redis_params)
scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params)

logger.debug(f'Task lock of {redis_params.redis_key} locked.')
redis_lock.release()
logger.debug(f'Task lock of {redis_params.redis_key} released.')
scheduler.shutdown()
result = func(*args, **kwargs)
return result

return wrapper


def wrap_with_remove_lock(func, redis_params: RedisParams):
"""Redis lock wrapper function for TargetOnKart.remove().
When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock.
https://github.com/m3dev/gokart/issues/265
"""
return _wrap_with_lock(func=func, redis_params=redis_params)


def make_redis_key(file_path: str, unique_id: str):
basename_without_ext = os.path.splitext(os.path.basename(file_path))[0]
return f'{basename_without_ext}_{unique_id}'
Expand Down
10 changes: 5 additions & 5 deletions gokart/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from gokart.file_processor import FileProcessor, make_file_processor
from gokart.object_storage import ObjectStorage
from gokart.redis_lock import RedisParams, make_redis_params, with_lock
from gokart.redis_lock import RedisParams, make_redis_params, wrap_with_dump_lock, wrap_with_load_lock, wrap_with_remove_lock, wrap_with_run_lock
from gokart.zip_client_util import make_zip_client

logger = getLogger(__name__)
Expand All @@ -26,17 +26,17 @@ def exists(self) -> bool:
return self._exists()

def load(self) -> Any:
return self.wrap_with_lock(self._load)()
return wrap_with_load_lock(func=self._load, redis_params=self._get_redis_params())()

def dump(self, obj, lock_at_dump: bool = True) -> None:
if lock_at_dump:
self.wrap_with_lock(self._dump)(obj)
wrap_with_dump_lock(func=self._dump, redis_params=self._get_redis_params(), exist_check=self.exists)(obj)
else:
self._dump(obj)

def remove(self) -> None:
if self.exists():
self.wrap_with_lock(self._remove)()
wrap_with_remove_lock(self._remove, redis_params=self._get_redis_params())()

def last_modification_time(self) -> datetime:
return self._last_modification_time()
Expand All @@ -45,7 +45,7 @@ def path(self) -> str:
return self._path()

def wrap_with_lock(self, func):
return with_lock(func=func, redis_params=self._get_redis_params())
return wrap_with_run_lock(func=func, redis_params=self._get_redis_params())

@abstractmethod
def _exists(self) -> bool:
Expand Down
Loading

0 comments on commit 904268f

Please sign in to comment.