diff --git a/README.md b/README.md index 5d9b9d88..f101f081 100644 --- a/README.md +++ b/README.md @@ -98,40 +98,6 @@ def run(self): ``` ## Advanced -### Task cache collision lock -#### Require -You need to install [redis](https://redis.io/topics/quickstart) for this advanced function. - -#### Description -Task lock is implemented to prevent task cache collision. -(Originally, task cache collision may occur when same task with same parameters run at different applications parallelly.) - -1. Set up a redis server at somewhere accessible from gokart/luigi jobs. - - Following will run redis at your localhost. - - ```bash - $ redis-server - ``` - -2. Set redis server hostname and port number as parameters to gokart.TaskOnKart(). - - You can set it by adding `--redis-host=[your-redis-localhost] --redis-port=[redis-port-number]` options to gokart python script. - - e.g. - ```bash - - python main.py sample.SomeTask --local-scheduler --redis-host=localhost --redis-port=6379 - ``` - - Alternatively, you may set parameters at config file. - - ```conf.ini - [TaskOnKart] - redis_host=localhost - redis_port=6379 - ``` - ### Inherit task parameters with decorator #### Description ```python diff --git a/docs/using_task_cache_collision_lock.rst b/docs/using_task_cache_collision_lock.rst new file mode 100644 index 00000000..e3b85f53 --- /dev/null +++ b/docs/using_task_cache_collision_lock.rst @@ -0,0 +1,91 @@ +1. Task cache collision lock +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Requires +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You need to install (redis)[https://redis.io/topics/quickstart] for this +advanced function. + +Description +^^^^^^^^^^^ + +Task lock is implemented to prevent task cache collision. (Originally, +task cache collision may occur when same task with same parameters run +at different applications parallelly.) + +1. Set up a redis server at somewhere accessible from gokart/luigi jobs. + + Following will run redis at your localhost. + + .. code:: bash + + $ redis-server + +2. Set redis server hostname and port number as parameters to gokart.TaskOnKart(). + + You can set it by adding ``--redis-host=[your-redis-localhost] --redis-port=[redis-port-number]`` options to gokart python script. + + e.g. + + .. code:: bash + + python main.py sample.SomeTask –local-scheduler –redis-host=localhost –redis-port=6379 + + + Alternatively, you may set parameters at config file. + + .. code:: + + [TaskOnKart] + redis_host=localhost + redis_port=6379 + +2. Using efficient task cache collision lock +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +Description +^^^^^^^^^^^ + +Above task lock will prevent cache collision. However, above setting check collisions only when the task access the cache file (i.e. ``task.dump()``, ``task.load()`` and ``task.remove()``). This will allow applications to run ``run()`` of same task at the same time, which +is not efficent. + +Settings in this section will prevent running ``run()`` at the same time for efficiency. + +1. Set normal cache collision lock Set cache collision lock following ``1. Task cache collision lock``. + +2. Decorate ``run()`` with ``@RunWithLock`` Decorate ``run()`` of yourt gokart tasks which you want to lock with ``@RunWithLock``. + + .. code:: python + + from gokart.run_with_lock import RunWithLock + + @RunWithLock + class SomeTask(gokart.TaskOnKart): + def run(self): + + +3. Set ``redis_fail_on_collision`` parameter to true. This parameter will affect the behavior when the task’s lock is taken by other application. By setting ``redis_fail_on_collision=True``, task will be failed if the task’s lock is taken by other application. The locked task will be skipped and other independent task will be done first. If ``redis_fail_on_collision=False``, it will wait until the lock of other application is released. + + The parameter can be set by config file. + + .. code:: + + [TaskOnKart] + redis_host=localhost + redis_port=6379 + redis_fail_on_collision=true + +4. Set retry parameters. Set following parameters to retry when task + failed. Values of ``retry_count`` and ``retry_delay``\ can be set to + any value depends on your situation. + + :: + + [scheduler] + retry_count=10000 + retry_delay=10 + + [worker] + keep_alive=true \ No newline at end of file diff --git a/gokart/redis_lock.py b/gokart/redis_lock.py index a287fd01..c9db427d 100644 --- a/gokart/redis_lock.py +++ b/gokart/redis_lock.py @@ -7,8 +7,6 @@ logger = getLogger(__name__) -# TODO: Codes of this file should be implemented to gokart - class RedisParams(NamedTuple): redis_host: str @@ -16,6 +14,11 @@ class RedisParams(NamedTuple): redis_timeout: int redis_key: str should_redis_lock: bool + redis_fail_on_collision: bool + + +class TaskLockException(Exception): + pass class RedisClient: @@ -43,8 +46,10 @@ def with_lock(func, redis_params: RedisParams): def wrapper(*args, **kwargs): redis_client = RedisClient(host=redis_params.redis_host, port=redis_params.redis_port).get_redis_client() - redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, blocking=True, thread_local=False) - redis_lock.acquire() + blocking = not redis_params.redis_fail_on_collision + redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, thread_local=False) + if not redis_lock.acquire(blocking=blocking): + raise TaskLockException('Lock already taken by other task.') def extend_lock(): redis_lock.extend(additional_time=redis_params.redis_timeout, replace_ttl=True) @@ -62,6 +67,7 @@ def extend_lock(): return result except BaseException as e: logger.debug(f'Task lock of {redis_params.redis_key} released with BaseException.') + redis_lock.release() scheduler.shutdown() raise e @@ -73,12 +79,18 @@ def make_redis_key(file_path: str, unique_id: str): return f'{basename_without_ext}_{unique_id}' -def make_redis_params(file_path: str, unique_id: str, redis_host: str, redis_port: str, redis_timeout: int): +def make_redis_params(file_path: str, + unique_id: str, + redis_host: str = None, + redis_port: str = None, + redis_timeout: int = None, + redis_fail_on_collision: bool = False): redis_key = make_redis_key(file_path, unique_id) should_redis_lock = redis_host is not None and redis_port is not None redis_params = RedisParams(redis_host=redis_host, redis_port=redis_port, redis_key=redis_key, should_redis_lock=should_redis_lock, - redis_timeout=redis_timeout) + redis_timeout=redis_timeout, + redis_fail_on_collision=redis_fail_on_collision) return redis_params diff --git a/gokart/run_with_lock.py b/gokart/run_with_lock.py new file mode 100644 index 00000000..a2a627a9 --- /dev/null +++ b/gokart/run_with_lock.py @@ -0,0 +1,25 @@ +from functools import partial + +import luigi + + +class RunWithLock: + def __init__(self, func): + self._func = func + + def __call__(self, instance): + instance._lock_at_dump = False + output_list = luigi.task.flatten(instance.output()) + return self._run_with_lock(partial(self._func, self=instance), output_list) + + def __get__(self, instance, owner_class): + return partial(self.__call__, instance) + + @classmethod + def _run_with_lock(cls, func, output_list: list): + if len(output_list) == 0: + return func() + + output = output_list.pop() + wrapped_func = output.wrap_with_lock(func) + return cls._run_with_lock(func=wrapped_func, output_list=output_list) diff --git a/gokart/target.py b/gokart/target.py index ee8f7db3..24c3c6e1 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -26,13 +26,16 @@ def exists(self) -> bool: return self._exists() def load(self) -> Any: - return self._with_lock(self._load)() + return self.wrap_with_lock(self._load)() - def dump(self, obj) -> None: - self._with_lock(self._dump)(obj) + def dump(self, obj, lock_at_dump: bool = True) -> None: + if lock_at_dump: + self.wrap_with_lock(self._dump)(obj) + else: + self._dump(obj) def remove(self) -> None: - return self._with_lock(self._remove)() + return self.wrap_with_lock(self._remove)() def last_modification_time(self) -> datetime: return self._last_modification_time() @@ -40,7 +43,7 @@ def last_modification_time(self) -> datetime: def path(self) -> str: return self._path() - def _with_lock(self, func): + def wrap_with_lock(self, func): return with_lock(func=func, redis_params=self._get_redis_params()) @abstractmethod @@ -211,17 +214,12 @@ def _get_last_modification_time(path: str) -> datetime: return datetime.fromtimestamp(os.path.getmtime(path)) -def make_target(file_path: str, - unique_id: Optional[str] = None, - processor: Optional[FileProcessor] = None, - redis_host: str = None, - redis_port: str = None, - redis_timeout: int = 180) -> TargetOnKart: - redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, redis_host=redis_host, redis_port=redis_port, redis_timeout=redis_timeout) +def make_target(file_path: str, unique_id: Optional[str] = None, processor: Optional[FileProcessor] = None, redis_params: RedisParams = None) -> TargetOnKart: + _redis_params = redis_params if redis_params is not None else make_redis_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path) file_system_target = _make_file_system_target(file_path, processor=processor) - return SingleFileTarget(target=file_system_target, processor=processor, redis_params=redis_params) + return SingleFileTarget(target=file_system_target, processor=processor, redis_params=_redis_params) def make_model_target(file_path: str, @@ -229,14 +227,12 @@ def make_model_target(file_path: str, save_function, load_function, unique_id: Optional[str] = None, - redis_host: str = None, - redis_port: str = None, - redis_timeout: int = 180) -> TargetOnKart: - redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, redis_host=redis_host, redis_port=redis_port, redis_timeout=redis_timeout) + redis_params: RedisParams = None) -> TargetOnKart: + _redis_params = redis_params if redis_params is not None else make_redis_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest()) return ModelTarget(file_path=file_path, temporary_directory=temporary_directory, save_function=save_function, load_function=load_function, - redis_params=redis_params) + redis_params=_redis_params) diff --git a/gokart/task.py b/gokart/task.py index 62ad929c..e857c488 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -1,3 +1,4 @@ +from functools import partial import hashlib import os import sys @@ -14,6 +15,7 @@ from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import TaskInstanceParameter, ListTaskInstanceParameter from gokart.target import TargetOnKart +from gokart.redis_lock import make_redis_params logger = getLogger(__name__) @@ -52,6 +54,10 @@ class TaskOnKart(luigi.Task): redis_host = luigi.Parameter(default=None, description='Task lock check is deactivated, when None.', significant=False) redis_port = luigi.Parameter(default=None, description='Task lock check is deactivated, when None.', significant=False) redis_timeout = luigi.IntParameter(default=180, description='Redis lock will be released after `redis_timeout` seconds', significant=False) + redis_fail_on_collision: bool = luigi.BoolParameter( + default=False, + description='True for failing the task immediately when the cache is locked, instead of waiting for the lock to be released', + significant=False) fail_on_empty_dump: bool = gokart.ExplicitBoolParameter(default=False, description='Fail when task dumps empty DF', significant=False) def __init__(self, *args, **kwargs): @@ -61,6 +67,7 @@ def __init__(self, *args, **kwargs): self.task_unique_id = None super(TaskOnKart, self).__init__(*args, **kwargs) self._rerun_state = self.rerun + self._lock_at_dump = True def output(self): return self.make_target() @@ -137,26 +144,32 @@ def make_target(self, relative_file_path: str = None, use_unique_id: bool = True f"{type(self).__name__}.pkl") file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None - return gokart.target.make_target(file_path=file_path, + + redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, - processor=processor, redis_host=self.redis_host, redis_port=self.redis_port, - redis_timeout=self.redis_timeout) + redis_timeout=self.redis_timeout, + redis_fail_on_collision=self.redis_fail_on_collision) + return gokart.target.make_target(file_path=file_path, unique_id=unique_id, processor=processor, redis_params=redis_params) def make_large_data_frame_target(self, relative_file_path: str = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace(".", "/"), f"{type(self).__name__}.zip") file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None + redis_params = make_redis_params(file_path=file_path, + unique_id=unique_id, + redis_host=self.redis_host, + redis_port=self.redis_port, + redis_timeout=self.redis_timeout, + redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_model_target(file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save, load_function=gokart.target.LargeDataFrameProcessor.load, - redis_host=self.redis_host, - redis_port=self.redis_port, - redis_timeout=self.redis_timeout) + redis_params=redis_params) def make_model_target(self, relative_file_path: str, @@ -174,14 +187,18 @@ def make_model_target(self, file_path = os.path.join(self.workspace_directory, relative_file_path) assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.' unique_id = self.make_unique_id() if use_unique_id else None + redis_params = make_redis_params(file_path=file_path, + unique_id=unique_id, + redis_host=self.redis_host, + redis_port=self.redis_port, + redis_timeout=self.redis_timeout, + redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_model_target(file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=save_function, load_function=load_function, - redis_host=self.redis_host, - redis_port=self.redis_port, - redis_timeout=self.redis_timeout) + redis_params=redis_params) def load(self, target: Union[None, str, TargetOnKart] = None) -> Any: def _load(targets): @@ -230,7 +247,7 @@ def dump(self, obj, target: Union[None, str, TargetOnKart] = None) -> None: PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace) if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame): assert not obj.empty - self._get_output_target(target).dump(obj) + self._get_output_target(target).dump(obj, lock_at_dump=self._lock_at_dump) def make_unique_id(self): self.task_unique_id = self.task_unique_id or self._make_hash_id() diff --git a/test/test_redis_lock.py b/test/test_redis_lock.py index c7a35ca7..7cfacde7 100644 --- a/test/test_redis_lock.py +++ b/test/test_redis_lock.py @@ -32,11 +32,31 @@ def test_make_redis_key(self): class TestMakeRedisParams(unittest.TestCase): def test_make_redis_params_with_valid_host(self): - result = make_redis_params(file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port='12345', redis_timeout=180) - expected = RedisParams(redis_host='0.0.0.0', redis_port='12345', redis_key='aaa_123', should_redis_lock=True, redis_timeout=180) + result = make_redis_params(file_path='gs://aaa.pkl', + unique_id='123', + redis_host='0.0.0.0', + redis_port='12345', + redis_timeout=180, + redis_fail_on_collision=False) + expected = RedisParams(redis_host='0.0.0.0', + redis_port='12345', + redis_key='aaa_123', + should_redis_lock=True, + redis_timeout=180, + redis_fail_on_collision=False) self.assertEqual(result, expected) def test_make_redis_params_with_no_host(self): - result = make_redis_params(file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port='12345', redis_timeout=180) - expected = RedisParams(redis_host=None, redis_port='12345', redis_key='aaa_123', should_redis_lock=False, redis_timeout=180) + result = make_redis_params(file_path='gs://aaa.pkl', + unique_id='123', + redis_host=None, + redis_port='12345', + redis_timeout=180, + redis_fail_on_collision=False) + expected = RedisParams(redis_host=None, + redis_port='12345', + redis_key='aaa_123', + should_redis_lock=False, + redis_timeout=180, + redis_fail_on_collision=False) self.assertEqual(result, expected) diff --git a/test/test_target.py b/test/test_target.py index 8f2de01f..681d469a 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -2,6 +2,7 @@ import os import shutil import unittest +from unittest.mock import patch from datetime import datetime import boto3 @@ -141,6 +142,24 @@ def test_save_pandas_series(self): pd.testing.assert_series_equal(loaded['column_name'], obj) + def test_dump_with_lock(self): + with patch('gokart.target.TargetOnKart.wrap_with_lock') as wrap_with_lock_mock: + obj = 1 + file_path = os.path.join(_get_temporary_directory(), 'test.pkl') + target = make_target(file_path=file_path, unique_id=None) + target.dump(obj, lock_at_dump=True) + + wrap_with_lock_mock.assert_called_once() + + def test_dump_without_lock(self): + with patch('gokart.target.TargetOnKart.wrap_with_lock') as wrap_with_lock_mock: + obj = 1 + file_path = os.path.join(_get_temporary_directory(), 'test.pkl') + target = make_target(file_path=file_path, unique_id=None) + target.dump(obj, lock_at_dump=False) + + wrap_with_lock_mock.assert_not_called() + class S3TargetTest(unittest.TestCase): @mock_s3 diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index 312eea80..b44b3193 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -11,6 +11,7 @@ from gokart.parameter import TaskInstanceParameter, ListTaskInstanceParameter from gokart.file_processor import XmlFileProcessor from gokart.target import TargetOnKart, SingleFileTarget, ModelTarget +from gokart.run_with_lock import RunWithLock class _DummyTask(gokart.TaskOnKart): @@ -56,6 +57,32 @@ class _DummyTaskD(gokart.TaskOnKart): task_namespace = __name__ +class _DummyTaskWithLock(gokart.TaskOnKart): + task_namespace = __name__ + + @RunWithLock + def run(self): + pass + + +class _DummyTaskWithLockMultipleOutput(gokart.TaskOnKart): + task_namespace = __name__ + + @RunWithLock + def run(self): + pass + + def output(self): + return dict(dataA=self.make_target('fileA.pkl'), dataB=self.make_target('fileB.pkl')) + + +class _DummyTaskWithoutLock(gokart.TaskOnKart): + task_namespace = __name__ + + def run(self): + pass + + class TaskTest(unittest.TestCase): def setUp(self): _DummyTask.clear_instance_cache() @@ -403,6 +430,39 @@ class _Task(gokart.TaskOnKart): f'list_task_param=[{__name__}._SubTask({sub_task_id}), {__name__}._SubTask({sub_task_id})])' self.assertEqual(expected, str(task)) + def test_run_with_lock_decorator(self): + task = _DummyTaskWithLock() + + def _wrap(func): + return func + + with patch('gokart.target.TargetOnKart.wrap_with_lock') as mock_obj: + mock_obj.side_effect = _wrap + task.run() + mock_obj.assert_called_once() + + def test_run_with_lock_decorator_multiple_output(self): + task = _DummyTaskWithLockMultipleOutput() + + def _wrap(func): + return func + + with patch('gokart.target.TargetOnKart.wrap_with_lock') as mock_obj: + mock_obj.side_effect = _wrap + task.run() + self.assertEqual(mock_obj.call_count, 2) + + def test_run_without_lock_decorator(self): + task = _DummyTaskWithoutLock() + + def _wrap(func): + return func + + with patch('gokart.target.TargetOnKart.wrap_with_lock') as mock_obj: + mock_obj.side_effect = _wrap + task.run() + mock_obj.assert_not_called() + if __name__ == '__main__': unittest.main()