Skip to content

Commit

Permalink
ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchun12 committed Jul 21, 2024
1 parent 518b659 commit 28ed116
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 26 deletions.
4 changes: 3 additions & 1 deletion loader_scripts/fix_run_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
SERVICE_ACCOUNT_INFO = json.loads(os.environ["GOOGLE_SQLMESH_CREDENTIALS"])

# Authenticate with BigQuery using environment variable
credentials = service_account.Credentials.from_service_account_info(SERVICE_ACCOUNT_INFO)
credentials = service_account.Credentials.from_service_account_info(
SERVICE_ACCOUNT_INFO
)
client = bigquery.Client(credentials=credentials, project="sqlmesh-public-demo")

# Define the table
Expand Down
5 changes: 3 additions & 2 deletions loader_scripts/load_raw_events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This is a public demo script to generate demo data """
"""This is a public demo script to generate demo data"""

import pandas as pd
import uuid
from datetime import datetime, timedelta
Expand Down Expand Up @@ -88,5 +89,5 @@ def append_to_bigquery_table(
table_name="tcloud_raw_data.raw_events",
num_rows=20,
end_date="2024-07-16",
project_id="sqlmesh-public-demo"
project_id="sqlmesh-public-demo",
)
23 changes: 11 additions & 12 deletions macros/gen_surrogate_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@

from sqlglot import exp
from sqlmesh import macro
from pydantic import BaseModel, Field


class SurrogateKeyInput(BaseModel):
field_list: list[str] = Field(..., min_length=2)


@macro("gen_surrogate_key")
Expand All @@ -22,13 +17,15 @@ def gen_surrogate_key(evaluator, field_list: list) -> exp.SHA2:
Example:
- gen_surrogate_key(["field1", "field2"])
- In a SQL model: select @gen_surrogate_key([orders.order_id, orders.customer_id]) as surrogate_key from orders
Returns: An expression (SQLGlot) representing the SQL for the generated surrogate key.
"""

if len(field_list) < 2:
raise ValueError("At least two fields are required to generate a surrogate key.")

raise ValueError(
"At least two fields are required to generate a surrogate key."
)

# Assuming columns are instances of a specific class, e.g., Column
if not all(isinstance(field, exp.Column) for field in field_list):
raise ValueError("All fields must be column objects.")
Expand All @@ -38,14 +35,16 @@ def gen_surrogate_key(evaluator, field_list: list) -> exp.SHA2:
expressions = []
for i, field in enumerate(field_list):
coalesce_expression = exp.Coalesce(
this=exp.cast(expression=field, to='TEXT'), # Adjusted to use the field directly
expressions=exp.Literal.string(default_null_value)
this=exp.cast(
expression=field, to="TEXT"
), # Adjusted to use the field directly
expressions=exp.Literal.string(default_null_value),
)
expressions.append(coalesce_expression)
if i < len(field_list) - 1: # Add separator except for the last element
expressions.append(exp.Literal.string('-'))
expressions.append(exp.Literal.string("-"))

concat_exp = exp.Concat(expressions=expressions)
hash_exp = exp.SHA2(this=concat_exp, length=exp.Literal.number(256))

return hash_exp
return hash_exp
30 changes: 19 additions & 11 deletions macros/tests/test_gen_surrogate_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ def test_gen_surrogate_key_valid():
field_list = [exp.Column(this="field1"), exp.Column(this="field2")]
expected_sql = "SHA256(CONCAT(COALESCE(CAST(field1 AS STRING), '_null_'), '-', COALESCE(CAST(field2 AS STRING), '_null_')))"

result = gen_surrogate_key(evaluator=None, field_list=field_list).sql(
"bigquery"
)
result = gen_surrogate_key(evaluator=None, field_list=field_list).sql("bigquery")

assert result == expected_sql


def test_gen_surrogate_key_multiple_fields():
field_list = [exp.Column(this="field1"), exp.Column(this="field2"), exp.Column(this="field3")]
field_list = [
exp.Column(this="field1"),
exp.Column(this="field2"),
exp.Column(this="field3"),
]
expected_sql = "SHA256(CONCAT(COALESCE(CAST(field1 AS STRING), '_null_'), '-', COALESCE(CAST(field2 AS STRING), '_null_'), '-', COALESCE(CAST(field3 AS STRING), '_null_')))"

result = gen_surrogate_key(evaluator=None, field_list=field_list).sql(
"bigquery"
)
result = gen_surrogate_key(evaluator=None, field_list=field_list).sql("bigquery")

assert result == expected_sql

Expand All @@ -29,17 +29,25 @@ def test_gen_surrogate_key_error_single_field():
with pytest.raises(ValueError) as excinfo:
gen_surrogate_key(evaluator=None, field_list=[exp.Column(this="field1")])

assert "At least two fields are required to generate a surrogate key." in str(excinfo.value)
assert "At least two fields are required to generate a surrogate key." in str(
excinfo.value
)


def test_gen_surrogate_key_error_empty_list():
with pytest.raises(ValueError) as excinfo:
gen_surrogate_key(evaluator=None, field_list=[])

assert "At least two fields are required to generate a surrogate key." in str(excinfo.value)
assert "At least two fields are required to generate a surrogate key." in str(
excinfo.value
)


def test_gen_surrogate_key_error_non_column_fields():
with pytest.raises(ValueError) as excinfo:
gen_surrogate_key(evaluator=None, field_list=[exp.Column(this="field1"), 2, exp.Column(this="field3")])
gen_surrogate_key(
evaluator=None,
field_list=[exp.Column(this="field1"), 2, exp.Column(this="field3")],
)

assert "All fields must be column objects." in str(excinfo.value)
assert "All fields must be column objects." in str(excinfo.value)

0 comments on commit 28ed116

Please sign in to comment.