Skip to content

Commit

Permalink
Fix is_deferrable in airflow agent (#2109)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@gmail.com>
  • Loading branch information
pingsutw authored Jan 17, 2024
1 parent 892b474 commit 80ca660
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
1 change: 1 addition & 0 deletions plugins/flytekit-airflow/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
apache-airflow-providers-apache-beam[google]
apache-airflow[google]
7 changes: 5 additions & 2 deletions plugins/flytekit-airflow/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ annotated-types==0.6.0
# via pydantic
anyio==4.0.0
# via httpx
apache-airflow==2.8.0
apache-airflow[google]==2.8.0
# via
# -r dev-requirements.in
# apache-airflow-providers-apache-beam
# apache-airflow-providers-common-sql
# apache-airflow-providers-ftp
Expand All @@ -40,7 +41,9 @@ apache-airflow-providers-common-sql==1.8.0
apache-airflow-providers-ftp==3.6.0
# via apache-airflow
apache-airflow-providers-google==10.11.0
# via apache-airflow-providers-apache-beam
# via
# apache-airflow
# apache-airflow-providers-apache-beam
apache-airflow-providers-http==4.6.0
# via apache-airflow
apache-airflow-providers-imap==3.4.0
Expand Down
17 changes: 10 additions & 7 deletions plugins/flytekit-airflow/flytekitplugins/airflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,7 @@ def _get_airflow_instance(

obj_module = importlib.import_module(name=airflow_obj.module)
obj_def = getattr(obj_module, airflow_obj.name)
if (
issubclass(obj_def, airflow_models.BaseOperator)
and not issubclass(obj_def, airflow_sensors.BaseSensorOperator)
and _is_deferrable(obj_def)
):
if _is_deferrable(obj_def):
try:
return obj_def(**airflow_obj.parameters, deferrable=True)
except airflow.exceptions.AirflowException as e:
Expand All @@ -163,12 +159,19 @@ def _get_airflow_instance(
def _is_deferrable(cls: Type) -> bool:
"""
This function is used to check if the Airflow operator is deferrable.
If the operator is not deferrable, we run it in a container instead of the agent.
"""
# Only Airflow operators are deferrable.
if not issubclass(cls, airflow_models.BaseOperator):
return False
# Airflow sensors are not deferrable. Sensor is a subclass of BaseOperator.
if issubclass(cls, airflow_sensors.BaseSensorOperator):
return False
try:
from airflow.providers.apache.beam.operators.beam import BeamBasePipelineOperator

# Dataflow operators are not deferrable.
if not issubclass(cls, BeamBasePipelineOperator):
if issubclass(cls, BeamBasePipelineOperator):
return False
except ImportError:
logger.debug("Failed to import BeamBasePipelineOperator")
Expand All @@ -194,7 +197,7 @@ def _flyte_operator(*args, **kwargs):
task_id = kwargs["task_id"] or cls.__name__
config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs)

if _is_deferrable(cls):
if not _is_deferrable(cls):
# Dataflow operators are not deferrable, so we run them in a container.
return AirflowContainerTask(name=task_id, task_config=config, container_image=container_image)()
return AirflowTask(name=task_id, task_config=config)()
Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-airflow/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jsonpickle
from airflow.providers.apache.beam.operators.beam import BeamRunJavaPipelineOperator
from airflow.providers.google.cloud.operators.dataproc import DataprocCreateClusterOperator
from airflow.sensors.bash import BashSensor
from airflow.utils.context import Context
from flytekitplugins.airflow.task import (
Expand Down Expand Up @@ -34,8 +35,9 @@ def test_xcom_push():


def test_is_deferrable():
assert _is_deferrable(BeamRunJavaPipelineOperator) is True
assert _is_deferrable(BeamRunJavaPipelineOperator) is False
assert _is_deferrable(BashSensor) is False
assert _is_deferrable(DataprocCreateClusterOperator) is True


def test_airflow_task():
Expand Down

0 comments on commit 80ca660

Please sign in to comment.