From a82ca1c65aed0eda57f38b14c9c81f2ef753a721 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 11:13:44 +0200 Subject: [PATCH 01/18] refactor: Allow MSGraphAsyncOperator and MSGraphSensor to directly start from trigger instead of worker --- .../microsoft/azure/hooks/msgraph.py | 1 - .../microsoft/azure/operators/msgraph.py | 66 ++++++++++++++++--- .../microsoft/azure/sensors/msgraph.py | 62 ++++++++++++++--- 3 files changed, 109 insertions(+), 20 deletions(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index 2674b304f7acc..45e6f36c5947c 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -468,7 +468,6 @@ def request_information( header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json" ) request_information.content = json.dumps(data).encode("utf-8") - print("Request Information:", request_information.url) return request_information @staticmethod diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index f4f6e91fea4de..3498ea341131f 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -17,10 +17,12 @@ # under the License. from __future__ import annotations +import datetime import warnings from collections.abc import Callable, Sequence from contextlib import suppress from copy import deepcopy +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -35,6 +37,21 @@ from airflow.providers.microsoft.azure.version_compat import BaseOperator from airflow.utils.xcom import XCOM_RETURN_KEY +try: + from airflow.triggers.base import StartTriggerArgs +except ImportError: + # TODO: Remove this when min airflow version is 2.10.0 for standard provider + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: datetime.timedelta | None = None + + if TYPE_CHECKING: from io import BytesIO @@ -108,8 +125,17 @@ class MSGraphAsyncOperator(BaseOperator): the message from the event, otherwise the response from the event payload is returned. :param serializer: Class which handles response serialization (default is ResponseSerializer). Bytes will be base64 encoded into a string, so it can be stored as an XCom. + :param start_from_trigger: If set to True, the operator will start directly from the triggerer without going into the worker first. """ + start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs={}, + next_method="execute_complete", + next_kwargs=None, + timeout=None, + ) + start_from_trigger = False template_fields: Sequence[str] = ( "url", "response_type", @@ -142,6 +168,7 @@ def __init__( result_processor: Callable[[Any, Context], Any] = lambda result, **context: result, event_handler: Callable[[dict[Any, Any] | None, Context], Any] | None = None, serializer: type[ResponseSerializer] = ResponseSerializer, + start_from_trigger: bool = False, **kwargs: Any, ): super().__init__(**kwargs) @@ -163,10 +190,9 @@ def __init__( self.result_processor = result_processor self.event_handler = event_handler or default_event_handler self.serializer: ResponseSerializer = serializer() - - def execute(self, context: Context) -> None: - self.defer( - trigger=MSGraphTrigger( + self.start_from_trigger = start_from_trigger + if self.start_from_trigger: + self.start_trigger_args.trigger_kwargs = dict( url=self.url, response_type=self.response_type, path_parameters=self.path_parameters, @@ -180,10 +206,30 @@ def execute(self, context: Context) -> None: proxies=self.proxies, scopes=self.scopes, api_version=self.api_version, - serializer=type(self.serializer), - ), - method_name=self.execute_complete.__name__, - ) + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ) + + def execute(self, context: Context) -> None: + if not self.start_from_trigger: + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) def execute_complete( self, @@ -229,14 +275,14 @@ def execute_complete( self.trigger_next_link( response=response, method_name=self.execute_complete.__name__, context=context ) - except TaskDeferred as exception: + except TaskDeferred as task_deferred: self.append_result( results=results, result=result, append_result_as_list_if_absent=True, ) self.push_xcom(context=context, value=results) - raise exception + raise task_deferred if not results: return result diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index d59cb13a46f89..361ce54a6f321 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -17,7 +17,9 @@ # under the License. from __future__ import annotations +import datetime from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException @@ -27,6 +29,20 @@ from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS +try: + from airflow.triggers.base import StartTriggerArgs +except ImportError: + # TODO: Remove this when min airflow version is 2.10.0 for standard provider + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: datetime.timedelta | None = None + if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseSensorOperator else: @@ -63,8 +79,16 @@ class MSGraphSensor(BaseSensorOperator): `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string. :param serializer: Class which handles response serialization (default is ResponseSerializer). Bytes will be base64 encoded into a string, so it can be stored as an XCom. + :param start_from_trigger: If set to True, the sensor will start directly from the triggerer without going into the worker first. """ - + start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs={}, + next_method="execute_complete", + next_kwargs=None, + timeout=None, + ) + start_from_trigger = False template_fields: Sequence[str] = ( "url", "response_type", @@ -94,6 +118,7 @@ def __init__( result_processor: Callable[[Any, Context], Any] = lambda result, **context: result, serializer: type[ResponseSerializer] = ResponseSerializer, retry_delay: timedelta | float = 60, + start_from_trigger: bool = False, **kwargs, ): super().__init__(retry_delay=retry_delay, **kwargs) @@ -112,10 +137,9 @@ def __init__( self.event_processor = event_processor self.result_processor = result_processor self.serializer = serializer() - - def execute(self, context: Context): - self.defer( - trigger=MSGraphTrigger( + self.start_from_trigger = start_from_trigger + if self.start_from_trigger: + self.start_trigger_args.trigger_kwargs = dict( url=self.url, response_type=self.response_type, path_parameters=self.path_parameters, @@ -129,10 +153,30 @@ def execute(self, context: Context): proxies=self.proxies, scopes=self.scopes, api_version=self.api_version, - serializer=type(self.serializer), - ), - method_name=self.execute_complete.__name__, - ) + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ) + + def execute(self, context: Context): + if not self.start_from_trigger: + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) def retry_execute( self, From cb3d71850827109e6d2929bf5a2c1f4879151d1f Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 11:34:53 +0200 Subject: [PATCH 02/18] refactor: Reformatted MSGraphSensor --- .../src/airflow/providers/microsoft/azure/sensors/msgraph.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 361ce54a6f321..eb3dcaeb5f112 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -43,6 +43,7 @@ class StartTriggerArgs: # type: ignore[no-redef] next_kwargs: dict[str, Any] | None = None timeout: datetime.timedelta | None = None + if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseSensorOperator else: @@ -81,6 +82,7 @@ class MSGraphSensor(BaseSensorOperator): Bytes will be base64 encoded into a string, so it can be stored as an XCom. :param start_from_trigger: If set to True, the sensor will start directly from the triggerer without going into the worker first. """ + start_trigger_args = StartTriggerArgs( trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", trigger_kwargs={}, From 9ae30b2877b14ffd72ec831219e9517f3f7ee15c Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 15:33:34 +0200 Subject: [PATCH 03/18] refactor: Updated TODO in ImportError clause --- .../src/airflow/providers/microsoft/azure/operators/msgraph.py | 2 +- .../src/airflow/providers/microsoft/azure/sensors/msgraph.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 3498ea341131f..d3433380f8c7a 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -40,7 +40,7 @@ try: from airflow.triggers.base import StartTriggerArgs except ImportError: - # TODO: Remove this when min airflow version is 2.10.0 for standard provider + # TODO: Remove this when min airflow version is 2.10.0 @dataclass class StartTriggerArgs: # type: ignore[no-redef] """Arguments required for start task execution from triggerer.""" diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index eb3dcaeb5f112..817231e88e125 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -32,7 +32,7 @@ try: from airflow.triggers.base import StartTriggerArgs except ImportError: - # TODO: Remove this when min airflow version is 2.10.0 for standard provider + # TODO: Remove this when min airflow version is 2.10.0 @dataclass class StartTriggerArgs: # type: ignore[no-redef] """Arguments required for start task execution from triggerer.""" From 57a119451183e45450737273620bda7a6eb82002 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 15:55:36 +0200 Subject: [PATCH 04/18] refactor: Refactored MSGraphAsyncOperator and MSGraphSensor to always start from trigger --- .../test_utils/operators/run_deferrable.py | 10 +++ .../microsoft/azure/operators/msgraph.py | 60 +++++----------- .../microsoft/azure/sensors/msgraph.py | 70 +++++++++---------- 3 files changed, 62 insertions(+), 78 deletions(-) diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index cdcc781767648..504138b5e5f3f 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import TaskDeferred +from airflow.utils.module_loading import import_string from tests_common.test_utils.mock_context import mock_context @@ -50,6 +51,15 @@ async def deferrable_operator(context, operator): triggered_events = [] try: operator.render_template_fields(context=context) + if operator.start_from_trigger: + trigger_cls = import_string(operator.start_trigger_args.trigger_cls) + trigger = trigger_cls(**operator.start_trigger_args.trigger_kwargs) + raise TaskDeferred( + trigger=trigger, + method_name=operator.start_trigger_args.next_method, + kwargs=operator.start_trigger_args.next_kwargs, + timeout=operator.start_trigger_args.timeout, + ) result = operator.execute(context=context) except TaskDeferred as deferred: task = deferred diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index d3433380f8c7a..2223309bf82bc 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -125,7 +125,6 @@ class MSGraphAsyncOperator(BaseOperator): the message from the event, otherwise the response from the event payload is returned. :param serializer: Class which handles response serialization (default is ResponseSerializer). Bytes will be base64 encoded into a string, so it can be stored as an XCom. - :param start_from_trigger: If set to True, the operator will start directly from the triggerer without going into the worker first. """ start_trigger_args = StartTriggerArgs( @@ -135,7 +134,7 @@ class MSGraphAsyncOperator(BaseOperator): next_kwargs=None, timeout=None, ) - start_from_trigger = False + start_from_trigger = True template_fields: Sequence[str] = ( "url", "response_type", @@ -168,7 +167,6 @@ def __init__( result_processor: Callable[[Any, Context], Any] = lambda result, **context: result, event_handler: Callable[[dict[Any, Any] | None, Context], Any] | None = None, serializer: type[ResponseSerializer] = ResponseSerializer, - start_from_trigger: bool = False, **kwargs: Any, ): super().__init__(**kwargs) @@ -190,46 +188,26 @@ def __init__( self.result_processor = result_processor self.event_handler = event_handler or default_event_handler self.serializer: ResponseSerializer = serializer() - self.start_from_trigger = start_from_trigger - if self.start_from_trigger: - self.start_trigger_args.trigger_kwargs = dict( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", - ) + self.start_trigger_args.next_method = self.execute_complete.__name__ + self.start_trigger_args.trigger_kwargs = dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ) def execute(self, context: Context) -> None: - if not self.start_from_trigger: - self.defer( - trigger=MSGraphTrigger( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=type(self.serializer), - ), - method_name=self.execute_complete.__name__, - ) + return def execute_complete( self, diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 817231e88e125..7b66152cc554b 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -80,7 +80,6 @@ class MSGraphSensor(BaseSensorOperator): `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string. :param serializer: Class which handles response serialization (default is ResponseSerializer). Bytes will be base64 encoded into a string, so it can be stored as an XCom. - :param start_from_trigger: If set to True, the sensor will start directly from the triggerer without going into the worker first. """ start_trigger_args = StartTriggerArgs( @@ -90,7 +89,7 @@ class MSGraphSensor(BaseSensorOperator): next_kwargs=None, timeout=None, ) - start_from_trigger = False + start_from_trigger = True template_fields: Sequence[str] = ( "url", "response_type", @@ -120,7 +119,6 @@ def __init__( result_processor: Callable[[Any, Context], Any] = lambda result, **context: result, serializer: type[ResponseSerializer] = ResponseSerializer, retry_delay: timedelta | float = 60, - start_from_trigger: bool = False, **kwargs, ): super().__init__(retry_delay=retry_delay, **kwargs) @@ -139,9 +137,34 @@ def __init__( self.event_processor = event_processor self.result_processor = result_processor self.serializer = serializer() - self.start_from_trigger = start_from_trigger - if self.start_from_trigger: - self.start_trigger_args.trigger_kwargs = dict( + self.start_trigger_args.next_method = self.execute_complete.__name__ + self.start_trigger_args.trigger_kwargs = dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ) + + def execute(self, context: Context): + return + + def retry_execute( + self, + context: Context, + **kwargs, + ) -> Any: + self.defer( + trigger=MSGraphTrigger( url=self.url, response_type=self.response_type, path_parameters=self.path_parameters, @@ -155,37 +178,10 @@ def __init__( proxies=self.proxies, scopes=self.scopes, api_version=self.api_version, - serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", - ) - - def execute(self, context: Context): - if not self.start_from_trigger: - self.defer( - trigger=MSGraphTrigger( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=type(self.serializer), - ), - method_name=self.execute_complete.__name__, - ) - - def retry_execute( - self, - context: Context, - **kwargs, - ) -> Any: - self.execute(context=context) + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) def execute_complete( self, From be8aaf17bce1e341991a87aa4d0d9cf3e97e7de6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 16:38:16 +0200 Subject: [PATCH 05/18] refactor: Added test for api version in KiotaRequestAdapterHook test --- .../tests/unit/microsoft/azure/hooks/test_msgraph.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py index 6e06b61f931e5..f5aede98b0ab4 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py @@ -212,6 +212,16 @@ def test_api_version(self): assert hook.api_version == APIVersion.v1.value + def test_api_version_when_none_is_explicitly_passed_as_api_version(self): + with patch( + f"{BASEHOOK_PATCH_PATH}.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", api_version=None) + actual = hook.api_version + + assert actual == APIVersion.v1.value + def test_get_api_version_when_empty_config_dict(self): with patch( f"{BASEHOOK_PATCH_PATH}.get_connection", From 8442e5dcf701eb527aae6faa002efd110ee66a33 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 16:47:49 +0200 Subject: [PATCH 06/18] refactor: Log RequestInformation URL in KiotaRequestAdapterHook --- .../azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index 45e6f36c5947c..53c76b29eb4e8 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -468,6 +468,7 @@ def request_information( header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json" ) request_information.content = json.dumps(data).encode("utf-8") + self.log.info("Request Information: %s", request_information.url) return request_information @staticmethod From 8279e588bfd5d3b38fced69eca8e59769577511f Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 17:33:16 +0200 Subject: [PATCH 07/18] refactor: Do render templated fields before initialising StartTriggerArgs --- .../tests_common/test_utils/operators/run_deferrable.py | 9 ++++++--- .../providers/microsoft/azure/operators/msgraph.py | 6 +++++- .../airflow/providers/microsoft/azure/sensors/msgraph.py | 6 +++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index 504138b5e5f3f..923351eb54bec 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -22,10 +22,13 @@ from airflow.exceptions import TaskDeferred from airflow.utils.module_loading import import_string +from airflow.utils.session import NEW_SESSION from tests_common.test_utils.mock_context import mock_context if TYPE_CHECKING: + from sqlalchemy.orm import Session + from airflow.models import Operator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -46,14 +49,14 @@ def execute_operator(operator: Operator) -> tuple[Any, Any]: return asyncio.run(deferrable_operator(context, operator)) -async def deferrable_operator(context, operator): +async def deferrable_operator(context, operator, session: Session = NEW_SESSION): result = None triggered_events = [] try: operator.render_template_fields(context=context) - if operator.start_from_trigger: + if operator.expand_start_from_trigger(context=context, session=session): trigger_cls = import_string(operator.start_trigger_args.trigger_cls) - trigger = trigger_cls(**operator.start_trigger_args.trigger_kwargs) + trigger = trigger_cls(**operator.expand_start_trigger_args(context=context, session=session).trigger_kwargs) raise TaskDeferred( trigger=trigger, method_name=operator.start_trigger_args.next_method, diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 2223309bf82bc..7a9da5f63a1a1 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -54,8 +54,8 @@ class StartTriggerArgs: # type: ignore[no-redef] if TYPE_CHECKING: from io import BytesIO - from msgraph_core import APIVersion + from sqlalchemy.orm import Session from airflow.utils.context import Context @@ -189,6 +189,9 @@ def __init__( self.event_handler = event_handler or default_event_handler self.serializer: ResponseSerializer = serializer() self.start_trigger_args.next_method = self.execute_complete.__name__ + + def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: + self.render_template_fields(context=context) self.start_trigger_args.trigger_kwargs = dict( url=self.url, response_type=self.response_type, @@ -205,6 +208,7 @@ def __init__( api_version=self.api_version, serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", ) + return self.start_trigger_args def execute(self, context: Context) -> None: return diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 7b66152cc554b..af5fc081070a3 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -52,8 +52,8 @@ class StartTriggerArgs: # type: ignore[no-redef] if TYPE_CHECKING: from datetime import timedelta from io import BytesIO - from msgraph_core import APIVersion + from sqlalchemy.orm import Session from airflow.utils.context import Context @@ -138,6 +138,9 @@ def __init__( self.result_processor = result_processor self.serializer = serializer() self.start_trigger_args.next_method = self.execute_complete.__name__ + + def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: + self.render_template_fields(context=context) self.start_trigger_args.trigger_kwargs = dict( url=self.url, response_type=self.response_type, @@ -154,6 +157,7 @@ def __init__( api_version=self.api_version, serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", ) + return self.start_trigger_args def execute(self, context: Context): return From 666b026d9f766b29ae1e0fc4f46f506afd5859eb Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Jul 2025 18:34:09 +0200 Subject: [PATCH 08/18] refactor: Initialise StartTriggerArgs in constructor --- .../microsoft/azure/operators/msgraph.py | 16 ++++++---------- .../providers/microsoft/azure/sensors/msgraph.py | 16 ++++++---------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 7a9da5f63a1a1..a3f9b4a0287e7 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -127,13 +127,6 @@ class MSGraphAsyncOperator(BaseOperator): Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ - start_trigger_args = StartTriggerArgs( - trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", - trigger_kwargs={}, - next_method="execute_complete", - next_kwargs=None, - timeout=None, - ) start_from_trigger = True template_fields: Sequence[str] = ( "url", @@ -188,9 +181,12 @@ def __init__( self.result_processor = result_processor self.event_handler = event_handler or default_event_handler self.serializer: ResponseSerializer = serializer() - self.start_trigger_args.next_method = self.execute_complete.__name__ + self.start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + next_method=self.execute_complete.__name__, + ) - def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: + def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: self.render_template_fields(context=context) self.start_trigger_args.trigger_kwargs = dict( url=self.url, @@ -208,7 +204,7 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St api_version=self.api_version, serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", ) - return self.start_trigger_args + return self.start_from_trigger def execute(self, context: Context) -> None: return diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index af5fc081070a3..df8e6ddca337b 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -82,13 +82,6 @@ class MSGraphSensor(BaseSensorOperator): Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ - start_trigger_args = StartTriggerArgs( - trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", - trigger_kwargs={}, - next_method="execute_complete", - next_kwargs=None, - timeout=None, - ) start_from_trigger = True template_fields: Sequence[str] = ( "url", @@ -137,9 +130,12 @@ def __init__( self.event_processor = event_processor self.result_processor = result_processor self.serializer = serializer() - self.start_trigger_args.next_method = self.execute_complete.__name__ + self.start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + next_method=self.execute_complete.__name__, + ) - def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: + def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: self.render_template_fields(context=context) self.start_trigger_args.trigger_kwargs = dict( url=self.url, @@ -157,7 +153,7 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St api_version=self.api_version, serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", ) - return self.start_trigger_args + return self.start_from_trigger def execute(self, context: Context): return From 2aaa6ba3e53cfed9896946072db842d900f8bbd0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 08:03:42 +0200 Subject: [PATCH 09/18] refactor: Only allow start_from_trigger in MSGraphOperator if Airflow is 3.1 or higher as it will contain fix for templated fields --- .../microsoft/azure/operators/msgraph.py | 64 ++++++++++------- .../microsoft/azure/sensors/msgraph.py | 68 ++++++++++++------- 2 files changed, 82 insertions(+), 50 deletions(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index a3f9b4a0287e7..bbcd3b69aad3e 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -28,13 +28,14 @@ Any, ) -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.triggers.msgraph import ( MSGraphTrigger, ResponseSerializer, ) -from airflow.providers.microsoft.azure.version_compat import BaseOperator +from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_1_PLUS, BaseOperator + +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.utils.xcom import XCOM_RETURN_KEY try: @@ -55,7 +56,6 @@ class StartTriggerArgs: # type: ignore[no-redef] if TYPE_CHECKING: from io import BytesIO from msgraph_core import APIVersion - from sqlalchemy.orm import Session from airflow.utils.context import Context @@ -127,7 +127,7 @@ class MSGraphAsyncOperator(BaseOperator): Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ - start_from_trigger = True + start_from_trigger = AIRFLOW_V_3_1_PLUS template_fields: Sequence[str] = ( "url", "response_type", @@ -183,30 +183,46 @@ def __init__( self.serializer: ResponseSerializer = serializer() self.start_trigger_args = StartTriggerArgs( trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs=dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ), next_method=self.execute_complete.__name__, ) - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: - self.render_template_fields(context=context) - self.start_trigger_args.trigger_kwargs = dict( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", - ) - return self.start_from_trigger - def execute(self, context: Context) -> None: + if not AIRFLOW_V_3_1_PLUS: + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) return def execute_complete( diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index df8e6ddca337b..6b659516cde43 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -22,12 +22,13 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException -from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.operators.msgraph import execute_callable from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer -from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS + +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger try: from airflow.triggers.base import StartTriggerArgs @@ -53,7 +54,6 @@ class StartTriggerArgs: # type: ignore[no-redef] from datetime import timedelta from io import BytesIO from msgraph_core import APIVersion - from sqlalchemy.orm import Session from airflow.utils.context import Context @@ -82,7 +82,7 @@ class MSGraphSensor(BaseSensorOperator): Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ - start_from_trigger = True + start_from_trigger = AIRFLOW_V_3_1_PLUS template_fields: Sequence[str] = ( "url", "response_type", @@ -132,30 +132,46 @@ def __init__( self.serializer = serializer() self.start_trigger_args = StartTriggerArgs( trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs=dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ), next_method=self.execute_complete.__name__, ) - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: - self.render_template_fields(context=context) - self.start_trigger_args.trigger_kwargs = dict( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", - ) - return self.start_from_trigger - - def execute(self, context: Context): + def execute(self, context: Context) -> None: + if not AIRFLOW_V_3_1_PLUS: + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) return def retry_execute( From 996aaf597a77d30b40b562a53edea27b2ee37fda Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 09:11:32 +0200 Subject: [PATCH 10/18] refactor: Check on start_from_trigger instead of Airflow version in execute method --- .../src/airflow/providers/microsoft/azure/operators/msgraph.py | 2 +- .../src/airflow/providers/microsoft/azure/sensors/msgraph.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index bbcd3b69aad3e..1ce163573a5c5 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -203,7 +203,7 @@ def __init__( ) def execute(self, context: Context) -> None: - if not AIRFLOW_V_3_1_PLUS: + if not self.start_from_trigger: self.defer( trigger=MSGraphTrigger( url=self.url, diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 6b659516cde43..1b9b8c951937d 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -152,7 +152,7 @@ def __init__( ) def execute(self, context: Context) -> None: - if not AIRFLOW_V_3_1_PLUS: + if not self.start_from_trigger: self.defer( trigger=MSGraphTrigger( url=self.url, From 1d81786928d16e5028ba2a218881bf99ef2d3c4a Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 11:29:26 +0200 Subject: [PATCH 11/18] refactor: Make sure base_url ends with slash --- .../airflow/providers/microsoft/azure/hooks/msgraph.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index 53c76b29eb4e8..17c5be09c7d00 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -209,6 +209,13 @@ def get_host(self, connection: Connection) -> str: return f"{connection.schema}://{connection.host}" return self.host + def get_base_url(self, host: str, api_version: str, config: dict) -> str: + base_url = config.get("base_url", urljoin(host, api_version)).strip() + + if not base_url.endswith("/"): + return f"{base_url}/" + return base_url + @staticmethod def format_no_proxy_url(url: str) -> str: if "://" not in url: @@ -255,7 +262,7 @@ def get_conn(self) -> RequestAdapter: config = connection.extra_dejson if connection.extra else {} api_version = self.get_api_version(config) host = self.get_host(connection) # type: ignore[arg-type] - base_url = config.get("base_url", urljoin(host, api_version)) + base_url = self.get_base_url(host, api_version, config) authority = config.get("authority") proxies = self.get_proxies(config) httpx_proxies = self.to_httpx_proxies(proxies=proxies) From 6787f0eea2eec6a86e5717e3318721ab9557164c Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 12:39:38 +0200 Subject: [PATCH 12/18] refactor: Disable start_trigger_args if import of StartTriggerArgs fails --- .../microsoft/azure/operators/msgraph.py | 63 ++++++++----------- .../microsoft/azure/sensors/msgraph.py | 63 ++++++++----------- 2 files changed, 52 insertions(+), 74 deletions(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 1ce163573a5c5..3ea044387e5ff 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -17,12 +17,10 @@ # under the License. from __future__ import annotations -import datetime import warnings from collections.abc import Callable, Sequence from contextlib import suppress from copy import deepcopy -from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -38,21 +36,6 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.utils.xcom import XCOM_RETURN_KEY -try: - from airflow.triggers.base import StartTriggerArgs -except ImportError: - # TODO: Remove this when min airflow version is 2.10.0 - @dataclass - class StartTriggerArgs: # type: ignore[no-redef] - """Arguments required for start task execution from triggerer.""" - - trigger_cls: str - next_method: str - trigger_kwargs: dict[str, Any] | None = None - next_kwargs: dict[str, Any] | None = None - timeout: datetime.timedelta | None = None - - if TYPE_CHECKING: from io import BytesIO from msgraph_core import APIVersion @@ -181,26 +164,32 @@ def __init__( self.result_processor = result_processor self.event_handler = event_handler or default_event_handler self.serializer: ResponseSerializer = serializer() - self.start_trigger_args = StartTriggerArgs( - trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", - trigger_kwargs=dict( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", - ), - next_method=self.execute_complete.__name__, - ) + if self.start_from_trigger: + try: + from airflow.triggers.base import StartTriggerArgs + + self.start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs=dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ), + next_method=self.execute_complete.__name__, + ) + except ImportError: + self.start_from_trigger = False def execute(self, context: Context) -> None: if not self.start_from_trigger: diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 1b9b8c951937d..45ce2e0d68eb8 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -17,9 +17,7 @@ # under the License. from __future__ import annotations -import datetime from collections.abc import Callable, Sequence -from dataclasses import dataclass from typing import TYPE_CHECKING, Any from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook @@ -30,21 +28,6 @@ from airflow.exceptions import AirflowException from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger -try: - from airflow.triggers.base import StartTriggerArgs -except ImportError: - # TODO: Remove this when min airflow version is 2.10.0 - @dataclass - class StartTriggerArgs: # type: ignore[no-redef] - """Arguments required for start task execution from triggerer.""" - - trigger_cls: str - next_method: str - trigger_kwargs: dict[str, Any] | None = None - next_kwargs: dict[str, Any] | None = None - timeout: datetime.timedelta | None = None - - if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseSensorOperator else: @@ -130,26 +113,32 @@ def __init__( self.event_processor = event_processor self.result_processor = result_processor self.serializer = serializer() - self.start_trigger_args = StartTriggerArgs( - trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", - trigger_kwargs=dict( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", - ), - next_method=self.execute_complete.__name__, - ) + if self.start_from_trigger: + try: + from airflow.triggers.base import StartTriggerArgs + + self.start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs=dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ), + next_method=self.execute_complete.__name__, + ) + except ImportError: + self.start_from_trigger = False def execute(self, context: Context) -> None: if not self.start_from_trigger: From 8375eba3b03b100c4cdc379eb98eebc48b4e2ef6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 12:43:02 +0200 Subject: [PATCH 13/18] refactor: Import BaseSensorOperator from compat_version --- .../airflow/providers/microsoft/azure/sensors/msgraph.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 45ce2e0d68eb8..461d83d8d229b 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -23,16 +23,11 @@ from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.operators.msgraph import execute_callable from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer -from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS +from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, BaseSensorOperator from airflow.exceptions import AirflowException from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseSensorOperator -else: - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] - if TYPE_CHECKING: from datetime import timedelta from io import BytesIO From 980e3377f4649f94d7b17409e31b020adc405f99 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 13:02:15 +0200 Subject: [PATCH 14/18] refactor: Allow task to be None in deferrable_operator --- .../src/tests_common/test_utils/operators/run_deferrable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index 923351eb54bec..ab9f2c3151642 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -65,7 +65,7 @@ async def deferrable_operator(context, operator, session: Session = NEW_SESSION) ) result = operator.execute(context=context) except TaskDeferred as deferred: - task = deferred + task: TaskDeferred | None = deferred while task: events = await run_tigger(task.trigger) From 58dd886edf12c75f8be5b7d1a130bbda9ca728c5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 13:37:44 +0200 Subject: [PATCH 15/18] refactor: Reformatted files --- .../test_utils/operators/run_deferrable.py | 4 +++- .../providers/microsoft/azure/operators/msgraph.py | 4 ++-- .../providers/microsoft/azure/sensors/msgraph.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index ab9f2c3151642..2b5f672c0ef97 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -56,7 +56,9 @@ async def deferrable_operator(context, operator, session: Session = NEW_SESSION) operator.render_template_fields(context=context) if operator.expand_start_from_trigger(context=context, session=session): trigger_cls = import_string(operator.start_trigger_args.trigger_cls) - trigger = trigger_cls(**operator.expand_start_trigger_args(context=context, session=session).trigger_kwargs) + trigger = trigger_cls( + **operator.expand_start_trigger_args(context=context, session=session).trigger_kwargs + ) raise TaskDeferred( trigger=trigger, method_name=operator.start_trigger_args.next_method, diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 3ea044387e5ff..090a088ef4f06 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -26,18 +26,18 @@ Any, ) +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.triggers.msgraph import ( MSGraphTrigger, ResponseSerializer, ) from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_1_PLUS, BaseOperator - -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: from io import BytesIO + from msgraph_core import APIVersion from airflow.utils.context import Context diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 461d83d8d229b..c0b503a90a3c5 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -20,17 +20,20 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.operators.msgraph import execute_callable from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer -from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, BaseSensorOperator - -from airflow.exceptions import AirflowException -from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger +from airflow.providers.microsoft.azure.version_compat import ( + AIRFLOW_V_3_1_PLUS, + BaseSensorOperator, +) if TYPE_CHECKING: from datetime import timedelta from io import BytesIO + from msgraph_core import APIVersion from airflow.utils.context import Context From 9e2c7ecb2282b1f5f6ab8d8f08296d4659b433a5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 14:50:07 +0200 Subject: [PATCH 16/18] refactor: Refactored deferrable_operator to check on start_from_trigger instead of expand_start_from_trigger --- .../src/tests_common/test_utils/operators/run_deferrable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index 2b5f672c0ef97..103db30e807fd 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -53,8 +53,7 @@ async def deferrable_operator(context, operator, session: Session = NEW_SESSION) result = None triggered_events = [] try: - operator.render_template_fields(context=context) - if operator.expand_start_from_trigger(context=context, session=session): + if operator.start_from_trigger: trigger_cls = import_string(operator.start_trigger_args.trigger_cls) trigger = trigger_cls( **operator.expand_start_trigger_args(context=context, session=session).trigger_kwargs @@ -65,6 +64,7 @@ async def deferrable_operator(context, operator, session: Session = NEW_SESSION) kwargs=operator.start_trigger_args.next_kwargs, timeout=operator.start_trigger_args.timeout, ) + operator.render_template_fields(context=context) result = operator.execute(context=context) except TaskDeferred as deferred: task: TaskDeferred | None = deferred From f51cd7bd10c8a8597e533315991f4bf9f11aca0f Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 15:52:07 +0200 Subject: [PATCH 17/18] refactor: Refactored deferrable_operator to use start_trigger_args instead of expand_start_trigger_args --- .../src/tests_common/test_utils/operators/run_deferrable.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index 103db30e807fd..499cdd8599cd8 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -53,18 +53,16 @@ async def deferrable_operator(context, operator, session: Session = NEW_SESSION) result = None triggered_events = [] try: + operator.render_template_fields(context=context) if operator.start_from_trigger: trigger_cls = import_string(operator.start_trigger_args.trigger_cls) - trigger = trigger_cls( - **operator.expand_start_trigger_args(context=context, session=session).trigger_kwargs - ) + trigger = trigger_cls(**operator.start_trigger_args.trigger_kwargs) raise TaskDeferred( trigger=trigger, method_name=operator.start_trigger_args.next_method, kwargs=operator.start_trigger_args.next_kwargs, timeout=operator.start_trigger_args.timeout, ) - operator.render_template_fields(context=context) result = operator.execute(context=context) except TaskDeferred as deferred: task: TaskDeferred | None = deferred From a2677fbb6364230f4c227881f13f2e266c7c0043 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Jul 2025 19:12:35 +0200 Subject: [PATCH 18/18] Revert "refactor: Refactored deferrable_operator to use start_trigger_args instead of expand_start_trigger_args" This reverts commit f51cd7bd10c8a8597e533315991f4bf9f11aca0f. --- .../src/tests_common/test_utils/operators/run_deferrable.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index 499cdd8599cd8..103db30e807fd 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -53,16 +53,18 @@ async def deferrable_operator(context, operator, session: Session = NEW_SESSION) result = None triggered_events = [] try: - operator.render_template_fields(context=context) if operator.start_from_trigger: trigger_cls = import_string(operator.start_trigger_args.trigger_cls) - trigger = trigger_cls(**operator.start_trigger_args.trigger_kwargs) + trigger = trigger_cls( + **operator.expand_start_trigger_args(context=context, session=session).trigger_kwargs + ) raise TaskDeferred( trigger=trigger, method_name=operator.start_trigger_args.next_method, kwargs=operator.start_trigger_args.next_kwargs, timeout=operator.start_trigger_args.timeout, ) + operator.render_template_fields(context=context) result = operator.execute(context=context) except TaskDeferred as deferred: task: TaskDeferred | None = deferred