Skip to content

Commit

Permalink
Fix: Take max interval end for a model into account even when there's…
Browse files Browse the repository at this point in the history
… a restatement interval (#3913)
  • Loading branch information
izeigerman committed Feb 27, 2025
1 parent 3111e5c commit b38bc85
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
5 changes: 0 additions & 5 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,6 @@ def _build_dag(self) -> DAG[SnapshotId]:
def _build_restatements(
self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike
) -> t.Dict[SnapshotId, Interval]:
def is_restateable_snapshot(snapshot: Snapshot) -> bool:
if not self._is_dev and snapshot.disable_restatement:
return False
return not snapshot.is_symbolic and not snapshot.is_seed

restate_models = self._restate_models
if restate_models == set():
# This is a warning but we print this as error since the Console is lacking API for warnings.
Expand Down
16 changes: 8 additions & 8 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,14 +1745,14 @@ def missing_intervals(
snapshot = snapshot.copy()
snapshot.intervals = snapshot.intervals.copy()
snapshot.remove_interval(restated_interval)
else:
existing_interval_end = interval_end_per_model.get(snapshot.name)
if existing_interval_end:
if to_timestamp(snapshot_start_date) >= existing_interval_end:
# The start exceeds the provided interval end, so we can skip this snapshot
# since it doesn't have missing intervals by definition
continue
snapshot_end_date = existing_interval_end

existing_interval_end = interval_end_per_model.get(snapshot.name)
if existing_interval_end:
if to_timestamp(snapshot_start_date) >= existing_interval_end:
# The start exceeds the provided interval end, so we can skip this snapshot
# since it doesn't have missing intervals by definition
continue
snapshot_end_date = existing_interval_end

missing_interval_end_date = snapshot_end_date
node_end_date = snapshot.node.end
Expand Down
27 changes: 27 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4572,6 +4572,33 @@ def test_restatement_of_full_model_with_start(init_and_plan_context: t.Callable)
assert waiter_by_day_interval == (to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_restatement_shouldnt_backfill_beyond_prod_intervals(init_and_plan_context: t.Callable):
context, _ = init_and_plan_context("examples/sushi")

model = context.get_model("sushi.top_waiters")
context.upsert_model(SqlModel.parse_obj({**model.dict(), "cron": "@hourly"}))

context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True)
context.run()

with time_machine.travel("2023-01-09 02:00:00 UTC"):
# It's time to backfill the waiter_revenue_by_day model but it hasn't run yet
restatement_plan = context.plan(
restate_models=["sushi.waiter_revenue_by_day"],
no_prompts=True,
skip_tests=True,
)
intervals_by_id = {i.snapshot_id: i for i in restatement_plan.missing_intervals}
# Make sure the intervals don't go beyond the prod intervals
assert intervals_by_id[context.get_snapshot("sushi.top_waiters").snapshot_id].intervals[-1][
1
] == to_timestamp("2023-01-08 15:00:00 UTC")
assert intervals_by_id[
context.get_snapshot("sushi.waiter_revenue_by_day").snapshot_id
].intervals[-1][1] == to_timestamp("2023-01-08 00:00:00 UTC")


def initial_add(context: Context, environment: str):
assert not context.state_reader.get_environment(environment)

Expand Down

0 comments on commit b38bc85

Please sign in to comment.