Skip to content

Commit

Permalink
Fix FileTaskHandler only read from default executor (apache#45631)
Browse files Browse the repository at this point in the history
* Fix FileTaskHandler only read from default executor

* Add cached_property back to avoid loading executors

* Add test for multi-executors scenario

* Allow to call load_executor without init_executors

* Refactor by caching necessary executors

* Refactor test with default executor case

* Fix side effect from executor_loader

* Fix KubernetesExecutor test

- Previous test failure is cuased by cache state of executor_instances
- Should set ti.state = RUNNING after ti.run

* Fix side effect from executor_loader

- The side effect only show up in postgres as backend environment, as
  previous fix only resolve side effect in sqlite as backend environment.
- Also refactor clean_executor_loader as pytest fixture with setup
  teardown

* Capitalize default executor key

* Refactor clean_executor_loader fixture
  • Loading branch information
jason810496 authored Jan 24, 2025
1 parent ce2c891 commit 51dbabc
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 50 deletions.
4 changes: 4 additions & 0 deletions airflow/executors/executor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def init_executors(cls) -> list[BaseExecutor]:
@classmethod
def lookup_executor_name_by_str(cls, executor_name_str: str) -> ExecutorName:
# lookup the executor by alias first, if not check if we're given a module path
if not _classname_to_executors or not _module_to_executors or not _alias_to_executors:
# if we haven't loaded the executors yet, such as directly calling load_executor
cls._get_executor_names()

if executor_name := _alias_to_executors.get(executor_name_str):
return executor_name
elif executor_name := _module_to_executors.get(executor_name_str):
Expand Down
33 changes: 26 additions & 7 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from collections.abc import Iterable
from contextlib import suppress
from enum import Enum
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urljoin
Expand All @@ -44,6 +43,7 @@
if TYPE_CHECKING:
from pendulum import DateTime

from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey

Expand Down Expand Up @@ -179,6 +179,8 @@ class FileTaskHandler(logging.Handler):
inherits_from_empty_operator_log_message = (
"Operator inherits from empty operator and thus does not have logs"
)
executor_instances: dict[str, BaseExecutor] = {}
DEFAULT_EXECUTOR_KEY = "_default_executor"

def __init__(
self,
Expand Down Expand Up @@ -314,11 +316,27 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO
def _read_grouped_logs(self):
return False

@cached_property
def _executor_get_task_log(self) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]:
"""This cached property avoids loading executor repeatedly."""
executor = ExecutorLoader.get_default_executor()
return executor.get_task_log
def _get_executor_get_task_log(
self, ti: TaskInstance
) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]:
"""
Get the get_task_log method from executor of current task instance.
Since there might be multiple executors, so we need to get the executor of current task instance instead of getting from default executor.
:param ti: task instance object
:return: get_task_log method of the executor
"""
executor_name = ti.executor or self.DEFAULT_EXECUTOR_KEY
executor = self.executor_instances.get(executor_name)
if executor is not None:
return executor.get_task_log

if executor_name == self.DEFAULT_EXECUTOR_KEY:
self.executor_instances[executor_name] = ExecutorLoader.get_default_executor()
else:
self.executor_instances[executor_name] = ExecutorLoader.load_executor(executor_name)
return self.executor_instances[executor_name].get_task_log

def _read(
self,
Expand Down Expand Up @@ -360,7 +378,8 @@ def _read(
messages_list.extend(remote_messages)
has_k8s_exec_pod = False
if ti.state == TaskInstanceState.RUNNING:
response = self._executor_get_task_log(ti, try_number)
executor_get_task_log = self._get_executor_get_task_log(ti)
response = executor_get_task_log(ti, try_number)
if response:
executor_messages, executor_logs = response
if executor_messages:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def teardown_method(self):
"airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor.get_task_log"
)
@pytest.mark.parametrize("state", [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS])
@pytest.mark.usefixtures("clean_executor_loader")
def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instance, state):
"""Test for k8s executor, the log is read from get_task_log method"""
mock_k8s_get_task_log.return_value = ([], [])
Expand All @@ -86,6 +87,7 @@ def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instanc
)
ti.state = state
ti.triggerer_job = None
ti.executor = executor_name
with conf_vars({("core", "executor"): executor_name}):
reload(executor_loader)
fth = FileTaskHandler("")
Expand All @@ -105,11 +107,12 @@ def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instanc
pytest.param(k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="pod-name-xxx")), "default"),
],
)
@patch.dict("os.environ", AIRFLOW__CORE__EXECUTOR="KubernetesExecutor")
@conf_vars({("core", "executor"): "KubernetesExecutor"})
@patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
def test_read_from_k8s_under_multi_namespace_mode(
self, mock_kube_client, pod_override, namespace_to_call
):
reload(executor_loader)
mock_read_log = mock_kube_client.return_value.read_namespaced_pod_log
mock_list_pod = mock_kube_client.return_value.list_namespaced_pod

