Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/analytics_query_analyzer/timespan_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,10 @@ def extract_literal(self, expr: exp.Expression) -> date | None:
return self._extract_parts_date(expr)
if isinstance(expr, exp.UnixToTime):
return self._extract_epoch_date(expr)
if isinstance(expr, (exp.DateAdd, exp.DateSub)):
if isinstance(expr, (exp.DateAdd, exp.DateSub, exp.TsOrDsAdd)):
return self._extract_date_arithmetic(expr)
if isinstance(expr, (exp.Add, exp.Sub)):
return self._extract_interval_arithmetic(expr)
return None

def _evaluate_current_date(self) -> date | None:
Expand Down Expand Up @@ -388,6 +390,27 @@ def _extract_date_arithmetic(self, expr: exp.Expression) -> date | None:
delta_value = -delta_value
return self._apply_date_delta(base, delta_value, unit)

def _extract_interval_arithmetic(self, expr: exp.Expression) -> date | None:
left = expr.args.get("this")
right = expr.args.get("expression")
if isinstance(right, exp.Interval):
base = self.extract_literal(left)
delta = self._extract_int_literal(right.args.get("this"))
unit = self._extract_unit_literal(right.args.get("unit"))
if base is None or delta is None or not unit:
return None
if isinstance(expr, exp.Sub):
delta = -delta
return self._apply_date_delta(base, delta, unit)
if isinstance(left, exp.Interval):
base = self.extract_literal(right)
delta = self._extract_int_literal(left.args.get("this"))
unit = self._extract_unit_literal(left.args.get("unit"))
if base is None or delta is None or not unit:
return None
return self._apply_date_delta(base, delta, unit)
return None

def _extract_trunc_date(self, expr: exp.Expression) -> date | None:
target = expr.args.get("this") or expr.this
unit_expr = expr.args.get("unit") or expr.args.get("part")
Expand Down Expand Up @@ -425,7 +448,13 @@ def _extract_int_literal(self, expr: exp.Expression | None) -> int | None:
try:
return int(expr.this)
except (TypeError, ValueError):
return None
try:
return int(float(expr.this))
except (TypeError, ValueError):
return None
if isinstance(expr, exp.Neg):
value = self._extract_int_literal(expr.this)
return None if value is None else -value
if isinstance(expr, exp.Cast):
return self._extract_int_literal(expr.this)
if isinstance(expr, exp.Paren):
Expand Down
File renamed without changes.
142 changes: 142 additions & 0 deletions tests/test_analyze_redshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import pytest
from sqlglot import dialects

from analytics_query_analyzer.analyzer import analyze


schema = {
"production": {
"shop": {
"orders": {
"id": "int64",
"ordered_at": "timestamp",
"user_id": "int64",
"payment_amount": "int64",
"payment_method": "string",
"items": "super",
},
"users": {"id": "int64", "name": "varchar"},
"items": {
"id": "int64",
"name": "varchar",
"brand": "super",
},
}
},
"development": {
"shop": {
"users": {"id": "int64", "name": "varchar"},
}
},
}


test_cases = [
{"name": "not referencing a table", "sql": "select 1", "expected": {}},
{
"name": "referencing a table but not columns",
"sql": "select count(1) from shop.orders",
"expected": {},
},
{
"name": "simple column reference",
"sql": "select user_id from shop.orders",
"expected": {"production.shop.orders": {"user_id"}},
},
{
"name": "qualifying a table with a project",
"sql": "select user_id from production.shop.orders",
"expected": {"production.shop.orders": {"user_id"}},
},
{
"name": "qualifying a table with a non-default project",
"sql": "select id from development.shop.users",
"expected": {"development.shop.users": {"id"}},
},
{
"name": "where clause reference",
"sql": "select count(1) from shop.orders where ordered_at >= '2026-01-01'",
"expected": {"production.shop.orders": {"ordered_at"}},
},
{
"name": "join reference",
"sql": "select count(1) from shop.orders join shop.users on orders.user_id = users.id",
"expected": {
"production.shop.orders": {"user_id"},
"production.shop.users": {"id"},
},
},
{
"name": "referencing a column in an ORDER BY clause",
"sql": "select user_id from shop.orders order by payment_amount desc",
"expected": {"production.shop.orders": {"user_id", "payment_amount"}},
},
{
"name": "wildcard pattern",
"sql": "select * from shop.users",
"expected": {
"production.shop.users": set(schema["production"]["shop"]["users"].keys())
},
},
{
"name": "using a wildcard with COUNT",
"sql": "select count(*) from shop.users",
"expected": {},
},
{
"name": "selecting multiple columns",
"sql": "select id, name from shop.users",
"expected": {"production.shop.users": {"id", "name"}},
},
{
"name": "CTE pattern",
"sql": """
with amount_by_method as (
select
payment_method,
sum(payment_amount) as total_amount
from
shop.orders
group by
1
)
select
*
from
amount_by_method
""",
"expected": {"production.shop.orders": {"payment_method", "payment_amount"}},
},
{
"name": "referencing a field of a super column",
"sql": "select brand['category'] from shop.items",
"expected": {"production.shop.items": {"brand"}},
},
{
"name": "referencing multiple fields of a super column",
"sql": "select brand['category'], brand['name'] from shop.items",
"expected": {"production.shop.items": {"brand"}},
},
{
"name": "referencing a field in a super column filter",
"sql": """
select
count(1)
from
shop.orders
where
json_extract_path_text(items, 'amount') is not null
""",
"expected": {"production.shop.orders": {"items"}},
},
]


@pytest.mark.parametrize(
("sql", "expected"),
[(case["sql"], case["expected"]) for case in test_cases],
ids=[case["name"] for case in test_cases],
)
def test_analyze_case(sql, expected):
result = analyze(dialects.Redshift, sql, schema, "production")
assert result == expected
File renamed without changes.
Loading