diff --git a/docs/using_task_cache_collision_lock.rst b/docs/using_task_cache_collision_lock.rst index e750fb4d..6c6ab607 100644 --- a/docs/using_task_cache_collision_lock.rst +++ b/docs/using_task_cache_collision_lock.rst @@ -81,8 +81,8 @@ This setting must be done to each gokart task which you want to lock the ``run() from gokart.run_with_lock import RunWithLock - @RunWithLock class SomeTask(gokart.TaskOnKart): + @RunWithLock def run(self): ... diff --git a/gokart/redis_lock.py b/gokart/redis_lock.py index eac6e41f..b59890b9 100644 --- a/gokart/redis_lock.py +++ b/gokart/redis_lock.py @@ -1,6 +1,7 @@ +import functools import os from logging import getLogger -from typing import NamedTuple +from typing import Callable, NamedTuple import redis from apscheduler.schedulers.background import BackgroundScheduler @@ -41,28 +42,39 @@ def get_redis_client(self): return self._redis_client -def with_lock(func, redis_params: RedisParams): +def _extend_lock(redis_lock: redis.lock.Lock, redis_timeout: int): + redis_lock.extend(additional_time=redis_timeout, replace_ttl=True) + + +def _set_redis_lock(redis_params: RedisParams) -> RedisClient: + redis_client = RedisClient(host=redis_params.redis_host, port=redis_params.redis_port).get_redis_client() + blocking = not redis_params.redis_fail_on_collision + redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, thread_local=False) + if not redis_lock.acquire(blocking=blocking): + raise TaskLockException('Lock already taken by other task.') + return redis_lock + + +def _set_lock_scheduler(redis_lock: redis.lock.Lock, redis_params: RedisParams) -> BackgroundScheduler: + scheduler = BackgroundScheduler() + extend_lock = functools.partial(_extend_lock, redis_lock=redis_lock, redis_timeout=redis_params.redis_timeout) + scheduler.add_job(extend_lock, + 'interval', + seconds=redis_params.lock_extend_seconds, + max_instances=999999999, + misfire_grace_time=redis_params.redis_timeout, + coalesce=False) + scheduler.start() + return scheduler + + +def _wrap_with_lock(func, redis_params: RedisParams): if not redis_params.should_redis_lock: return func def wrapper(*args, **kwargs): - redis_client = RedisClient(host=redis_params.redis_host, port=redis_params.redis_port).get_redis_client() - blocking = not redis_params.redis_fail_on_collision - redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, thread_local=False) - if not redis_lock.acquire(blocking=blocking): - raise TaskLockException('Lock already taken by other task.') - - def extend_lock(): - redis_lock.extend(additional_time=redis_params.redis_timeout, replace_ttl=True) - - scheduler = BackgroundScheduler() - scheduler.add_job(extend_lock, - 'interval', - seconds=redis_params.lock_extend_seconds, - max_instances=999999999, - misfire_grace_time=redis_params.redis_timeout, - coalesce=False) - scheduler.start() + redis_lock = _set_redis_lock(redis_params=redis_params) + scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params) try: logger.debug(f'Task lock of {redis_params.redis_key} locked.') @@ -80,6 +92,70 @@ def extend_lock(): return wrapper +def wrap_with_run_lock(func, redis_params: RedisParams): + """Redis lock wrapper function for RunWithLock. + When a fucntion is wrapped by RunWithLock, the wrapped function will be simply wrapped with redis lock. + https://github.com/m3dev/gokart/issues/265 + """ + return _wrap_with_lock(func=func, redis_params=redis_params) + + +def wrap_with_dump_lock(func: Callable, redis_params: RedisParams, exist_check: Callable): + """Redis lock wrapper function for TargetOnKart.dump(). + When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check. + https://github.com/m3dev/gokart/issues/265 + """ + + if not redis_params.should_redis_lock: + return func + + def wrapper(*args, **kwargs): + redis_lock = _set_redis_lock(redis_params=redis_params) + scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params) + + try: + logger.debug(f'Task lock of {redis_params.redis_key} locked.') + if not exist_check(): + func(*args, **kwargs) + finally: + logger.debug(f'Task lock of {redis_params.redis_key} released.') + redis_lock.release() + scheduler.shutdown() + + return wrapper + + +def wrap_with_load_lock(func, redis_params: RedisParams): + """Redis lock wrapper function for TargetOnKart.load(). + When TargetOnKart.load() is called, redis lock will be locked and released before load(). + https://github.com/m3dev/gokart/issues/265 + """ + + if not redis_params.should_redis_lock: + return func + + def wrapper(*args, **kwargs): + redis_lock = _set_redis_lock(redis_params=redis_params) + scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params) + + logger.debug(f'Task lock of {redis_params.redis_key} locked.') + redis_lock.release() + logger.debug(f'Task lock of {redis_params.redis_key} released.') + scheduler.shutdown() + result = func(*args, **kwargs) + return result + + return wrapper + + +def wrap_with_remove_lock(func, redis_params: RedisParams): + """Redis lock wrapper function for TargetOnKart.remove(). + When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock. + https://github.com/m3dev/gokart/issues/265 + """ + return _wrap_with_lock(func=func, redis_params=redis_params) + + def make_redis_key(file_path: str, unique_id: str): basename_without_ext = os.path.splitext(os.path.basename(file_path))[0] return f'{basename_without_ext}_{unique_id}' diff --git a/gokart/target.py b/gokart/target.py index 11415ff6..b888ff50 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -14,7 +14,7 @@ from gokart.file_processor import FileProcessor, make_file_processor from gokart.object_storage import ObjectStorage -from gokart.redis_lock import RedisParams, make_redis_params, with_lock +from gokart.redis_lock import RedisParams, make_redis_params, wrap_with_dump_lock, wrap_with_load_lock, wrap_with_remove_lock, wrap_with_run_lock from gokart.zip_client_util import make_zip_client logger = getLogger(__name__) @@ -26,17 +26,17 @@ def exists(self) -> bool: return self._exists() def load(self) -> Any: - return self.wrap_with_lock(self._load)() + return wrap_with_load_lock(func=self._load, redis_params=self._get_redis_params())() def dump(self, obj, lock_at_dump: bool = True) -> None: if lock_at_dump: - self.wrap_with_lock(self._dump)(obj) + wrap_with_dump_lock(func=self._dump, redis_params=self._get_redis_params(), exist_check=self.exists)(obj) else: self._dump(obj) def remove(self) -> None: if self.exists(): - self.wrap_with_lock(self._remove)() + wrap_with_remove_lock(self._remove, redis_params=self._get_redis_params())() def last_modification_time(self) -> datetime: return self._last_modification_time() @@ -45,7 +45,7 @@ def path(self) -> str: return self._path() def wrap_with_lock(self, func): - return with_lock(func=func, redis_params=self._get_redis_params()) + return wrap_with_run_lock(func=func, redis_params=self._get_redis_params()) @abstractmethod def _exists(self) -> bool: diff --git a/test/test_redis_lock.py b/test/test_redis_lock.py index 0b80c468..990f2a73 100644 --- a/test/test_redis_lock.py +++ b/test/test_redis_lock.py @@ -5,7 +5,7 @@ import fakeredis -from gokart.redis_lock import RedisClient, RedisParams, make_redis_key, make_redis_params, with_lock +from gokart.redis_lock import RedisClient, RedisParams, make_redis_key, make_redis_params, wrap_with_dump_lock, wrap_with_remove_lock, wrap_with_run_lock class TestRedisClient(unittest.TestCase): @@ -28,16 +28,16 @@ def test_redis_client_is_singleton(self): self.assertEqual(redis_client_0_0.get_redis_client(), redis_client_0_1.get_redis_client()) -class TestWithLock(unittest.TestCase): +def _sample_func_with_error(a: int, b: str = None): + raise Exception() - @staticmethod - def _sample_func_with_error(a: int, b: str = None): - raise Exception() - @staticmethod - def _sample_long_func(a: int, b: str = None): - time.sleep(3) - return dict(a=a, b=b) +def _sample_long_func(a: int, b: str = None): + time.sleep(2.7) + return dict(a=a, b=b) + + +class TestWrapWithRunLock(unittest.TestCase): def test_no_redis(self): redis_params = make_redis_params( @@ -47,7 +47,213 @@ def test_no_redis(self): redis_port=None, ) mock_func = MagicMock() - resulted = with_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + self.assertEqual(resulted, mock_func()) + + def test_use_redis(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + mock_func = MagicMock() + redis_mock.side_effect = fakeredis.FakeRedis + resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + self.assertEqual(resulted, mock_func()) + + def test_check_lock_extended(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + redis_timeout=2, + lock_extend_seconds=1, + ) + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.side_effect = fakeredis.FakeRedis + resulted = wrap_with_run_lock(func=_sample_long_func, redis_params=redis_params)(123, b='abc') + expected = dict(a=123, b='abc') + self.assertEqual(resulted, expected) + + def test_lock_is_removed_after_func_is_finished(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + server = fakeredis.FakeServer() + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + mock_func = MagicMock() + resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + self.assertEqual(resulted, mock_func()) + + fake_redis = fakeredis.FakeStrictRedis(server=server) + with self.assertRaises(KeyError): + fake_redis[redis_params.redis_key] + + def test_lock_is_removed_after_func_is_finished_with_error(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + server = fakeredis.FakeServer() + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + try: + wrap_with_run_lock(func=_sample_func_with_error, redis_params=redis_params)(a=123, b='abc') + except Exception: + fake_redis = fakeredis.FakeStrictRedis(server=server) + with self.assertRaises(KeyError): + fake_redis[redis_params.redis_key] + + +class TestWrapWithDumpLock(unittest.TestCase): + + def test_no_redis(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host=None, + redis_port=None, + ) + mock_func = MagicMock() + wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + + def test_use_redis(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.side_effect = fakeredis.FakeRedis + mock_func = MagicMock() + wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + + def test_if_func_is_skipped_when_cache_already_exists(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.side_effect = fakeredis.FakeRedis + mock_func = MagicMock() + wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: True)(123, b='abc') + + mock_func.assert_not_called() + + def test_check_lock_extended(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + redis_timeout=2, + lock_extend_seconds=1, + ) + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.side_effect = fakeredis.FakeRedis + wrap_with_dump_lock(func=_sample_long_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + + def test_lock_is_removed_after_func_is_finished(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + server = fakeredis.FakeServer() + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + mock_func = MagicMock() + wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + + fake_redis = fakeredis.FakeStrictRedis(server=server) + with self.assertRaises(KeyError): + fake_redis[redis_params.redis_key] + + def test_lock_is_removed_after_func_is_finished_with_error(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + server = fakeredis.FakeServer() + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + try: + wrap_with_dump_lock(func=_sample_func_with_error, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + except Exception: + fake_redis = fakeredis.FakeStrictRedis(server=server) + with self.assertRaises(KeyError): + fake_redis[redis_params.redis_key] + + +class TestWrapWithLoadLock(unittest.TestCase): + + def test_no_redis(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host=None, + redis_port=None, + ) + mock_func = MagicMock() + resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -67,7 +273,7 @@ def test_use_redis(self): with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() - resulted = with_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -88,7 +294,7 @@ def test_check_lock_extended(self): with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis - resulted = with_lock(func=self._sample_long_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=_sample_long_func, redis_params=redis_params)(123, b='abc') expected = dict(a=123, b='abc') self.assertEqual(resulted, expected) @@ -105,13 +311,109 @@ def test_lock_is_removed_after_func_is_finished(self): with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) mock_func = MagicMock() - resulted = with_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + self.assertEqual(resulted, mock_func()) + + fake_redis = fakeredis.FakeStrictRedis(server=server) + with self.assertRaises(KeyError): + fake_redis[redis_params.redis_key] + + def test_lock_is_removed_after_func_is_finished_with_error(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + server = fakeredis.FakeServer() + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + try: + wrap_with_run_lock(func=_sample_func_with_error, redis_params=redis_params)(123, b='abc') + except Exception: + fake_redis = fakeredis.FakeStrictRedis(server=server) + with self.assertRaises(KeyError): + fake_redis[redis_params.redis_key] + + +class TestWrapWithRemoveLock(unittest.TestCase): + + def test_no_redis(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host=None, + redis_port=None, + ) + mock_func = MagicMock() + resulted = wrap_with_remove_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) + self.assertEqual(resulted, mock_func()) + + def test_use_redis(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.side_effect = fakeredis.FakeRedis + mock_func = MagicMock() + resulted = wrap_with_remove_lock(func=mock_func, redis_params=redis_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123, )) self.assertDictEqual(called_kwargs, dict(b='abc')) + self.assertEqual(resulted, mock_func()) + + def test_check_lock_extended(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + redis_timeout=2, + lock_extend_seconds=1, + ) + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.side_effect = fakeredis.FakeRedis + resulted = wrap_with_remove_lock(func=_sample_long_func, redis_params=redis_params)(123, b='abc') + expected = dict(a=123, b='abc') + self.assertEqual(resulted, expected) + def test_lock_is_removed_after_func_is_finished(self): + redis_params = make_redis_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + ) + + server = fakeredis.FakeServer() + + with patch('gokart.redis_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + mock_func = MagicMock() + resulted = wrap_with_remove_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + mock_func.assert_called_once() + called_args, called_kwargs = mock_func.call_args + self.assertTupleEqual(called_args, (123, )) + self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) fake_redis = fakeredis.FakeStrictRedis(server=server) @@ -131,7 +433,7 @@ def test_lock_is_removed_after_func_is_finished_with_error(self): with patch('gokart.redis_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) try: - with_lock(func=self._sample_func_with_error, redis_params=redis_params)(123, b='abc') + wrap_with_remove_lock(func=_sample_func_with_error, redis_params=redis_params)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): diff --git a/test/test_target.py b/test/test_target.py index a73c36c4..781eb0cf 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -153,7 +153,7 @@ def test_save_pandas_series(self): pd.testing.assert_series_equal(loaded['column_name'], obj) def test_dump_with_lock(self): - with patch('gokart.target.TargetOnKart.wrap_with_lock') as wrap_with_lock_mock: + with patch('gokart.target.wrap_with_dump_lock') as wrap_with_lock_mock: obj = 1 file_path = os.path.join(_get_temporary_directory(), 'test.pkl') target = make_target(file_path=file_path, unique_id=None) @@ -162,7 +162,7 @@ def test_dump_with_lock(self): wrap_with_lock_mock.assert_called_once() def test_dump_without_lock(self): - with patch('gokart.target.TargetOnKart.wrap_with_lock') as wrap_with_lock_mock: + with patch('gokart.target.wrap_with_dump_lock') as wrap_with_lock_mock: obj = 1 file_path = os.path.join(_get_temporary_directory(), 'test.pkl') target = make_target(file_path=file_path, unique_id=None)