Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions library/src/iqb/cli/pipeline_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Pipeline run command."""

# TODO(bassosimone): add support for -f/--force to bypass cache

from dataclasses import dataclass
from datetime import date, datetime
from pathlib import Path
Expand Down Expand Up @@ -94,8 +92,9 @@ def coerce_str(value: object) -> str:
metavar="WORKFLOW",
help="Path to YAML workflow file (default: <dir>/pipeline.yaml)",
)
@click.option("-f", "--force", is_flag=True, default=False, help="Bypass cache and force sync")
@click.option("-v", "--verbose", is_flag=True, default=False, help="Verbose mode.")
def run(data_dir: str | None, workflow_file: str | None, verbose: bool) -> None:
def run(data_dir: str | None, workflow_file: str | None, force: bool, verbose: bool) -> None:
"""Run the BigQuery pipeline for all matrix entries."""

console = get_console()
Expand All @@ -115,6 +114,7 @@ def run(data_dir: str | None, workflow_file: str | None, verbose: bool) -> None:
enable_bigquery=True,
start_date=start,
end_date=end,
force=force,
)

raise SystemExit(interceptor.exitcode())
11 changes: 8 additions & 3 deletions library/src/iqb/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import logging
from datetime import datetime
from functools import partial
from importlib.resources import files
from pathlib import Path

Expand Down Expand Up @@ -91,6 +92,7 @@ def get_cache_entry(
enable_bigquery: bool,
start_date: str,
end_date: str,
force: bool = False,
) -> PipelineCacheEntry:
"""
Get or create a cache entry for the given query template.
Expand All @@ -103,6 +105,7 @@ def get_cache_entry(
enable_bigquery: Whether to enabled querying from BigQuery.
start_date: Date when to start the query (included) -- format YYYY-MM-DD
end_date: Date when to end the query (excluded) -- format YYYY-MM-DD
force: Whether to bypass cache and force BigQuery query execution.

Returns:
PipelineCacheEntry with paths to data.parquet and stats.json.
Expand All @@ -116,14 +119,16 @@ def get_cache_entry(

# 2. prepare for synching from BigQuery
if enable_bigquery:
entry.syncers.append(self._bq_syncer)
if force:
entry.syncers.clear()
entry.syncers.append(partial(self._bq_syncer, force=force))

# 3. return the entry
return entry

def _bq_syncer(self, entry: PipelineCacheEntry) -> bool:
def _bq_syncer(self, entry: PipelineCacheEntry, *, force: bool = False) -> bool:
"""Internal method to get the entry files using a BigQuery query."""
if entry.exists():
if not force and entry.exists():
log.info("querying for %s... skipped (cached)", entry)
return True
try:
Expand Down
4 changes: 3 additions & 1 deletion library/src/iqb/scripting/iqb_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def sync_mlab(
*,
enable_bigquery: bool,
end_date: str,
force: bool = False,
start_date: str,
) -> None:
"""
Expand Down Expand Up @@ -67,9 +68,10 @@ def sync_mlab(
enable_bigquery=enable_bigquery,
start_date=start_date,
end_date=end_date,
force=force,
)
with entry.lock():
if not entry.exists():
if force or not entry.exists():
entry.sync()

log.info(
Expand Down
29 changes: 29 additions & 0 deletions library/tests/iqb/cli/pipeline_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_bare_dates(
enable_bigquery=True,
start_date="2024-10-01",
end_date="2024-11-01",
force=False,
)


Expand All @@ -219,6 +220,34 @@ def test_valid_config(
enable_bigquery=True,
start_date="2024-10-01",
end_date="2024-11-01",
force=False,
)


class TestPipelineRunForceFlag:
"""-f/--force flag is accepted and passed to sync."""

@patch("iqb.cli.pipeline_run.IQBPipeline")
@patch("iqb.cli.pipeline_run.Pipeline")
def test_force_accepted(
self, mock_pipeline_cls: MagicMock, mock_iqb_pipeline_cls: MagicMock, tmp_path: Path
):
_write_config(tmp_path / "pipeline.yaml", _VALID_CONFIG)
mock_pipeline = MagicMock()
mock_pipeline_cls.return_value = mock_pipeline

runner = CliRunner()
result = runner.invoke(
cli,
["pipeline", "run", "-d", str(tmp_path), "--force"],
)
assert result.exit_code == 0
mock_pipeline.sync_mlab.assert_called_once_with(
"country",
enable_bigquery=True,
start_date="2024-10-01",
end_date="2024-11-01",
force=True,
)


Expand Down
36 changes: 36 additions & 0 deletions library/tests/iqb/pipeline/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,42 @@ def test_bq_syncer_skip_when_exists(self, mock_client, tmp_path):
# Verify BigQuery was not attempted
mock_client.return_value.execute_query.assert_not_called()

@patch("iqb.pipeline.pipeline.PipelineBQPQClient")
def test_bq_syncer_force_queries_when_exists(self, mock_client, tmp_path):
"""Test that force=True bypasses cached skip and executes BigQuery."""
data_dir = tmp_path / "iqb"
pipeline = IQBPipeline(project="test-project", data_dir=data_dir)

cache_dir = (
data_dir
/ "cache"
/ "v1"
/ "20241001T000000Z"
/ "20241101T000000Z"
/ "downloads_by_country"
)
cache_dir.mkdir(parents=True, exist_ok=True)
(cache_dir / "data.parquet").write_text("fake parquet data")
(cache_dir / "stats.json").write_text("{}")

mock_result = MagicMock(spec=PipelineBQPQQueryResult)
mock_client.return_value.execute_query.return_value = mock_result

entry = pipeline.get_cache_entry(
dataset_name="downloads_by_country",
enable_bigquery=True,
start_date="2024-10-01",
end_date="2024-11-01",
force=True,
)

with entry.lock():
entry.sync()

mock_client.return_value.execute_query.assert_called_once()
mock_result.save_data_parquet.assert_called_once()
mock_result.save_stats_json.assert_called_once()

@patch("iqb.pipeline.pipeline.PipelineBQPQClient")
def test_bq_syncer_failure(self, mock_client, tmp_path):
"""Test that _bq_syncer handles exceptions and returns False."""
Expand Down
36 changes: 36 additions & 0 deletions library/tests/iqb/scripting/iqb_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ def test_syncs_missing_entries(self) -> None:
enable_bigquery=True,
start_date="2024-01-01",
end_date="2024-02-01",
force=False,
),
call(
dataset_name="uploads_by_country",
enable_bigquery=True,
start_date="2024-01-01",
end_date="2024-02-01",
force=False,
),
]

Expand All @@ -81,6 +83,40 @@ def test_skips_existing_entries(self) -> None:
assert entry_download.synced is False
assert entry_upload.synced is False

def test_force_syncs_existing_entries(self) -> None:
entry_download = _DummyEntry(exists=True)
entry_upload = _DummyEntry(exists=True)
pipeline = Mock()
pipeline.get_cache_entry.side_effect = [entry_download, entry_upload]

wrapper = iqb_pipeline.Pipeline(pipeline=pipeline)
wrapper.sync_mlab(
"country",
enable_bigquery=True,
force=True,
start_date="2024-01-01",
end_date="2024-02-01",
)

assert entry_download.synced is True
assert entry_upload.synced is True
assert pipeline.get_cache_entry.call_args_list == [
call(
dataset_name="downloads_by_country",
enable_bigquery=True,
start_date="2024-01-01",
end_date="2024-02-01",
force=True,
),
call(
dataset_name="uploads_by_country",
enable_bigquery=True,
start_date="2024-01-01",
end_date="2024-02-01",
force=True,
),
]

def test_invalid_granularity_raises(self) -> None:
pipeline = Mock()
wrapper = iqb_pipeline.Pipeline(pipeline=pipeline)
Expand Down