Skip to content

Commit

Permalink
refactor: Improved mocking of Context with TaskInstance to allow unit…
Browse files Browse the repository at this point in the history
… testing operators
  • Loading branch information
davidblain-infrabel committed Apr 22, 2024
1 parent 5033423 commit d994f21
Showing 1 changed file with 47 additions and 29 deletions.
76 changes: 47 additions & 29 deletions tests/providers/microsoft/azure/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@
import asyncio
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from typing import TYPE_CHECKING, Any, Iterable
from unittest.mock import patch

from kiota_http.httpx_request_adapter import HttpxRequestAdapter

from airflow.exceptions import TaskDeferred
from airflow.models import Operator, TaskInstance
from airflow.models import Operator
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.utils.session import NEW_SESSION
from airflow.utils.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
from tests.providers.microsoft.conftest import get_airflow_connection, mock_context

if TYPE_CHECKING:
Expand All @@ -38,38 +36,58 @@
from airflow.triggers.base import BaseTrigger, TriggerEvent


class MockedTaskInstance(TaskInstance):
def mock_context(task) -> Context:
from datetime import datetime

from airflow.models import TaskInstance
from airflow.utils.session import NEW_SESSION
from airflow.utils.state import TaskInstanceState
from airflow.utils.xcom import XCOM_RETURN_KEY

values = {}

def xcom_pull(
self,
task_ids: Iterable[str] | str | None = None,
dag_id: str | None = None,
key: str = XCOM_RETURN_KEY,
include_prior_dates: bool = False,
session: Session = NEW_SESSION,
*,
map_indexes: Iterable[int] | int | None = None,
default: Any | None = None,
) -> Any:
self.task_id = task_ids
self.dag_id = dag_id
return self.values.get(f"{task_ids}_{dag_id}_{key}")

def xcom_push(
self,
key: str,
value: Any,
execution_date: datetime | None = None,
session: Session = NEW_SESSION,
) -> None:
self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value
class MockedTaskInstance(TaskInstance):
def __init__(
self,
task,
execution_date: datetime | None = None,
run_id: str | None = "run_id",
state: str | None = TaskInstanceState.RUNNING,
map_index: int = -1,
):
super().__init__(task=task, execution_date=execution_date, run_id=run_id, state=state, map_index=map_index)
self.values = {}

def xcom_pull(
self,
task_ids: Iterable[str] | str | None = None,
dag_id: str | None = None,
key: str = XCOM_RETURN_KEY,
include_prior_dates: bool = False,
session: Session = NEW_SESSION,
*,
map_indexes: Iterable[int] | int | None = None,
default: Any | None = None,
) -> Any:
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}")

def xcom_push(
self,
key: str,
value: Any,
execution_date: datetime | None = None,
session: Session = NEW_SESSION,
) -> None:
values[f"{self.task_id}_{self.dag_id}_{key}"] = value

values["ti"] = MockedTaskInstance(task=task)

return Context(values)


class Base:
def teardown_method(self, method):
KiotaRequestAdapterHook.cached_request_adapters.clear()
MockedTaskInstance.values.clear()

@contextmanager
def patch_hook_and_request_adapter(self, response):
Expand Down

0 comments on commit d994f21

Please sign in to comment.