Expand Down Expand Up @@ -139,6 +142,7 @@ def task_callable(ti):
)
ti = TaskInstance(task=task, run_id=dagrun.run_id)
ti.try_number = 3
ti.executor = "KubernetesExecutor"

logger = ti.log
ti.log.disabled = False
Expand All @@ -147,6 +151,8 @@ def task_callable(ti):
set_context(logger, ti)
ti.run(ignore_ti_state=True)
ti.state = TaskInstanceState.RUNNING
# clear executor_instances cache
file_handler.executor_instances = {}
file_handler.read(ti, 2)

# first we find pod name
Expand Down
89 changes: 48 additions & 41 deletions tests/executors/test_executor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
# under the License.
from __future__ import annotations

from importlib import reload
from unittest import mock

import pytest

from airflow.exceptions import AirflowConfigException
from airflow.executors import executor_loader
from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, ExecutorName
from airflow.executors.executor_loader import ConnectorSource, ExecutorName
from airflow.executors.local_executor import LocalExecutor
from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
Expand All @@ -35,24 +34,12 @@ class FakeExecutor:
pass


@pytest.mark.usefixtures("clean_executor_loader")
class TestExecutorLoader:
def setup_method(self) -> None:
from airflow.executors import executor_loader

reload(executor_loader)
global ExecutorLoader
ExecutorLoader = executor_loader.ExecutorLoader # type: ignore

def teardown_method(self) -> None:
from airflow.executors import executor_loader

reload(executor_loader)
ExecutorLoader.init_executors()

def test_no_executor_configured(self):
with conf_vars({("core", "executor"): None}):
with pytest.raises(AirflowConfigException, match=r".*not found in config$"):
ExecutorLoader.get_default_executor()
executor_loader.ExecutorLoader.get_default_executor()

@pytest.mark.parametrize(
"executor_name",
Expand All @@ -66,16 +53,18 @@ def test_no_executor_configured(self):
)
def test_should_support_executor_from_core(self, executor_name):
with conf_vars({("core", "executor"): executor_name}):
executor = ExecutorLoader.get_default_executor()
executor = executor_loader.ExecutorLoader.get_default_executor()
assert executor is not None
assert executor_name == executor.__class__.__name__
assert executor.name is not None
assert executor.name == ExecutorName(ExecutorLoader.executors[executor_name], alias=executor_name)
assert executor.name == ExecutorName(
executor_loader.ExecutorLoader.executors[executor_name], alias=executor_name
)
assert executor.name.connector_source == ConnectorSource.CORE

