Skip to content

Commit

Permalink
Add conditional output processing in SQL operators (apache#31136)
Browse files Browse the repository at this point in the history
The change adds conditional processing of output based on
criteria that can be overridden by the operator extending the
common.sql BaseSQLOperator. Originally, output processing has only
been happening if "do_xcom_push" was enabled, but in some cases
we want to run processing also when do_xcom_push is disabled
(for example in case of databricks SQL operator, it might be
done when the output is redirected to a file).

This change enables it.

Fixes: apache#31080
  • Loading branch information
potiuk authored May 9, 2023
1 parent 521dae5 commit edd7133
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 6 deletions.
7 changes: 5 additions & 2 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen
self.log.info("Operator output is: %s", results)
return results

def _should_run_output_processing(self) -> bool:
return self.do_xcom_push

def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()
Expand All @@ -269,11 +272,11 @@ def execute(self, context):
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
handler=self.handler if self.do_xcom_push else None,
handler=self.handler if self._should_run_output_processing() else None,
return_last=self.return_last,
**extra_kwargs,
)
if not self.do_xcom_push:
if not self._should_run_output_processing():
return None
if return_single_query_results(self.sql, self.return_last, self.split_statements):
# For simplicity, we pass always list as input to _process_output, regardless if
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ description: |
suspended: false
versions:
- 1.5.0
- 1.4.0
- 1.3.4
- 1.3.3
Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def get_db_hook(self) -> DatabricksSqlHook:
}
return DatabricksSqlHook(self.databricks_conn_id, **hook_params)

def _should_run_output_processing(self) -> bool:
return self.do_xcom_push or bool(self._output_path)

def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
if not self._output_path:
return list(zip(descriptions, results))
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ versions:

dependencies:
- apache-airflow>=2.4.0
- apache-airflow-providers-common-sql>=1.3.1
- apache-airflow-providers-common-sql>=1.5.0
- requests>=2.27,<3
- databricks-sql-connector>=2.0.0, <3.0.0
- aiohttp>=3.6.3, <4
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
"databricks": {
"deps": [
"aiohttp>=3.6.3, <4",
"apache-airflow-providers-common-sql>=1.3.1",
"apache-airflow-providers-common-sql>=1.5.0",
"apache-airflow>=2.4.0",
"databricks-sql-connector>=2.0.0, <3.0.0",
"requests>=2.27,<3"
Expand Down
23 changes: 21 additions & 2 deletions tests/providers/databricks/operators/test_databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,15 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc


@pytest.mark.parametrize(
"return_last, split_statements, sql, descriptions, hook_results",
"return_last, split_statements, sql, descriptions, hook_results, do_xcom_push",
[
pytest.param(
True,
False,
"select * from dummy",
[[("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
True,
id="Scalar: return_last True and split_statement False",
),
pytest.param(
Expand All @@ -168,6 +169,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
"select * from dummy",
[[("id",), ("value",)]],
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
True,
id="Non-Scalar: return_last False and split_statement True",
),
pytest.param(
Expand All @@ -176,6 +178,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
"select * from dummy",
[[("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
True,
id="Scalar: return_last True and no split_statement True",
),
pytest.param(
Expand All @@ -184,6 +187,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
"select * from dummy",
[[("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
True,
id="Scalar: return_last False and split_statement is False",
),
pytest.param(
Expand All @@ -195,6 +199,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
],
True,
id="Non-Scalar: return_last False and split_statement is True",
),
pytest.param(
Expand All @@ -203,6 +208,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
"select * from dummy2; select * from dummy",
[[("id2",), ("value2",)], [("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
True,
id="Scalar: return_last True and split_statement is True",
),
pytest.param(
Expand All @@ -211,6 +217,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
"select * from dummy2; select * from dummy",
[[("id2",), ("value2",)], [("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
True,
id="Scalar: return_last True and split_statement is True",
),
pytest.param(
Expand All @@ -219,6 +226,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
["select * from dummy2", "select * from dummy"],
[[("id2",), ("value2",)], [("id",), ("value",)]],
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
True,
id="Non-Scalar: sql is list and return_last is True",
),
pytest.param(
Expand All @@ -227,11 +235,21 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
["select * from dummy2", "select * from dummy"],
[[("id2",), ("value2",)], [("id",), ("value",)]],
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
True,
id="Non-Scalar: sql is list and return_last is False",
),
pytest.param(
False,
True,
["select * from dummy2", "select * from dummy"],
[[("id2",), ("value2",)], [("id",), ("value",)]],
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
False,
id="Write output when do_xcom_push is False",
),
],
)
def test_exec_write_file(return_last, split_statements, sql, descriptions, hook_results):
def test_exec_write_file(return_last, split_statements, sql, descriptions, hook_results, do_xcom_push):
"""
Test the execute function in case where SQL query was successful and data is written as CSV
"""
Expand All @@ -242,6 +260,7 @@ def test_exec_write_file(return_last, split_statements, sql, descriptions, hook_
sql=sql,
output_path=tempfile_path,
return_last=return_last,
do_xcom_push=do_xcom_push,
split_statements=split_statements,
)
db_mock = db_mock_class.return_value
Expand Down

0 comments on commit edd7133

Please sign in to comment.