Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/analytics_query_analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@ def analyze(
dialect: str | type[dialects.Dialect],
sql: str,
schema: dict,
default_project: str,
default_catalog: str,
):
dialect = _resolve_dialect(dialect)
expression = parse_one(sql, read=dialect)

qualified = optimizer.qualify.qualify(
expression,
schema=schema,
catalog=default_project,
catalog=default_catalog,
validate_qualify_columns=False,
)
analyzer = ReferencesAnalyzer(schema, default_project)
analyzer = ReferencesAnalyzer(schema, default_catalog)
return analyzer.analyze(qualified)


def analyze_timespan(
dialect: str | type[dialects.Dialect],
sql: str,
schema: dict,
default_project: str,
default_catalog: str,
current_date_provider=None,
) -> dict:
dialect = _resolve_dialect(dialect)
Expand All @@ -44,10 +44,10 @@ def analyze_timespan(
qualified = optimizer.qualify.qualify(
expression,
schema=schema,
catalog=default_project,
catalog=default_catalog,
validate_qualify_columns=False,
)
analyzer = TimespanAnalyzer(schema, default_project, current_date_provider)
analyzer = TimespanAnalyzer(schema, default_catalog, current_date_provider)
return analyzer.analyze(qualified)


Expand Down
6 changes: 3 additions & 3 deletions src/analytics_query_analyzer/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def resolve_column_path(
column: exp.Column,
scope,
default_project: str,
default_catalog: str,
schema: dict,
column_extractor,
) -> tuple[str, str] | None:
Expand All @@ -25,12 +25,12 @@ def resolve_column_path(
if not nested_column:
return None
return resolve_column_path(
nested_column, source, default_project, schema, column_extractor
nested_column, source, default_catalog, schema, column_extractor
)
if not isinstance(source, exp.Table):
return None

project = source.catalog or default_project
project = source.catalog or default_catalog
dataset = source.db
table = source.name
if not dataset and project in schema:
Expand Down
6 changes: 3 additions & 3 deletions src/analytics_query_analyzer/references_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


class ReferencesAnalyzer:
def __init__(self, schema: dict, default_project: str):
def __init__(self, schema: dict, default_catalog: str):
self.schema = schema
self.default_project = default_project
self.default_catalog = default_catalog

def analyze(self, expression: exp.Expression) -> dict:
references: dict[str, set[str]] = {}
Expand All @@ -17,7 +17,7 @@ def analyze(self, expression: exp.Expression) -> dict:
resolved = resolve_column_path(
column,
scope,
self.default_project,
self.default_catalog,
self.schema,
_extract_column,
)
Expand Down
10 changes: 5 additions & 5 deletions src/analytics_query_analyzer/timespan_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class TimespanAnalyzer:
def __init__(
self,
schema: dict,
default_project: str,
default_catalog: str,
current_date_provider: Callable[[], str | date | datetime] | None = None,
):
self.schema = schema
self.default_project = default_project
self.default_catalog = default_catalog
self.current_date_provider = current_date_provider or date.today

def analyze(self, expression: exp.Expression) -> dict[str, dict[str, str | None]]:
Expand Down Expand Up @@ -141,7 +141,7 @@ def bounds_for_comparison(
if not column:
return {}
resolved = resolve_column_path(
column, scope, self.default_project, self.schema, self.extract_column
column, scope, self.default_catalog, self.schema, self.extract_column
)
if not resolved:
return {}
Expand All @@ -158,7 +158,7 @@ def bounds_for_comparison(
if not column:
return {}
resolved = resolve_column_path(
column, scope, self.default_project, self.schema, self.extract_column
column, scope, self.default_catalog, self.schema, self.extract_column
)
if not resolved:
return {}
Expand Down Expand Up @@ -218,7 +218,7 @@ def bounds_for_comparison(
return {}

resolved = resolve_column_path(
column, scope, self.default_project, self.schema, self.extract_column
column, scope, self.default_catalog, self.schema, self.extract_column
)
if not resolved:
return {}
Expand Down