Skip to content

Commit

Permalink
Add more types for function arguments and return values (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
syou6162 authored Feb 13, 2024
1 parent f553725 commit 4db04d2
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 35 deletions.
24 changes: 14 additions & 10 deletions dbterd/adapters/algos/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from typing import Dict, List

import click

Expand All @@ -9,9 +10,10 @@
TEST_META_RELATIONSHIP_TYPE,
)
from dbterd.helpers.log import logger
from dbterd.types import Catalog, Manifest


def get_tables_from_metadata(data=[], **kwargs):
def get_tables_from_metadata(data=[], **kwargs) -> List[Table]:
"""Extract tables from dbt metadata
Args:
Expand Down Expand Up @@ -47,7 +49,7 @@ def get_tables_from_metadata(data=[], **kwargs):
return tables


def get_tables(manifest, catalog, **kwargs):
def get_tables(manifest: Manifest, catalog: Catalog, **kwargs) -> List[Table]:
"""Extract tables from dbt artifacts
Args:
Expand Down Expand Up @@ -94,7 +96,9 @@ def get_tables(manifest, catalog, **kwargs):
return tables


def enrich_tables_from_relationships(tables, relationships):
def enrich_tables_from_relationships(
tables: List[Table], relationships: List[Ref]
) -> List[Table]:
"""Fullfill columns in Table due to `select *`
Args:
Expand Down Expand Up @@ -180,7 +184,7 @@ def get_table_from_metadata(model_metadata, exposures=[], **kwargs) -> Table:


def get_table(
node_name, manifest_node, catalog_node=None, exposures=[], **kwargs
node_name: str, manifest_node, catalog_node=None, exposures=[], **kwargs
) -> Table:
"""Construct a single Table object
Expand Down Expand Up @@ -313,7 +317,7 @@ def get_node_exposures_from_metadata(data=[], **kwargs):
return exposures


def get_node_exposures(manifest):
def get_node_exposures(manifest: Manifest) -> List[Dict[str, str]]:
"""Get the mapping of table name and exposure name
Args:
Expand Down Expand Up @@ -349,7 +353,7 @@ def get_table_name(format: str, **kwargs) -> str:
return ".".join([kwargs.get(x.lower()) or "KEYNOTFOUND" for x in format.split(".")])


def get_relationships_from_metadata(data=[], **kwargs) -> list[Ref]:
def get_relationships_from_metadata(data=[], **kwargs) -> List[Ref]:
"""Extract relationships from Metadata result list on test relationship
Args:
Expand Down Expand Up @@ -409,7 +413,7 @@ def get_relationships_from_metadata(data=[], **kwargs) -> list[Ref]:
return get_unique_refs(refs=refs)


def get_relationships(manifest, **kwargs):
def get_relationships(manifest: Manifest, **kwargs) -> List[Ref]:
"""Extract relationships from dbt artifacts based on test relationship
Args:
Expand Down Expand Up @@ -482,7 +486,7 @@ def get_unique_refs(refs: list[Ref] = []) -> list[Ref]:
return distinct_list


def get_algo_rule(**kwargs):
def get_algo_rule(**kwargs) -> Dict[str, str]:
"""Extract rule from the --algo option
Args:
Expand Down Expand Up @@ -517,7 +521,7 @@ def get_algo_rule(**kwargs):
return rules


def get_table_map_from_metadata(test_node, **kwargs):
def get_table_map_from_metadata(test_node, **kwargs) -> List[str]:
"""Get the table map with order of [to, from] guaranteed
(for Metadata)
Expand Down Expand Up @@ -570,7 +574,7 @@ def get_table_map_from_metadata(test_node, **kwargs):
return list(reversed(test_parents))


def get_table_map(test_node, **kwargs):
def get_table_map(test_node, **kwargs) -> List[str]:
"""Get the table map with order of [to, from] guaranteed
Args:
Expand Down
11 changes: 8 additions & 3 deletions dbterd/adapters/algos/test_relationship.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List, Tuple, Union

from dbterd.adapters.algos import base
from dbterd.adapters.filter import is_selected_table
from dbterd.adapters.meta import Ref
from dbterd.adapters.meta import Ref, Table
from dbterd.helpers.log import logger
from dbterd.types import Catalog, Manifest


