Skip to content

Commit

Permalink
Fix: Don't drop intervals in state when restating models in a dev env…
Browse files Browse the repository at this point in the history
…ironment (#3580)
  • Loading branch information
izeigerman authored Jan 3, 2025
1 parent 795fd2b commit 2319c80
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 11 deletions.
21 changes: 10 additions & 11 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,24 +357,23 @@ def _demote_snapshots(
)

def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]) -> None:
if not plan.restatements:
if not plan.restatements or plan.is_dev:
return

snapshot_intervals_to_restate = {
(snapshots_by_name[name].table_info, intervals)
for name, intervals in plan.restatements.items()
}

if plan.is_prod:
# Restating intervals on prod plans should mean that the intervals are cleared across
# all environments, not just the version currently in prod
# This ensures that work done in dev environments can still be promoted to prod
# by forcing dev environments to re-run intervals that changed in prod
#
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
snapshot_intervals_to_restate.update(
self._restatement_intervals_across_all_environments(plan.restatements)
)
# Restating intervals on prod plans should mean that the intervals are cleared across
# all environments, not just the version currently in prod
# This ensures that work done in dev environments can still be promoted to prod
# by forcing dev environments to re-run intervals that changed in prod
#
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
snapshot_intervals_to_restate.update(
self._restatement_intervals_across_all_environments(plan.restatements)
)

self.state_sync.remove_intervals(
snapshot_intervals=list(snapshot_intervals_to_restate),
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def run(
execution_time = execution_time or now()

self.state_sync.refresh_snapshot_intervals(self.snapshots.values())
for s_id, interval in (restatements or {}).items():
self.snapshots[s_id].remove_interval(interval)

if auto_restatement_enabled:
auto_restated_intervals = apply_auto_restatements(self.snapshots, execution_time)
self.state_sync.add_snapshots_intervals(auto_restated_intervals)
Expand Down
71 changes: 71 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,77 @@ def test_forward_only_model_regular_plan_preview_enabled(init_and_plan_context:
assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")]


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_forward_only_model_restate_full_history_in_dev(init_and_plan_context: t.Callable):
context, _ = init_and_plan_context("examples/sushi")

model_name = "memory.sushi.customer_max_revenue"
expressions = d.parse(
f"""
MODEL (
name {model_name},
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key customer_id,
forward_only true,
),
);
SELECT
customer_id, MAX(revenue) AS max_revenue
FROM memory.sushi.customer_revenue_lifetime
GROUP BY 1;
"""
)

model = load_sql_based_model(expressions)
assert model.forward_only
assert model.kind.full_history_restatement_only
context.upsert_model(model)

context.plan("prod", skip_tests=True, auto_apply=True)

model_kwargs = {
**model.dict(),
# Make a breaking change.
"query": model.query.order_by("customer_id"), # type: ignore
}
context.upsert_model(SqlModel.parse_obj(model_kwargs))

# Apply the model change in dev
plan = context.plan_builder("dev", skip_tests=True).build()
assert not plan.missing_intervals
context.apply(plan)

snapshot = context.get_snapshot(model, raise_if_missing=True)
snapshot_table_name = snapshot.table_name(False)

# Manually insert a dummy value to check that the table is recreated during the restatement
context.engine_adapter.insert_append(
snapshot_table_name,
pd.DataFrame({"customer_id": [-1], "max_revenue": [100]}),
)
df = context.engine_adapter.fetchdf(
"SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue WHERE customer_id = -1"
)
assert df["cnt"][0] == 1

# Apply a restatement plan in dev
plan = context.plan("dev", restate_models=[model.name], auto_apply=True)
assert len(plan.missing_intervals) == 1

# Check that the dummy value is not present
df = context.engine_adapter.fetchdf(
"SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue WHERE customer_id = -1"
)
assert df["cnt"][0] == 0

# Check that the table is not empty
df = context.engine_adapter.fetchdf(
"SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue"
)
assert df["cnt"][0] > 0


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_full_history_restatement_model_regular_plan_preview_enabled(
init_and_plan_context: t.Callable,
Expand Down

0 comments on commit 2319c80

Please sign in to comment.