Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hook to run generic pre- and post-task logic #65

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,35 @@ JOBS = {
}
```

#### Pre & Post Task Hooks
You can also run pre task or post task hooks, which happen in the normal processing of your `Job` instances and are executed inside the worker process.

Both pre and post task hooks receive your `Job` instance as their only argument. Here's an example:

```python
def my_pre_task_hook(job):
... # configure something before running your task
```

To ensure these hooks are run, simply add a `pre_task_hook` or `post_task_hook` key (or both, if needed) to your job config like so:

```python
JOBS = {
"my_job": {
"tasks": ["project.common.jobs.my_task"],
"pre_task_hook": "project.common.jobs.my_pre_task_hook",
"post_task_hook": "project.common.jobs.my_post_task_hook",
},
}
```

Notes:

* If the `pre_task_hook` fails (raises an exception), the task function is not run, and django-db-queue behaves as if the task function itself had failed: the failure hook is called, and the job is goes into the `FAILED` state.
* The `post_task_hook` is always run, even if the job fails. In this case, it runs after the `failure_hook`.
* If the `post_task_hook` raises an exception, this is logged but the the job is **not marked as failed** and the failure hook does not run. This is because the `post_task_hook` might need to perform cleanup that always happens after the task, no matter whether it succeeds or fails.


### Start the worker

In another terminal:
Expand Down
2 changes: 1 addition & 1 deletion django_dbq/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.1.0"
__version__ = "3.2.0"
21 changes: 9 additions & 12 deletions django_dbq/management/commands/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,23 @@ def _process_job(self):
self.current_job = job

try:
task_function = import_string(job.next_task)
task_function(job)
job.run_pre_task_hook()
job.run_next_task()
job.update_next_task()

if not job.next_task:
job.state = Job.STATES.COMPLETE
else:
job.state = Job.STATES.READY
except Exception as exception:
logger.exception("Job id=%s failed", job.pk)
job.state = Job.STATES.FAILED

failure_hook_name = job.get_failure_hook_name()
if failure_hook_name:
logger.info(
"Running failure hook %s for job id=%s", failure_hook_name, job.pk
)
failure_hook_function = import_string(failure_hook_name)
failure_hook_function(job, exception)
else:
logger.info("No failure hook for job id=%s", job.pk)
job.run_failure_hook(exception)
finally:
try:
job.run_post_task_hook()
except:
logger.exception("Job id=%s post_task_hook failed", job.pk)

logger.info(
'Updating job: name="%s" id=%s state=%s next_task=%s',
Expand Down
35 changes: 34 additions & 1 deletion django_dbq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from django.utils.module_loading import import_string
from django_dbq.tasks import (
get_next_task_name,
get_pre_task_hook_name,
get_post_task_hook_name,
get_failure_hook_name,
get_creation_hook_name,
)
Expand Down Expand Up @@ -126,16 +128,47 @@ def save(self, *args, **kwargs):
def update_next_task(self):
self.next_task = get_next_task_name(self.name, self.next_task) or ""

def run_next_task(self):
next_task_function = import_string(self.next_task)
next_task_function(self)

def get_pre_task_hook_name(self):
return get_pre_task_hook_name(self.name)

def get_post_task_hook_name(self):
return get_post_task_hook_name(self.name)

def get_failure_hook_name(self):
return get_failure_hook_name(self.name)

def get_creation_hook_name(self):
return get_creation_hook_name(self.name)

def run_pre_task_hook(self):
pre_task_hook_name = self.get_pre_task_hook_name()
if pre_task_hook_name:
logger.info("Running pre_task hook %s for job", pre_task_hook_name)
pre_task_hook_function = import_string(pre_task_hook_name)
pre_task_hook_function(self)

def run_post_task_hook(self):
post_task_hook_name = self.get_post_task_hook_name()
if post_task_hook_name:
logger.info("Running post_task hook %s for job", post_task_hook_name)
post_task_hook_function = import_string(post_task_hook_name)
post_task_hook_function(self)

def run_failure_hook(self, exception):
failure_hook_name = self.get_failure_hook_name()
if failure_hook_name:
logger.info("Running failure hook %s for job", failure_hook_name)
failure_hook_function = import_string(failure_hook_name)
failure_hook_function(self, exception)

def run_creation_hook(self):
creation_hook_name = self.get_creation_hook_name()
if creation_hook_name:
logger.info("Running creation hook %s for new job", creation_hook_name)
logger.info("Running creation hook %s for job", creation_hook_name)
creation_hook_function = import_string(creation_hook_name)
creation_hook_function(self)

Expand Down
12 changes: 12 additions & 0 deletions django_dbq/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


TASK_LIST_KEY = "tasks"
PRE_TASK_HOOK_KEY = "pre_task_hook"
POST_TASK_HOOK_KEY = "post_task_hook"
FAILURE_HOOK_KEY = "failure_hook"
CREATION_HOOK_KEY = "creation_hook"

Expand All @@ -24,6 +26,16 @@ def get_next_task_name(job_name, current_task=None):
return None


def get_pre_task_hook_name(job_name):
"""Return the name of the pre task hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(PRE_TASK_HOOK_KEY)


def get_post_task_hook_name(job_name):
"""Return the name of the post_task hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(POST_TASK_HOOK_KEY)


def get_failure_hook_name(job_name):
"""Return the name of the failure hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(FAILURE_HOOK_KEY)
Expand Down
52 changes: 52 additions & 0 deletions django_dbq/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,25 @@ def failing_task(job):
raise Exception("uh oh")


def pre_task_hook(job):
job.workspace["output"] = "pre task hook ran"
job.workspace["job_id"] = str(job.id)


def post_task_hook(job):
job.workspace["output"] = "post task hook ran"
job.workspace["job_id"] = str(job.id)


def failure_hook(job, exception):
job.workspace["output"] = "failure hook ran"
job.workspace["exception"] = str(exception)
job.workspace["job_id"] = str(job.id)


def creation_hook(job):
job.workspace["output"] = "creation hook ran"
job.workspace["job_id"] = str(job.id)


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
Expand Down Expand Up @@ -316,6 +329,7 @@ def test_creation_hook(self):
job = Job.objects.create(name="testjob")
job = Job.objects.get()
self.assertEqual(job.workspace["output"], "creation hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))

def test_creation_hook_only_runs_on_create(self):
job = Job.objects.create(name="testjob")
Expand All @@ -326,6 +340,42 @@ def test_creation_hook_only_runs_on_create(self):
self.assertEqual(job.workspace["output"], "creation hook output removed")


@override_settings(
JOBS={
"testjob": {
"tasks": ["django_dbq.tests.test_task"],
"pre_task_hook": "django_dbq.tests.pre_task_hook",
}
}
)
class JobPreTaskHookTestCase(TestCase):
def test_pre_task_hook(self):
job = Job.objects.create(name="testjob")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.COMPLETE)
self.assertEqual(job.workspace["output"], "pre task hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(
JOBS={
"testjob": {
"tasks": ["django_dbq.tests.test_task"],
"post_task_hook": "django_dbq.tests.post_task_hook",
}
}
)
class JobPostTaskHookTestCase(TestCase):
def test_post_task_hook(self):
job = Job.objects.create(name="testjob")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.COMPLETE)
self.assertEqual(job.workspace["output"], "post task hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(
JOBS={
"testjob": {
Expand All @@ -341,6 +391,8 @@ def test_failure_hook(self):
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.FAILED)
self.assertEqual(job.workspace["output"], "failure hook ran")
self.assertIn("uh oh", job.workspace["exception"])
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
Expand Down
Loading