Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,26 @@ from

references = analyze(dialects.BigQuery, sql, schema, "production")
print(references)
# {"production.shop.orders": {"id", "user_id", "payment.amount"}}
# [
# {
# "database": "production",
# "schema": "shop",
# "table": "orders",
# "column": "id",
# },
# {
# "database": "production",
# "schema": "shop",
# "table": "orders",
# "column": "payment.amount",
# },
# {
# "database": "production",
# "schema": "shop",
# "table": "orders",
# "column": "user_id",
# },
# ]
```

### analyze_timespan
Expand Down Expand Up @@ -100,12 +119,16 @@ where

timespans = analyze_timespan(dialects.BigQuery, sql, schema, "production")
print(timespans)
# {
# "production.shop.orders.ordered_at": {
# [
# {
# "database": "production",
# "schema": "shop",
# "table": "orders",
# "column": "ordered_at",
# "lower": "2025-01-01",
# "upper": "2026-01-01"
# "upper": "2026-01-01",
# }
# }
# ]
```

To make `current_date()` deterministic, pass a provider:
Expand Down
21 changes: 19 additions & 2 deletions src/analytics_query_analyzer/references_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, schema: dict, default_catalog: str):
self.schema = schema
self.default_catalog = default_catalog

def analyze(self, expression: exp.Expression) -> dict:
def analyze(self, expression: exp.Expression) -> list[dict[str, str]]:
references: dict[str, set[str]] = {}

for scope in traverse_scope(expression):
Expand All @@ -27,7 +27,7 @@ def analyze(self, expression: exp.Expression) -> dict:
table_path = ".".join(full_path.split(".", 3)[:3])
_add(references, table_path, column_path)

return references
return _flatten_references(references)

def _add(references: dict[str, set[str]], table: str, column: str):
if table in references:
Expand All @@ -36,6 +36,23 @@ def _add(references: dict[str, set[str]], table: str, column: str):
references[table] = {column}


def _flatten_references(references: dict[str, set[str]]) -> list[dict[str, str]]:
rows: list[dict[str, str]] = []
for table_path, columns in references.items():
database, schema, table = table_path.split(".", 2)
for column in columns:
rows.append(
{
"database": database,
"schema": schema,
"table": table,
"column": column,
}
)
rows.sort(key=lambda row: (row["database"], row["schema"], row["table"], row["column"]))
return rows


def _extract_column(expr: exp.Expression) -> exp.Column | None:
if isinstance(expr, exp.Column):
return expr
Expand Down
31 changes: 29 additions & 2 deletions src/analytics_query_analyzer/timespan_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(
self.default_catalog = default_catalog
self.current_date_provider = current_date_provider or date.today

def analyze(self, expression: exp.Expression) -> dict[str, dict[str, str | None]]:
return _stringify_results(self._analyze_internal(expression))
def analyze(self, expression: exp.Expression) -> list[dict[str, str | None]]:
return _flatten_timespans(_stringify_results(self._analyze_internal(expression)))

def _analyze_internal(self, expression: exp.Expression) -> TimespanResults:
if isinstance(expression, exp.Union):
Expand Down Expand Up @@ -632,3 +632,30 @@ def _stringify_results(results: TimespanResults) -> dict[str, dict[str, str | No
"upper": bounds["upper"].isoformat() if bounds["upper"] else None,
}
return output


def _flatten_timespans(
results: dict[str, dict[str, str | None]],
) -> list[dict[str, str | None]]:
rows: list[dict[str, str | None]] = []
for full_path, bounds in results.items():
database, schema, table, column = full_path.split(".", 3)
rows.append(
{
"database": database,
"schema": schema,
"table": table,
"column": column,
"lower": bounds.get("lower"),
"upper": bounds.get("upper"),
}
)
rows.sort(
key=lambda row: (
row["database"] or "",
row["schema"] or "",
row["table"] or "",
row["column"] or "",
)
)
return rows
147 changes: 126 additions & 21 deletions tests/test_analyze_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,61 +32,125 @@


test_cases = [
{"name": "not referencing a table", "sql": "select 1", "expected": {}},
{"name": "not referencing a table", "sql": "select 1", "expected": []},
{
"name": "referencing a table but not columns",
"sql": "select count(1) from shop.orders",
"expected": {},
"expected": [],
},
{
"name": "simple column reference",
"sql": "select user_id from shop.orders",
"expected": {"production.shop.orders": {"user_id"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "user_id",
}
],
},
{
"name": "qualifying a table with a project",
"sql": "select user_id from production.shop.orders",
"expected": {"production.shop.orders": {"user_id"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "user_id",
}
],
},
{
"name": "qualifying a table with a non-default project",
"sql": "select id from development.shop.users",
"expected": {"development.shop.users": {"id"}},
"expected": [
{
"database": "development",
"schema": "shop",
"table": "users",
"column": "id",
}
],
},
{
"name": "referencing a column in a where clause",
"sql": "select count(1) from shop.orders where ordered_at >= '2026-01-01'",
"expected": {"production.shop.orders": {"ordered_at"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "ordered_at",
}
],
},
{
"name": "referencing a column in a join",
"sql": "select count(1) from shop.orders join shop.users on orders.user_id = users.id",
"expected": {
"production.shop.orders": {"user_id"},
"production.shop.users": {"id"},
},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "user_id",
},
{
"database": "production",
"schema": "shop",
"table": "users",
"column": "id",
},
],
},
{
"name": "referencing a column in an ORDER BY clause",
"sql": "select user_id from shop.orders order by payment_amount desc",
"expected": {"production.shop.orders": {"user_id", "payment_amount"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "payment_amount",
},
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "user_id",
},
],
},
{
"name": "wildcard pattern",
"sql": "select * from shop.users",
"expected": {
"production.shop.users": set(schema["production"]["shop"]["users"].keys())
},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "users",
"column": column,
}
for column in schema["production"]["shop"]["users"].keys()
],
},
{
"name": "using a wildcard with COUNT",
"sql": "select count(*) from shop.users",
"expected": {},
"expected": [],
},
{
"name": "SELECT EXCEPT pattern",
"sql": "select * except (name) from shop.users",
"expected": {"production.shop.users": {"id"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "users",
"column": "id",
}
],
},
{
"name": "CTE pattern",
Expand All @@ -105,17 +169,44 @@
from
amount_by_method
""",
"expected": {"production.shop.orders": {"payment_method", "payment_amount"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "payment_amount",
},
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "payment_method",
},
],
},
{
"name": "referencing a field of a struct",
"sql": "select brand.category from shop.items",
"expected": {"production.shop.items": {"brand.category"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "items",
"column": "brand.category",
}
],
},
{
"name": "referencing fields of a struct with a wildcard",
"sql": "select brand.* from shop.items",
"expected": {"production.shop.items": {"brand.*"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "items",
"column": "brand.*",
}
],
},
{
"name": "referencing a field in an unnested struct",
Expand All @@ -127,7 +218,14 @@
where
exists (select 1 from unnest(items) where amount > 1)
""",
"expected": {"production.shop.orders": {"items"}},
"expected": [
{
"database": "production",
"schema": "shop",
"table": "orders",
"column": "items",
}
],
},
]

Expand All @@ -139,4 +237,11 @@
)
def test_analyze_case(sql, expected):
result = analyze(dialects.BigQuery, sql, schema, "production")
assert result == expected
assert _sorted_rows(result) == _sorted_rows(expected)


def _sorted_rows(rows: list[dict[str, str]]) -> list[dict[str, str]]:
return sorted(
rows,
key=lambda row: (row["database"], row["schema"], row["table"], row["column"]),
)
Loading