def test_should_support_custom_path(self):
with conf_vars({("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor"}):
executor = ExecutorLoader.get_default_executor()
executor = executor_loader.ExecutorLoader.get_default_executor()
assert executor is not None
assert executor.__class__.__name__ == "FakeExecutor"
assert executor.name is not None
Expand Down Expand Up @@ -249,17 +238,17 @@ def test_get_hybrid_executors_from_config(
"airflow.executors.executor_loader.ExecutorLoader._get_team_executor_configs",
return_value=team_executor_config,
):
executors = ExecutorLoader._get_executor_names()
executors = executor_loader.ExecutorLoader._get_executor_names()
assert executors == expected_executors_list

def test_init_executors(self):
with conf_vars({("core", "executor"): "CeleryExecutor"}):
executors = ExecutorLoader.init_executors()
executor_name = ExecutorLoader.get_default_executor_name()
executors = executor_loader.ExecutorLoader.init_executors()
executor_name = executor_loader.ExecutorLoader.get_default_executor_name()
assert len(executors) == 1
assert isinstance(executors[0], CeleryExecutor)
assert "CeleryExecutor" in ExecutorLoader.executors
assert ExecutorLoader.executors["CeleryExecutor"] == executor_name.module_path
assert "CeleryExecutor" in executor_loader.ExecutorLoader.executors
assert executor_loader.ExecutorLoader.executors["CeleryExecutor"] == executor_name.module_path

@pytest.mark.parametrize(
"executor_config",
Expand All @@ -276,7 +265,7 @@ def test_get_hybrid_executors_from_config_duplicates_should_fail(self, executor_
with pytest.raises(
AirflowConfigException, match=r".+Duplicate executors are not yet supported.+"
):
ExecutorLoader._get_executor_names()
executor_loader.ExecutorLoader._get_executor_names()

@pytest.mark.parametrize(
"executor_config",
Expand All @@ -292,7 +281,7 @@ def test_get_hybrid_executors_from_config_duplicates_should_fail(self, executor_
def test_get_hybrid_executors_from_config_core_executors_bad_config_format(self, executor_config):
with conf_vars({("core", "executor"): executor_config}):
with pytest.raises(AirflowConfigException):
ExecutorLoader._get_executor_names()
executor_loader.ExecutorLoader._get_executor_names()

@pytest.mark.parametrize(
("executor_config", "expected_value"),
Expand All @@ -308,7 +297,7 @@ def test_get_hybrid_executors_from_config_core_executors_bad_config_format(self,
)
def test_should_support_import_executor_from_core(self, executor_config, expected_value):
with conf_vars({("core", "executor"): executor_config}):
executor, import_source = ExecutorLoader.import_default_executor_cls()
executor, import_source = executor_loader.ExecutorLoader.import_default_executor_cls()
assert expected_value == executor.__name__
assert import_source == ConnectorSource.CORE

Expand All @@ -322,26 +311,43 @@ def test_should_support_import_executor_from_core(self, executor_config, expecte
)
def test_should_support_import_custom_path(self, executor_config):
with conf_vars({("core", "executor"): executor_config}):
executor, import_source = ExecutorLoader.import_default_executor_cls()
executor, import_source = executor_loader.ExecutorLoader.import_default_executor_cls()
assert executor.__name__ == "FakeExecutor"
assert import_source == ConnectorSource.CUSTOM_PATH

def test_load_executor(self):
with conf_vars({("core", "executor"): "LocalExecutor"}):
ExecutorLoader.init_executors()
assert isinstance(ExecutorLoader.load_executor("LocalExecutor"), LocalExecutor)
assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor)
assert isinstance(ExecutorLoader.load_executor(None), LocalExecutor)
executor_loader.ExecutorLoader.init_executors()
assert isinstance(executor_loader.ExecutorLoader.load_executor("LocalExecutor"), LocalExecutor)
assert isinstance(
executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]),
LocalExecutor,
)
assert isinstance(executor_loader.ExecutorLoader.load_executor(None), LocalExecutor)

def test_load_executor_alias(self):
with conf_vars({("core", "executor"): "local_exec:airflow.executors.local_executor.LocalExecutor"}):
ExecutorLoader.init_executors()
assert isinstance(ExecutorLoader.load_executor("local_exec"), LocalExecutor)
executor_loader.ExecutorLoader.init_executors()
assert isinstance(executor_loader.ExecutorLoader.load_executor("local_exec"), LocalExecutor)
assert isinstance(
ExecutorLoader.load_executor("airflow.executors.local_executor.LocalExecutor"),
executor_loader.ExecutorLoader.load_executor(
"airflow.executors.local_executor.LocalExecutor"
),
LocalExecutor,
)
assert isinstance(
executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]),
LocalExecutor,
)
assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor)

@mock.patch(
"airflow.executors.executor_loader.ExecutorLoader._get_executor_names",
wraps=executor_loader.ExecutorLoader._get_executor_names,
)
def test_call_load_executor_method_without_init_executors(self, mock_get_executor_names):
with conf_vars({("core", "executor"): "LocalExecutor"}):
executor_loader.ExecutorLoader.load_executor("LocalExecutor")
mock_get_executor_names.assert_called_once()

@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor", autospec=True)
def test_load_custom_executor_with_classname(self, mock_executor):
Expand All @@ -353,15 +359,16 @@ def test_load_custom_executor_with_classname(self, mock_executor):
): "my_alias:airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"
}
):
ExecutorLoader.init_executors()
assert isinstance(ExecutorLoader.load_executor("my_alias"), AwsEcsExecutor)
assert isinstance(ExecutorLoader.load_executor("AwsEcsExecutor"), AwsEcsExecutor)
executor_loader.ExecutorLoader.init_executors()
assert isinstance(executor_loader.ExecutorLoader.load_executor("my_alias"), AwsEcsExecutor)
assert isinstance(executor_loader.ExecutorLoader.load_executor("AwsEcsExecutor"), AwsEcsExecutor)
assert isinstance(
ExecutorLoader.load_executor(
executor_loader.ExecutorLoader.load_executor(
"airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"
),
AwsEcsExecutor,
)
assert isinstance(
ExecutorLoader.load_executor(executor_loader._executor_names[0]), AwsEcsExecutor
executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]),
AwsEcsExecutor,
)
1 change: 1 addition & 0 deletions tests/ti_deps/deps/test_ready_to_reschedule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def side_effect(*args, **kwargs):
yield m


@pytest.mark.usefixtures("clean_executor_loader")
class TestNotInReschedulePeriodDep:
@pytest.fixture(autouse=True)
def setup_test_cases(self, request, create_task_instance):
Expand Down
Loading

0 comments on commit 51dbabc

Please sign in to comment.