Skip to content

Commit

Permalink
Update tests and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Jan 21, 2025
1 parent aca4a66 commit fc08bd8
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 7 deletions.
6 changes: 5 additions & 1 deletion python_modules/libraries/dagster-dlt/dagster_dlt/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,11 @@ def _run(
"""
asset_key_dlt_source_resource_mapping = {
dagster_dlt_translator.get_asset_spec(DltResourceTranslatorData(resource=dlt_source_resource, destination=dlt_pipeline.destination)).key: dlt_source_resource
dagster_dlt_translator.get_asset_spec(
DltResourceTranslatorData(
resource=dlt_source_resource, destination=dlt_pipeline.destination
)
).key: dlt_source_resource
for dlt_source_resource in dlt_source.selected_resources.values()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dagster import (
AssetExecutionContext,
AssetKey,
AssetSpec,
AutoMaterializePolicy,
AutoMaterializeRule,
AutomationCondition,
Expand All @@ -23,6 +24,7 @@
from dagster._core.definitions.metadata.table import TableColumn, TableSchema
from dagster._core.definitions.tags import build_kind_tag, has_kind
from dagster_dlt import DagsterDltResource, DagsterDltTranslator, dlt_assets
from dagster_dlt.translator import DltResourceTranslatorData
from dlt import Pipeline
from dlt.common.destination import Destination
from dlt.extract.resource import DltResource
Expand Down Expand Up @@ -191,6 +193,28 @@ def example_pipeline_assets(


def test_multi_asset_names_do_not_conflict(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_spec(self, data: DltResourceTranslatorData) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.replace_attributes(key=AssetKey("custom_" + data.resource.name))

@dlt_assets(dlt_source=pipeline(), dlt_pipeline=dlt_pipeline, name="multi_asset_name1")
def assets1():
pass

@dlt_assets(
dlt_source=pipeline(),
dlt_pipeline=dlt_pipeline,
name="multi_asset_name2",
dagster_dlt_translator=CustomDagsterDltTranslator(),
)
def assets2():
pass

assert Definitions(assets=[assets1, assets2])


def test_multi_asset_names_do_not_conflict_legacy(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_key(self, resource: DltResource) -> AssetKey:
return AssetKey("custom_" + resource.name)
Expand All @@ -211,7 +235,7 @@ def assets2():
assert Definitions(assets=[assets1, assets2])


def test_get_materialize_policy(dlt_pipeline: Pipeline):
def test_get_materialize_policy_legacy(dlt_pipeline: Pipeline):
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_auto_materialize_policy(
self, resource: DltResource
Expand All @@ -232,7 +256,7 @@ def assets():
assert "0 1 * * *" in str(item)


def test_get_automation_condition(dlt_pipeline: Pipeline):
def test_get_automation_condition_legacy(dlt_pipeline: Pipeline):
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_automation_condition(self, resource: DltResource) -> Optional[AutomationCondition]:
return AutomationCondition.eager() | AutomationCondition.on_cron("0 1 * * *")
Expand All @@ -249,7 +273,7 @@ def assets():
assert "0 1 * * *" in str(item)


def test_get_automation_condition_converts_auto_materialize_policy(
def test_get_automation_condition_converts_auto_materialize_policy_legacy(
dlt_pipeline: Pipeline,
):
class CustomDagsterDltTranslator(DagsterDltTranslator):
Expand Down Expand Up @@ -401,6 +425,52 @@ def example_pipeline_assets(


def test_asset_metadata(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
metadata_by_resource_name = {
"repos": {"mode": "upsert", "primary_key": "id"},
"repo_issues": {"mode": "upsert", "primary_key": ["repo_id", "issue_id"]},
}

def get_asset_spec(self, data: DltResourceTranslatorData) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.merge_attributes(
metadata=self.metadata_by_resource_name.get(data.resource.name, {})
)

@dlt_assets(
dlt_source=pipeline(),
dlt_pipeline=dlt_pipeline,
dagster_dlt_translator=CustomDagsterDltTranslator(),
)
def example_pipeline_assets(
context: AssetExecutionContext, dlt_pipeline_resource: DagsterDltResource
):
yield from dlt_pipeline_resource.run(context=context)

first_asset_metadata = next(iter(example_pipeline_assets.metadata_by_key.values()))
dagster_dlt_source = first_asset_metadata.get("dagster_dlt/source")
dagster_dlt_pipeline = first_asset_metadata.get("dagster_dlt/pipeline")
dagster_dlt_translator = first_asset_metadata.get("dagster_dlt/translator")

assert example_pipeline_assets.metadata_by_key == {
AssetKey("dlt_pipeline_repos"): {
"dagster_dlt/source": dagster_dlt_source,
"dagster_dlt/pipeline": dagster_dlt_pipeline,
"dagster_dlt/translator": dagster_dlt_translator,
"mode": "upsert",
"primary_key": "id",
},
AssetKey("dlt_pipeline_repo_issues"): {
"dagster_dlt/source": dagster_dlt_source,
"dagster_dlt/pipeline": dagster_dlt_pipeline,
"dagster_dlt/translator": dagster_dlt_translator,
"mode": "upsert",
"primary_key": ["repo_id", "issue_id"],
},
}


def test_asset_metadata_legacy(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
metadata_by_resource_name = {
"repos": {"mode": "upsert", "primary_key": "id"},
Expand Down Expand Up @@ -473,6 +543,23 @@ async def main():


def test_with_asset_key_replacements(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_spec(self, data: DltResourceTranslatorData) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.replace_attributes(key=default_spec.key.with_prefix("prefix"))

@dlt_assets(
dlt_source=dlt_source(),
dlt_pipeline=dlt_pipeline,
dagster_dlt_translator=CustomDagsterDltTranslator(),
)
def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...

assert my_dlt_assets.asset_deps.keys()
assert all(key.has_prefix(["prefix"]) for key in my_dlt_assets.asset_deps.keys())


def test_with_asset_key_replacements_legacy(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_key(self, resource: DltResource) -> AssetKey:
return super().get_asset_key(resource).with_prefix("prefix")
Expand All @@ -489,6 +576,23 @@ def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...


def test_with_deps_replacements(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_spec(self, data: DltResourceTranslatorData) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.replace_attributes(deps=[])

@dlt_assets(
dlt_source=dlt_source(),
dlt_pipeline=dlt_pipeline,
dagster_dlt_translator=CustomDagsterDltTranslator(),
)
def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...

assert my_dlt_assets.asset_deps.keys()
assert all(not deps for deps in my_dlt_assets.asset_deps.values())


def test_with_deps_replacements_legacy(dlt_pipeline: Pipeline) -> None:
class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_deps_asset_keys(self, _) -> Sequence[AssetKey]:
return []
Expand All @@ -507,6 +611,25 @@ def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...
def test_with_description_replacements(dlt_pipeline: Pipeline) -> None:
expected_description = "customized description"

class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_spec(self, data: DltResourceTranslatorData) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.replace_attributes(description=expected_description)

@dlt_assets(
dlt_source=dlt_source(),
dlt_pipeline=dlt_pipeline,
dagster_dlt_translator=CustomDagsterDltTranslator(),
)
def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...

for description in my_dlt_assets.descriptions_by_key.values():
assert description == expected_description


def test_with_description_replacements_legacy(dlt_pipeline: Pipeline) -> None:
expected_description = "customized description"

class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_description(self, _) -> Optional[str]:
return expected_description
Expand All @@ -525,6 +648,25 @@ def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...
def test_with_metadata_replacements(dlt_pipeline: Pipeline) -> None:
expected_metadata = {"customized": "metadata"}

class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_spec(self, data: DltResourceTranslatorData) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.merge_attributes(metadata=expected_metadata)

@dlt_assets(
dlt_source=dlt_source(),
dlt_pipeline=dlt_pipeline,
dagster_dlt_translator=CustomDagsterDltTranslator(),
)
def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...

for metadata in my_dlt_assets.metadata_by_key.values():
assert metadata["customized"] == "metadata"


def test_with_metadata_replacements_legacy(dlt_pipeline: Pipeline) -> None:
expected_metadata = {"customized": "metadata"}

class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_metadata(self, _) -> Optional[Mapping[str, Any]]:
return expected_metadata
Expand All @@ -540,7 +682,7 @@ def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...
assert metadata["customized"] == "metadata"


def test_with_group_replacements(dlt_pipeline: Pipeline) -> None:
def test_with_group_replacements_legacy(dlt_pipeline: Pipeline) -> None:
expected_group = "customized_group"

class CustomDagsterDltTranslator(DagsterDltTranslator):
Expand All @@ -557,8 +699,7 @@ def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...
for group in my_dlt_assets.group_names_by_key.values():
assert group == expected_group


def test_with_owner_replacements(dlt_pipeline: Pipeline) -> None:
def test_with_owner_replacements_legacy(dlt_pipeline: Pipeline) -> None:
expected_owners = ["custom@custom.com"]

class CustomDagsterDltTranslator(DagsterDltTranslator):
Expand Down Expand Up @@ -587,6 +728,33 @@ def test_with_tag_replacements(dlt_pipeline: Pipeline) -> None:
**build_kind_tag("test"),
}

class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_spec(self, data: DltResourceTranslatorData) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.replace_attributes(tags=expected_tags, kinds={"dlt", "test"})

@dlt_assets(
dlt_source=dlt_source(),
dlt_pipeline=dlt_pipeline,
dagster_dlt_translator=CustomDagsterDltTranslator(),
)
def my_dlt_assets(dlt_pipeline_resource: DagsterDltResource): ...

for tags in my_dlt_assets.tags_by_key.values():
assert tags == expected_tags


def test_with_tag_replacements_legacy(dlt_pipeline: Pipeline) -> None:
custom_tags = {
"customized": "tag",
}

expected_tags = {
**custom_tags,
**build_kind_tag("dlt"),
**build_kind_tag("test"),
}

class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_tags(self, _) -> Optional[Mapping[str, str]]:
return expected_tags
Expand Down

0 comments on commit fc08bd8

Please sign in to comment.