diff --git a/library/src/iqb/cli/pipeline_run.py b/library/src/iqb/cli/pipeline_run.py index feed2aa..4cb3243 100644 --- a/library/src/iqb/cli/pipeline_run.py +++ b/library/src/iqb/cli/pipeline_run.py @@ -1,7 +1,5 @@ """Pipeline run command.""" -# TODO(bassosimone): add support for -f/--force to bypass cache - from dataclasses import dataclass from datetime import date, datetime from pathlib import Path @@ -94,8 +92,9 @@ def coerce_str(value: object) -> str: metavar="WORKFLOW", help="Path to YAML workflow file (default: /pipeline.yaml)", ) +@click.option("-f", "--force", is_flag=True, default=False, help="Bypass cache and force sync") @click.option("-v", "--verbose", is_flag=True, default=False, help="Verbose mode.") -def run(data_dir: str | None, workflow_file: str | None, verbose: bool) -> None: +def run(data_dir: str | None, workflow_file: str | None, force: bool, verbose: bool) -> None: """Run the BigQuery pipeline for all matrix entries.""" console = get_console() @@ -115,6 +114,7 @@ def run(data_dir: str | None, workflow_file: str | None, verbose: bool) -> None: enable_bigquery=True, start_date=start, end_date=end, + force=force, ) raise SystemExit(interceptor.exitcode()) diff --git a/library/src/iqb/pipeline/pipeline.py b/library/src/iqb/pipeline/pipeline.py index 55eef72..2436e44 100644 --- a/library/src/iqb/pipeline/pipeline.py +++ b/library/src/iqb/pipeline/pipeline.py @@ -3,6 +3,7 @@ import hashlib import logging from datetime import datetime +from functools import partial from importlib.resources import files from pathlib import Path @@ -91,6 +92,7 @@ def get_cache_entry( enable_bigquery: bool, start_date: str, end_date: str, + force: bool = False, ) -> PipelineCacheEntry: """ Get or create a cache entry for the given query template. @@ -103,6 +105,7 @@ def get_cache_entry( enable_bigquery: Whether to enabled querying from BigQuery. start_date: Date when to start the query (included) -- format YYYY-MM-DD end_date: Date when to end the query (excluded) -- format YYYY-MM-DD + force: Whether to bypass cache and force BigQuery query execution. Returns: PipelineCacheEntry with paths to data.parquet and stats.json. @@ -116,14 +119,16 @@ def get_cache_entry( # 2. prepare for synching from BigQuery if enable_bigquery: - entry.syncers.append(self._bq_syncer) + if force: + entry.syncers.clear() + entry.syncers.append(partial(self._bq_syncer, force=force)) # 3. return the entry return entry - def _bq_syncer(self, entry: PipelineCacheEntry) -> bool: + def _bq_syncer(self, entry: PipelineCacheEntry, *, force: bool = False) -> bool: """Internal method to get the entry files using a BigQuery query.""" - if entry.exists(): + if not force and entry.exists(): log.info("querying for %s... skipped (cached)", entry) return True try: diff --git a/library/src/iqb/scripting/iqb_pipeline.py b/library/src/iqb/scripting/iqb_pipeline.py index 4828a7e..1ae1d28 100644 --- a/library/src/iqb/scripting/iqb_pipeline.py +++ b/library/src/iqb/scripting/iqb_pipeline.py @@ -32,6 +32,7 @@ def sync_mlab( *, enable_bigquery: bool, end_date: str, + force: bool = False, start_date: str, ) -> None: """ @@ -67,9 +68,10 @@ def sync_mlab( enable_bigquery=enable_bigquery, start_date=start_date, end_date=end_date, + force=force, ) with entry.lock(): - if not entry.exists(): + if force or not entry.exists(): entry.sync() log.info( diff --git a/library/tests/iqb/cli/pipeline_run_test.py b/library/tests/iqb/cli/pipeline_run_test.py index b019bb7..891dbce 100644 --- a/library/tests/iqb/cli/pipeline_run_test.py +++ b/library/tests/iqb/cli/pipeline_run_test.py @@ -193,6 +193,7 @@ def test_bare_dates( enable_bigquery=True, start_date="2024-10-01", end_date="2024-11-01", + force=False, ) @@ -219,6 +220,34 @@ def test_valid_config( enable_bigquery=True, start_date="2024-10-01", end_date="2024-11-01", + force=False, + ) + + +class TestPipelineRunForceFlag: + """-f/--force flag is accepted and passed to sync.""" + + @patch("iqb.cli.pipeline_run.IQBPipeline") + @patch("iqb.cli.pipeline_run.Pipeline") + def test_force_accepted( + self, mock_pipeline_cls: MagicMock, mock_iqb_pipeline_cls: MagicMock, tmp_path: Path + ): + _write_config(tmp_path / "pipeline.yaml", _VALID_CONFIG) + mock_pipeline = MagicMock() + mock_pipeline_cls.return_value = mock_pipeline + + runner = CliRunner() + result = runner.invoke( + cli, + ["pipeline", "run", "-d", str(tmp_path), "--force"], + ) + assert result.exit_code == 0 + mock_pipeline.sync_mlab.assert_called_once_with( + "country", + enable_bigquery=True, + start_date="2024-10-01", + end_date="2024-11-01", + force=True, ) diff --git a/library/tests/iqb/pipeline/pipeline_test.py b/library/tests/iqb/pipeline/pipeline_test.py index 76b9bbf..037e691 100644 --- a/library/tests/iqb/pipeline/pipeline_test.py +++ b/library/tests/iqb/pipeline/pipeline_test.py @@ -316,6 +316,42 @@ def test_bq_syncer_skip_when_exists(self, mock_client, tmp_path): # Verify BigQuery was not attempted mock_client.return_value.execute_query.assert_not_called() + @patch("iqb.pipeline.pipeline.PipelineBQPQClient") + def test_bq_syncer_force_queries_when_exists(self, mock_client, tmp_path): + """Test that force=True bypasses cached skip and executes BigQuery.""" + data_dir = tmp_path / "iqb" + pipeline = IQBPipeline(project="test-project", data_dir=data_dir) + + cache_dir = ( + data_dir + / "cache" + / "v1" + / "20241001T000000Z" + / "20241101T000000Z" + / "downloads_by_country" + ) + cache_dir.mkdir(parents=True, exist_ok=True) + (cache_dir / "data.parquet").write_text("fake parquet data") + (cache_dir / "stats.json").write_text("{}") + + mock_result = MagicMock(spec=PipelineBQPQQueryResult) + mock_client.return_value.execute_query.return_value = mock_result + + entry = pipeline.get_cache_entry( + dataset_name="downloads_by_country", + enable_bigquery=True, + start_date="2024-10-01", + end_date="2024-11-01", + force=True, + ) + + with entry.lock(): + entry.sync() + + mock_client.return_value.execute_query.assert_called_once() + mock_result.save_data_parquet.assert_called_once() + mock_result.save_stats_json.assert_called_once() + @patch("iqb.pipeline.pipeline.PipelineBQPQClient") def test_bq_syncer_failure(self, mock_client, tmp_path): """Test that _bq_syncer handles exceptions and returns False.""" diff --git a/library/tests/iqb/scripting/iqb_pipeline_test.py b/library/tests/iqb/scripting/iqb_pipeline_test.py index ceea0dd..34c138c 100644 --- a/library/tests/iqb/scripting/iqb_pipeline_test.py +++ b/library/tests/iqb/scripting/iqb_pipeline_test.py @@ -55,12 +55,14 @@ def test_syncs_missing_entries(self) -> None: enable_bigquery=True, start_date="2024-01-01", end_date="2024-02-01", + force=False, ), call( dataset_name="uploads_by_country", enable_bigquery=True, start_date="2024-01-01", end_date="2024-02-01", + force=False, ), ] @@ -81,6 +83,40 @@ def test_skips_existing_entries(self) -> None: assert entry_download.synced is False assert entry_upload.synced is False + def test_force_syncs_existing_entries(self) -> None: + entry_download = _DummyEntry(exists=True) + entry_upload = _DummyEntry(exists=True) + pipeline = Mock() + pipeline.get_cache_entry.side_effect = [entry_download, entry_upload] + + wrapper = iqb_pipeline.Pipeline(pipeline=pipeline) + wrapper.sync_mlab( + "country", + enable_bigquery=True, + force=True, + start_date="2024-01-01", + end_date="2024-02-01", + ) + + assert entry_download.synced is True + assert entry_upload.synced is True + assert pipeline.get_cache_entry.call_args_list == [ + call( + dataset_name="downloads_by_country", + enable_bigquery=True, + start_date="2024-01-01", + end_date="2024-02-01", + force=True, + ), + call( + dataset_name="uploads_by_country", + enable_bigquery=True, + start_date="2024-01-01", + end_date="2024-02-01", + force=True, + ), + ] + def test_invalid_granularity_raises(self) -> None: pipeline = Mock() wrapper = iqb_pipeline.Pipeline(pipeline=pipeline)