Skip to content

Commit

Permalink
Use pytest.param for test
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle-Neale committed Dec 18, 2024
1 parent 74d9888 commit adec372
Showing 1 changed file with 51 additions and 23 deletions.
74 changes: 51 additions & 23 deletions airflow/tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,52 @@ def test_dag_total_tasks(aggregator, task_instance):
aggregator.assert_metric('airflow.dag.task.total_running', value=1, count=1)


def test_dag_task_ongoing_duration(aggregator, task_instance):
instance = common.FULL_CONFIG['instances'][0]
check = AirflowCheck('airflow', common.FULL_CONFIG, [instance])

with mock.patch('datadog_checks.airflow.airflow.AirflowCheck._get_version', return_value='2.6.2'):

with mock.patch('datadog_checks.base.utils.http.requests') as req:
mock_resp = mock.MagicMock(status_code=200)
mock_resp.json.side_effect = [
{'metadatabase': {'status': 'healthy'}, 'scheduler': {'status': 'healthy'}},
]
req.get.return_value = mock_resp
with mock.patch(
'datadog_checks.airflow.airflow.AirflowCheck._get_all_task_instances',
return_value=task_instance.get('task_instances'),
):
check.check(None)

aggregator.assert_metric(
'airflow.dag.task.ongoing_duration',
tags=['key:my-tag', 'url:http://localhost:8080', 'dag_id:tutorial', 'task_id:sleep'],
count=1,
)


@pytest.mark.parametrize(
"collect_ongoing_duration, expected_metric_count, should_call_method",
"collect_ongoing_duration, should_call_method",
[
(True, 1, True), # `collect_task_duration=True`: metric is collected, method is called
(False, 0, False), # `collect_task_duration=False`: metric is NOT collected, method is NOT called
pytest.param(
True,
[
mock.call(
'http://localhost:8080/api/v1/dags/~/dagRuns/~/taskInstances?state=running',
['url:http://localhost:8080', 'key:my-tag'],
)
],
id="collect",
),
pytest.param(
False,
[],
id="don't collect",
),
],
)
def test_dag_task_ongoing_duration(
aggregator, task_instance, collect_ongoing_duration, expected_metric_count, should_call_method
):
def test_config_collect_ongoing_duration(collect_ongoing_duration, should_call_method):
instance = {**common.FULL_CONFIG['instances'][0], 'collect_ongoing_duration': collect_ongoing_duration}
check = AirflowCheck('airflow', common.FULL_CONFIG, [instance])

Expand All @@ -117,21 +153,13 @@ def test_dag_task_ongoing_duration(
]
req.get.return_value = mock_resp

if should_call_method:
with mock.patch(
'datadog_checks.airflow.airflow.AirflowCheck._get_all_task_instances',
return_value=task_instance.get('task_instances'),
):
check.check(None)
else:
with mock.patch(
'datadog_checks.airflow.airflow.AirflowCheck._get_all_task_instances'
) as mock_get_all_task_instances:
check.check(None)
mock_get_all_task_instances.assert_not_called()
with mock.patch(
'datadog_checks.airflow.airflow.AirflowCheck._get_all_task_instances'
) as mock_get_all_task_instances:
check.check(None)

# Assert the metric count
aggregator.assert_metric(
'airflow.dag.task.ongoing_duration',
count=expected_metric_count,
)
# Assert method calls
if collect_ongoing_duration:
mock_get_all_task_instances.assert_has_calls(should_call_method)
else:
mock_get_all_task_instances.assert_not_called()

0 comments on commit adec372

Please sign in to comment.