Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from celery import states as celery_states
from deprecated import deprecated

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.celery.executors import (
celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043
celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043.
)
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats
Expand All @@ -49,18 +49,23 @@
log = logging.getLogger(__name__)


CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task"
CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery workload"


if TYPE_CHECKING:
from collections.abc import Sequence

from celery.result import AsyncResult

from airflow.cli.cli_config import GroupCommand
from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery

if AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads.types import WorkloadKey


# PEP562
def __getattr__(name):
Expand All @@ -84,7 +89,7 @@ class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow.

It allows distributing the execution of task instances to multiple worker nodes.
It allows distributing the execution of workloads (task instances and callbacks) to multiple worker nodes.

Celery is a simple, flexible and reliable distributed system to process
vast amounts of messages, while providing operations with the tools
Expand All @@ -102,7 +107,7 @@ class CeleryExecutor(BaseExecutor):
if TYPE_CHECKING:
if AIRFLOW_V_3_0_PLUS:
# TODO: TaskSDK: move this type change into BaseExecutor
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -127,7 +132,7 @@ def __init__(self, *args, **kwargs):

self.celery_app = create_celery_app(self.conf)

# Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)
# Celery doesn't support bulk sending the workloads (which can become a bottleneck on bigger clusters)
# so we use a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
self._sync_parallelism = self.conf.getint("celery", "SYNC_PARALLELISM", fallback=0)
Expand All @@ -136,149 +141,151 @@ def __init__(self, *args, **kwargs):
from airflow.providers.celery.executors.celery_executor_utils import BulkStateFetcher

self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism, celery_app=self.celery_app)
self.tasks = {}
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
self.task_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3)
self.workloads: dict[WorkloadKey, AsyncResult] = {}
self.workload_publish_retries: Counter[WorkloadKey] = Counter()
self.workload_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3)

def start(self) -> None:
self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism)

def _num_tasks_per_send_process(self, to_send_count: int) -> int:
def _num_workloads_per_send_process(self, to_send_count: int) -> int:
"""
How many Celery tasks should each worker process send.
How many Celery workloads should each worker process send.

:return: Number of tasks that should be sent per process
:return: Number of workloads that should be sent per process
"""
return max(1, math.ceil(to_send_count / self._sync_parallelism))

def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None:
# Airflow V2 version
# Airflow V2 compatibility path — converts task tuples into workload-compatible tuples.

task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for task_tuple in task_tuples]

self._send_tasks(task_tuples_to_send)
self._send_workloads(task_tuples_to_send)

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
# Airflow V3 version -- have to delay imports until we know we are on v3
# Airflow V3 version -- have to delay imports until we know we are on v3.
from airflow.executors.workloads import ExecuteTask

if AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads import ExecuteCallback

tasks: list[WorkloadInCelery] = []
workloads_to_be_sent: list[WorkloadInCelery] = []
for workload in workloads:
if isinstance(workload, ExecuteTask):
tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name))
workloads_to_be_sent.append((workload.ti.key, workload, workload.ti.queue, self.team_name))
elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):
# Use default queue for callbacks, or extract from callback data if available
# Use default queue for callbacks, or extract from callback data if available.
queue = "default"
if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data:
queue = workload.callback.data["queue"]
tasks.append((workload.callback.key, workload, queue, self.team_name))
workloads_to_be_sent.append((workload.callback.key, workload, queue, self.team_name))
else:
raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}")

self._send_tasks(tasks)
self._send_workloads(workloads_to_be_sent)

def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]):
# Celery state queries will be stuck if we do not use one same backend
# for all tasks.
# for all workloads.
cached_celery_backend = self.celery_app.backend

key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
self.log.debug("Sent all tasks.")
key_and_async_results = self._send_workloads_to_celery(workload_tuples_to_send)
self.log.debug("Sent all workloads.")
from airflow.providers.celery.executors.celery_executor_utils import ExceptionWithTraceback

for key, _, result in key_and_async_results:
if isinstance(result, ExceptionWithTraceback) and isinstance(
result.exception, AirflowTaskTimeout
):
retries = self.task_publish_retries[key]
if retries < self.task_publish_max_retries:
retries = self.workload_publish_retries[key]
if retries < self.workload_publish_max_retries:
Stats.incr("celery.task_timeout_error")
self.log.info(
"[Try %s of %s] Task Timeout Error for Task: (%s).",
self.task_publish_retries[key] + 1,
self.task_publish_max_retries,
"[Try %s of %s] Task Timeout Error for Workload: (%s).",
self.workload_publish_retries[key] + 1,
self.workload_publish_max_retries,
tuple(key),
)
self.task_publish_retries[key] = retries + 1
self.workload_publish_retries[key] = retries + 1
continue
if key in self.queued_tasks:
self.queued_tasks.pop(key)
else:
self.queued_callbacks.pop(key, None)
self.task_publish_retries.pop(key, None)
self.workload_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback)
self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
self.tasks[key] = result
self.workloads[key] = result

# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# Store the Celery task_id (workload execution ID) in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
# which point we don't need the ID anymore anyway
# which point we don't need the ID anymore anyway.
self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)

def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor
def _send_workloads_to_celery(self, workload_tuples_to_send: Sequence[WorkloadInCelery]):
from airflow.providers.celery.executors.celery_executor_utils import send_workload_to_executor

if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1:
# One tuple, or max one process -> send it in the main thread.
return list(map(send_task_to_executor, task_tuples_to_send))
return list(map(send_workload_to_executor, workload_tuples_to_send))

# Use chunks instead of a work queue to reduce context switching
# since tasks are roughly uniform in size
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
# since workloads are roughly uniform in size.
chunksize = self._num_workloads_per_send_process(len(workload_tuples_to_send))
num_processes = min(len(workload_tuples_to_send), self._sync_parallelism)

# Use ProcessPoolExecutor with team_name instead of task objects to avoid pickling issues.
# Use ProcessPoolExecutor with team_name instead of workload objects to avoid pickling issues.
# Subprocesses reconstruct the team-specific Celery app from the team name and existing config.
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
key_and_async_results = list(
send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize)
send_pool.map(send_workload_to_executor, workload_tuples_to_send, chunksize=chunksize)
)
return key_and_async_results

def sync(self) -> None:
if not self.tasks:
self.log.debug("No task to query celery, skipping sync")
if not self.workloads:
self.log.debug("No workload to query celery, skipping sync")
return
self.update_all_task_states()
self.update_all_workload_states()

def debug_dump(self) -> None:
"""Debug dump; called in response to SIGUSR2 by the scheduler."""
super().debug_dump()
self.log.info(
"executor.tasks (%d)\n\t%s", len(self.tasks), "\n\t".join(map(repr, self.tasks.items()))
"executor.workloads (%d)\n\t%s",
len(self.workloads),
"\n\t".join(map(repr, self.workloads.items())),
)

def update_all_task_states(self) -> None:
"""Update states of the tasks."""
self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.tasks.values())
def update_all_workload_states(self) -> None:
"""Update states of the workloads."""
self.log.debug("Inquiring about %s celery workload(s)", len(self.workloads))
state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.workloads.values())

self.log.debug("Inquiries completed.")
for key, async_result in list(self.tasks.items()):
for key, async_result in list(self.workloads.items()):
state, info = state_and_info_by_celery_task_id.get(async_result.task_id)
if state:
self.update_task_state(key, state, info)
self.update_workload_state(key, state, info)

def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
super().change_state(key, state, info, remove_running=remove_running)
self.tasks.pop(key, None)
self.workloads.pop(key, None)

def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
"""Update state of a single task."""
def update_workload_state(self, key: WorkloadKey, state: str, info: Any) -> None:
"""Update state of a single workload."""
try:
if state == celery_states.SUCCESS:
self.success(key, info)
self.success(cast("TaskInstanceKey", key), info)
elif state in (celery_states.FAILURE, celery_states.REVOKED):
self.fail(key, info)
self.fail(cast("TaskInstanceKey", key), info)
elif state in (celery_states.STARTED, celery_states.PENDING, celery_states.RETRY):
pass
else:
Expand All @@ -288,7 +295,9 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None

def end(self, synchronous: bool = False) -> None:
if synchronous:
while any(task.state not in celery_states.READY_STATES for task in self.tasks.values()):
while any(
workload.state not in celery_states.READY_STATES for workload in self.workloads.values()
):
time.sleep(5)
self.sync()

Expand Down Expand Up @@ -322,7 +331,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
not_adopted_tis.append(ti)

if not celery_tasks:
# Nothing to adopt
# Nothing to adopt.
return tis

states_by_celery_task_id = self.bulk_state_fetcher.get_many(
Expand All @@ -342,9 +351,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

# Set the correct elements of the state dicts, then update this
# like we just queried it.
self.tasks[ti.key] = result
self.workloads[ti.key] = result
self.running.add(ti.key)
self.update_task_state(ti.key, state, info)
self.update_workload_state(ti.key, state, info)
adopted.append(f"{ti} in state {state}")

if adopted:
Expand Down Expand Up @@ -373,7 +382,7 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
return reprs

def revoke_task(self, *, ti: TaskInstance):
celery_async_result = self.tasks.pop(ti.key, None)
celery_async_result = self.workloads.pop(ti.key, None)
if celery_async_result:
try:
self.celery_app.control.revoke(celery_async_result.task_id)
Expand Down
Loading
Loading