diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index abc7f2edc..13477c283 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -357,7 +357,7 @@ def _demote_snapshots( ) def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]) -> None: - if not plan.restatements: + if not plan.restatements or plan.is_dev: return snapshot_intervals_to_restate = { @@ -365,16 +365,15 @@ def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapsho for name, intervals in plan.restatements.items() } - if plan.is_prod: - # Restating intervals on prod plans should mean that the intervals are cleared across - # all environments, not just the version currently in prod - # This ensures that work done in dev environments can still be promoted to prod - # by forcing dev environments to re-run intervals that changed in prod - # - # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod - snapshot_intervals_to_restate.update( - self._restatement_intervals_across_all_environments(plan.restatements) - ) + # Restating intervals on prod plans should mean that the intervals are cleared across + # all environments, not just the version currently in prod + # This ensures that work done in dev environments can still be promoted to prod + # by forcing dev environments to re-run intervals that changed in prod + # + # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod + snapshot_intervals_to_restate.update( + self._restatement_intervals_across_all_environments(plan.restatements) + ) self.state_sync.remove_intervals( snapshot_intervals=list(snapshot_intervals_to_restate), diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index ed7eee92d..260ef1b1b 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -293,6 +293,9 @@ def run( execution_time = execution_time or now() self.state_sync.refresh_snapshot_intervals(self.snapshots.values()) + for s_id, interval in (restatements or {}).items(): + self.snapshots[s_id].remove_interval(interval) + if auto_restatement_enabled: auto_restated_intervals = apply_auto_restatements(self.snapshots, execution_time) self.state_sync.add_snapshots_intervals(auto_restated_intervals) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index feaf08105..3964b3f97 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -359,6 +359,77 @@ def test_forward_only_model_regular_plan_preview_enabled(init_and_plan_context: assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")] +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_forward_only_model_restate_full_history_in_dev(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + model_name = "memory.sushi.customer_max_revenue" + expressions = d.parse( + f""" + MODEL ( + name {model_name}, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key customer_id, + forward_only true, + ), + ); + + SELECT + customer_id, MAX(revenue) AS max_revenue + FROM memory.sushi.customer_revenue_lifetime + GROUP BY 1; + """ + ) + + model = load_sql_based_model(expressions) + assert model.forward_only + assert model.kind.full_history_restatement_only + context.upsert_model(model) + + context.plan("prod", skip_tests=True, auto_apply=True) + + model_kwargs = { + **model.dict(), + # Make a breaking change. + "query": model.query.order_by("customer_id"), # type: ignore + } + context.upsert_model(SqlModel.parse_obj(model_kwargs)) + + # Apply the model change in dev + plan = context.plan_builder("dev", skip_tests=True).build() + assert not plan.missing_intervals + context.apply(plan) + + snapshot = context.get_snapshot(model, raise_if_missing=True) + snapshot_table_name = snapshot.table_name(False) + + # Manually insert a dummy value to check that the table is recreated during the restatement + context.engine_adapter.insert_append( + snapshot_table_name, + pd.DataFrame({"customer_id": [-1], "max_revenue": [100]}), + ) + df = context.engine_adapter.fetchdf( + "SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue WHERE customer_id = -1" + ) + assert df["cnt"][0] == 1 + + # Apply a restatement plan in dev + plan = context.plan("dev", restate_models=[model.name], auto_apply=True) + assert len(plan.missing_intervals) == 1 + + # Check that the dummy value is not present + df = context.engine_adapter.fetchdf( + "SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue WHERE customer_id = -1" + ) + assert df["cnt"][0] == 0 + + # Check that the table is not empty + df = context.engine_adapter.fetchdf( + "SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue" + ) + assert df["cnt"][0] > 0 + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_full_history_restatement_model_regular_plan_preview_enabled( init_and_plan_context: t.Callable,