Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ print(references)
# ]
```

### analyze_timespan
### analyze with time bounds

Extract time bounds from filters.
Extract time bounds from filters by enabling `with_timebounds`.

```python
from analytics_query_analyzer import analyze_timespan
from analytics_query_analyzer import analyze
from sqlglot import dialects

schema = {
Expand All @@ -117,28 +117,51 @@ where
and ordered_at < "2026-01-01"
"""

timespans = analyze_timespan(dialects.BigQuery, sql, schema, "production")
print(timespans)
timebounds = analyze(
dialects.BigQuery,
sql,
schema,
"production",
with_timebounds=True,
)
print(timebounds)
# [
# {
# "database": "production",
# "schema": "shop",
# "table": "orders",
# "column": "id",
# "lower": None,
# "upper": None,
# },
# {
# "database": "production",
# "schema": "shop",
# "table": "orders",
# "column": "ordered_at",
# "lower": "2025-01-01",
# "upper": "2026-01-01",
# }
# },
# {
# "database": "production",
# "schema": "shop",
# "table": "orders",
# "column": "user_id",
# "lower": None,
# "upper": None,
# },
# ]
```

To make `current_date()` deterministic, pass a provider:

```python
timespans = analyze_timespan(
timebounds = analyze(
dialects.BigQuery,
"select * from shop.orders where ordered_at >= current_date()",
schema,
"production",
with_timebounds=True,
current_date_provider=lambda: "2026-01-01",
)
```
Expand All @@ -158,7 +181,7 @@ print(schema)
- Authentication uses Application Default Credentials (ADC).
- When `table` is omitted, it scans all tables in the dataset.
- When both `dataset` and `table` are omitted, it scans all datasets in the project.
- The returned `schema` can be passed directly to `analyze` and `analyze_timespan`.
- The returned `schema` can be passed directly to `analyze`.

Fetching from Redshift is also supported:

Expand Down
4 changes: 2 additions & 2 deletions src/analytics_query_analyzer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .analyzer import analyze, analyze_timespan
from .analyzer import analyze
from .schema_builder import build_schema

__all__ = ["analyze", "analyze_timespan", "build_schema"]
__all__ = ["analyze", "build_schema"]
70 changes: 53 additions & 17 deletions src/analytics_query_analyzer/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Literal, overload

from sqlglot import dialects, optimizer, parse_one

from .references_analyzer import ReferencesAnalyzer
from .timespan_analyzer import TimespanAnalyzer
from .timebounds_analyzer import TimeboundsAnalyzer
from .types import ReferenceRow, TimeboundsRow


def _resolve_dialect(
Expand All @@ -12,32 +15,36 @@ def _resolve_dialect(
return dialect


@overload
def analyze(
dialect: str | type[dialects.Dialect],
sql: str,
schema: dict,
default_catalog: str,
):
dialect = _resolve_dialect(dialect)
expression = parse_one(sql, read=dialect)
with_timebounds: Literal[False] = False,
current_date_provider=None,
) -> list[ReferenceRow]: ...

qualified = optimizer.qualify.qualify(
expression,
schema=schema,
catalog=default_catalog,
validate_qualify_columns=False,
)
analyzer = ReferencesAnalyzer(schema, default_catalog)
return analyzer.analyze(qualified)

@overload
def analyze(
dialect: str | type[dialects.Dialect],
sql: str,
schema: dict,
default_catalog: str,
with_timebounds: Literal[True],
current_date_provider=None,
) -> list[TimeboundsRow]: ...


def analyze_timespan(
def analyze(
dialect: str | type[dialects.Dialect],
sql: str,
schema: dict,
default_catalog: str,
with_timebounds: bool = False,
current_date_provider=None,
) -> dict:
) -> list[ReferenceRow] | list[TimeboundsRow]:
dialect = _resolve_dialect(dialect)
expression = parse_one(sql, read=dialect)

Expand All @@ -47,8 +54,37 @@ def analyze_timespan(
catalog=default_catalog,
validate_qualify_columns=False,
)
analyzer = TimespanAnalyzer(schema, default_catalog, current_date_provider)
return analyzer.analyze(qualified)
analyzer = ReferencesAnalyzer(schema, default_catalog)
references: list[ReferenceRow] = analyzer.analyze(qualified)
if not with_timebounds:
return references
return _merge_timebounds(
references,
TimeboundsAnalyzer(schema, default_catalog, current_date_provider).analyze(
qualified
),
)


def _merge_timebounds(
references: list[ReferenceRow],
timebounds: list[TimeboundsRow],
) -> list[TimeboundsRow]:
merged: dict[tuple[str, str, str, str], TimeboundsRow] = {}
for row in references:
key = (row["database"], row["schema"], row["table"], row["column"])
merged[key] = {**row, "lower": None, "upper": None}
for row in timebounds:
key = (row["database"], row["schema"], row["table"], row["column"])
if key in merged:
merged[key]["lower"] = row.get("lower")
merged[key]["upper"] = row.get("upper")
else:
merged[key] = row
return [
merged[key]
for key in sorted(merged.keys(), key=lambda k: (k[0], k[1], k[2], k[3]))
]


__all__ = ["analyze", "analyze_timespan"]
__all__ = ["analyze"]
23 changes: 13 additions & 10 deletions src/analytics_query_analyzer/references_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from sqlglot.optimizer.scope import traverse_scope

from .column_resolver import resolve_column_path
from .types import ReferenceRow


class ReferencesAnalyzer:
def __init__(self, schema: dict, default_catalog: str):
self.schema = schema
self.default_catalog = default_catalog

def analyze(self, expression: exp.Expression) -> list[dict[str, str]]:
def analyze(self, expression: exp.Expression) -> list[ReferenceRow]:
references: dict[str, set[str]] = {}

for scope in traverse_scope(expression):
Expand All @@ -36,20 +37,22 @@ def _add(references: dict[str, set[str]], table: str, column: str):
references[table] = {column}


def _flatten_references(references: dict[str, set[str]]) -> list[dict[str, str]]:
rows: list[dict[str, str]] = []
def _flatten_references(references: dict[str, set[str]]) -> list[ReferenceRow]:
rows: list[ReferenceRow] = []
for table_path, columns in references.items():
database, schema, table = table_path.split(".", 2)
for column in columns:
rows.append(
{
"database": database,
"schema": schema,
"table": table,
"column": column,
}
ReferenceRow(
database=database,
schema=schema,
table=table,
column=column,
)
)
rows.sort(key=lambda row: (row["database"], row["schema"], row["table"], row["column"]))
rows.sort(
key=lambda row: (row["database"], row["schema"], row["table"], row["column"])
)
return rows


Expand Down
Loading