def parse_metadata(data, **kwargs):
def parse_metadata(data, **kwargs) -> Tuple[List[Table], List[Ref]]:
"""Get all information (tables, relationships) needed for building diagram
(from Metadata)
Expand Down Expand Up @@ -47,7 +50,9 @@ def parse_metadata(data, **kwargs):
return (tables, relationships)


def parse(manifest, catalog, **kwargs):
def parse(
manifest: Manifest, catalog: Union[str, Catalog], **kwargs
) -> Tuple[List[Table], List[Ref]]:
"""Get all information (tables, relationships) needed for building diagram
Args:
Expand Down
18 changes: 9 additions & 9 deletions dbterd/adapters/filter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import sys
from fnmatch import fnmatch
from typing import List
from typing import List, Optional, Tuple

from dbterd.adapters.meta import Table

RULE_FUNC_PREFIX = "is_satisfied_by_"


def has_unsupported_rule(rules: List[str] = []) -> bool:
def has_unsupported_rule(rules: List[str] = []) -> Tuple[bool, Optional[str]]:
"""Verify if existing the unsupported selection rule
Args:
Expand All @@ -32,7 +32,7 @@ def is_selected_table(
select_rules: List[str] = [],
exclude_rules: List[str] = [],
resource_types: List[str] = ["model"],
):
) -> bool:
"""Check if Table is selected with defined selection criteria
Args:
Expand Down Expand Up @@ -61,7 +61,7 @@ def is_selected_table(
return selected and not excluded


def evaluate_rule(table: Table, rule: str):
def evaluate_rule(table: Table, rule: str) -> bool:
"""Evaluate selection/exclusion single rule with AND logic applied
Args:
Expand All @@ -86,7 +86,7 @@ def evaluate_rule(table: Table, rule: str):
return all(results)


def is_satisfied_by_name(table: Table, rule: str = ""):
def is_satisfied_by_name(table: Table, rule: str = "") -> bool:
"""Evaluate rule by Name
Args:
Expand All @@ -101,7 +101,7 @@ def is_satisfied_by_name(table: Table, rule: str = ""):
return table.node_name.startswith(rule)


def is_satisfied_by_exact(table: Table, rule: str = ""):
def is_satisfied_by_exact(table: Table, rule: str = "") -> bool:
"""Evaluate rule by model name with exact match
Args:
Expand All @@ -116,7 +116,7 @@ def is_satisfied_by_exact(table: Table, rule: str = ""):
return table.node_name.lower() == rule


def is_satisfied_by_schema(table: Table, rule: str = ""):
def is_satisfied_by_schema(table: Table, rule: str = "") -> bool:
"""Evaluate rule by Schema name
Args:
Expand All @@ -137,7 +137,7 @@ def is_satisfied_by_schema(table: Table, rule: str = ""):
)


def is_satisfied_by_wildcard(table: Table, rule: str = "*"):
def is_satisfied_by_wildcard(table: Table, rule: str = "*") -> bool:
"""Evaluate rule by Wildcard (Unix Style)
Args:
Expand All @@ -152,7 +152,7 @@ def is_satisfied_by_wildcard(table: Table, rule: str = "*"):
return fnmatch(table.node_name, rule)


def is_satisfied_by_exposure(table: Table, rule: str = ""):
def is_satisfied_by_exposure(table: Table, rule: str = "") -> bool:
"""Evaluate rule by dbt Exposure name
Args:
Expand Down
7 changes: 5 additions & 2 deletions dbterd/adapters/targets/d2/d2_test_relationship.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Tuple

from dbterd.adapters.algos import test_relationship
from dbterd.types import Catalog, Manifest


def run(manifest, catalog, **kwargs):
def run(manifest: Manifest, catalog: Catalog, **kwargs) -> Tuple[str, str]:
"""Parse dbt artifacts and export D2 file
Args:
Expand All @@ -14,7 +17,7 @@ def run(manifest, catalog, **kwargs):
return ("output.d2", parse(manifest, catalog, **kwargs))


def parse(manifest, catalog, **kwargs):
def parse(manifest: Manifest, catalog: Catalog, **kwargs) -> str:
"""Get the D2 content from dbt artifacts
Args:
Expand Down
6 changes: 4 additions & 2 deletions dbterd/adapters/targets/dbml/dbml_test_relationship.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from typing import Tuple

from dbterd.adapters.algos import test_relationship
from dbterd.types import Catalog, Manifest


def run(manifest, catalog, **kwargs):
def run(manifest: Manifest, catalog: Catalog, **kwargs) -> Tuple[str, str]:
"""Parse dbt artifacts and export DBML file
Args:
Expand All @@ -16,7 +18,7 @@ def run(manifest, catalog, **kwargs):
return ("output.dbml", parse(manifest, catalog, **kwargs))


def parse(manifest, catalog, **kwargs):
def parse(manifest: Manifest, catalog: Catalog, **kwargs) -> str:
"""Get the DBML content from dbt artifacts
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Tuple

from dbterd.adapters.algos import test_relationship
from dbterd.types import Catalog, Manifest


def run(manifest, catalog, **kwargs):
def run(manifest: Manifest, catalog: Catalog, **kwargs) -> Tuple[str, str]:
"""Parse dbt artifacts and export GraphViz file
Args:
Expand All @@ -14,7 +17,7 @@ def run(manifest, catalog, **kwargs):
return ("output.graphviz", parse(manifest, catalog, **kwargs))


def parse(manifest, catalog, **kwargs):
def parse(manifest: Manifest, catalog: Catalog, **kwargs) -> str:
"""Get the GraphViz content from dbt artifacts
Args:
Expand Down
7 changes: 4 additions & 3 deletions dbterd/adapters/targets/mermaid/mermaid_test_relationship.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from typing import Optional
from typing import Optional, Tuple

from dbterd.adapters.algos import test_relationship
from dbterd.types import Catalog, Manifest


def run(manifest, catalog, **kwargs):
def run(manifest: Manifest, catalog: Catalog, **kwargs) -> Tuple[str, str]:
"""Parse dbt artifacts and export Mermaid file
Args:
Expand Down Expand Up @@ -67,7 +68,7 @@ def replace_column_type(column_type: str) -> str:
return column_type.replace(" ", "-")


def parse(manifest, catalog, **kwargs):
def parse(manifest: Manifest, catalog: Catalog, **kwargs) -> str:
"""Get the Mermaid content from dbt artifacts
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Tuple

from dbterd.adapters.algos import test_relationship
from dbterd.types import Catalog, Manifest


def run(manifest, catalog, **kwargs):
def run(manifest: Manifest, catalog: Catalog, **kwargs) -> Tuple[str, str]:
"""Parse dbt artifacts and export PlantUML file
Args:
Expand All @@ -14,7 +17,7 @@ def run(manifest, catalog, **kwargs):
return ("output.plantuml", parse(manifest, catalog, **kwargs))


def parse(manifest, catalog, **kwargs):
def parse(manifest: Manifest, catalog: Catalog, **kwargs) -> str:
"""Get the PlantUML content from dbt artifacts
Args:
Expand Down
5 changes: 3 additions & 2 deletions dbterd/helpers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbt_artifacts_parser import parser

from dbterd.helpers.log import logger
from dbterd.types import Catalog, Manifest


def get_sys_platform(): # pragma: no cover
Expand Down Expand Up @@ -109,7 +110,7 @@ def win_prepare_path(path: str) -> str: # pragma: no cover
return path


def read_manifest(path: str, version: int = None):
def read_manifest(path: str, version: int = None) -> Manifest:
"""Reads in the manifest.json file, with optional version specification
Args:
Expand All @@ -134,7 +135,7 @@ def read_manifest(path: str, version: int = None):
return parse_func(manifest=_dict)


def read_catalog(path: str, version: int = None):
def read_catalog(path: str, version: int = None) -> Catalog:
"""Reads in the catalog.json file, with optional version specification
Args:
Expand Down
31 changes: 31 additions & 0 deletions dbterd/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Union

from dbt_artifacts_parser.parsers.catalog.catalog_v1 import CatalogV1
from dbt_artifacts_parser.parsers.manifest.manifest_v1 import ManifestV1
from dbt_artifacts_parser.parsers.manifest.manifest_v2 import ManifestV2
from dbt_artifacts_parser.parsers.manifest.manifest_v3 import ManifestV3
from dbt_artifacts_parser.parsers.manifest.manifest_v4 import ManifestV4
from dbt_artifacts_parser.parsers.manifest.manifest_v5 import ManifestV5
from dbt_artifacts_parser.parsers.manifest.manifest_v6 import ManifestV6
from dbt_artifacts_parser.parsers.manifest.manifest_v7 import ManifestV7
from dbt_artifacts_parser.parsers.manifest.manifest_v8 import ManifestV8
from dbt_artifacts_parser.parsers.manifest.manifest_v9 import ManifestV9
from dbt_artifacts_parser.parsers.manifest.manifest_v10 import ManifestV10
from dbt_artifacts_parser.parsers.manifest.manifest_v11 import ManifestV11

Manifest = Union[
ManifestV1,
ManifestV2,
ManifestV3,
ManifestV4,
ManifestV5,
ManifestV6,
ManifestV7,
ManifestV8,
ManifestV9,
ManifestV10,
ManifestV11,
]

# If a new version of Catalog is added, replace with `Union[CatalogV1, CatalogV2, ...]`.
Catalog = CatalogV1

0 comments on commit 4db04d2

Please sign in to comment.