Skip to content

Commit

Permalink
Improve Type Annotations Coverage (#162)
Browse files Browse the repository at this point in the history
* Add some type annotations

* More type annotation.

* Freedom from typlessness.

* More annotations.

* Clean up linter issues
  • Loading branch information
peterallenwebb authored Jul 4, 2024
1 parent d25c29a commit feda22d
Show file tree
Hide file tree
Showing 26 changed files with 343 additions and 274 deletions.
5 changes: 3 additions & 2 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from typing import List, Mapping, Optional

from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX
from dbt_common.record import Recorder


class InvocationContext:
def __init__(self, env: Mapping[str, str]):
self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)}
self._env_secrets: Optional[List[str]] = None
self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)}
self.recorder = None
self.recorder: Optional[Recorder] = None
# This class will also eventually manage the invocation_id, flags, event manager, etc.

@property
Expand All @@ -32,7 +33,7 @@ def env_secrets(self) -> List[str]:
_INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR")


def reliably_get_invocation_var() -> ContextVar:
def reliably_get_invocation_var() -> ContextVar[InvocationContext]:
invocation_var: Optional[ContextVar] = next(
(cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None
)
Expand Down
8 changes: 4 additions & 4 deletions dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional
from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple
import re
import jsonschema
from dataclasses import fields, Field
Expand Down Expand Up @@ -26,7 +26,7 @@ class ValidationError(jsonschema.ValidationError):


class DateTimeSerialization(SerializationStrategy):
def serialize(self, value) -> str:
def serialize(self, value: datetime) -> str:
out = value.isoformat()
# Assume UTC if timezone is missing
if value.tzinfo is None:
Expand Down Expand Up @@ -127,7 +127,7 @@ def _get_fields(cls) -> List[Tuple[Field, str]]:

# copied from hologram. Used in tests
@classmethod
def _get_field_names(cls):
def _get_field_names(cls) -> List[str]:
return [element[1] for element in cls._get_fields()]


Expand All @@ -152,7 +152,7 @@ def validate(cls, value):

# These classes must be in this order or it doesn't work
class StrEnum(str, SerializableType, Enum):
def __str__(self):
def __str__(self) -> str:
return self.value

# https://docs.python.org/3.6/library/enum.html#using-automatic-values
Expand Down
6 changes: 3 additions & 3 deletions dbt_common/exceptions/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import builtins
from typing import List, Any, Optional
from typing import Any, List, Optional
import os

from dbt_common.constants import SECRET_ENV_PREFIX
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(self, msg: str):
self.msg = scrub_secrets(msg, env_secrets())

@property
def type(self):
def type(self) -> str:
return "Internal"

def process_stack(self):
Expand All @@ -59,7 +59,7 @@ def process_stack(self):

return lines

def __str__(self):
def __str__(self) -> str:
if hasattr(self.msg, "split"):
split_msg = self.msg.split("\n")
else:
Expand Down
4 changes: 2 additions & 2 deletions dbt_common/helper_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class NVEnum(StrEnum):
novalue = "novalue"

def __eq__(self, other):
def __eq__(self, other) -> bool:
return isinstance(other, NVEnum)


Expand Down Expand Up @@ -59,7 +59,7 @@ def includes(self, item_name: str) -> bool:
item_name in self.include or self.include in self.INCLUDE_ALL
) and item_name not in self.exclude

def _validate_items(self, items: List[str]):
def _validate_items(self, items: List[str]) -> None:
pass


Expand Down
8 changes: 4 additions & 4 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Mapping, Optional, Type

from dbt_common.context import get_invocation_context


class Record:
"""An instance of this abstract Record class represents a request made by dbt
Expand Down Expand Up @@ -295,9 +293,11 @@ def record_function_inner(func_to_record):
return func_to_record

@functools.wraps(func_to_record)
def record_replay_wrapper(*args, **kwargs):
recorder: Recorder = None
def record_replay_wrapper(*args, **kwargs) -> Any:
recorder: Optional[Recorder] = None
try:
from dbt_common.context import get_invocation_context

recorder = get_invocation_context().recorder
except LookupError:
pass
Expand Down
30 changes: 16 additions & 14 deletions dbt_common/semver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
import re
from typing import List
from typing import List, Iterable

import dbt_common.exceptions.base
from dbt_common.exceptions import VersionsNotCompatibleError
Expand Down Expand Up @@ -74,7 +74,7 @@ def _cmp(a, b):

@dataclass
class VersionSpecifier(VersionSpecification):
def to_version_string(self, skip_matcher=False):
def to_version_string(self, skip_matcher: bool = False) -> str:
prerelease = ""
build = ""
matcher = ""
Expand All @@ -92,7 +92,7 @@ def to_version_string(self, skip_matcher=False):
)

@classmethod
def from_version_string(cls, version_string):
def from_version_string(cls, version_string: str) -> "VersionSpecifier":
match = _VERSION_REGEX.match(version_string)

if not match:
Expand All @@ -104,7 +104,7 @@ def from_version_string(cls, version_string):

return cls.from_dict(matched)

def __str__(self):
def __str__(self) -> str:
return self.to_version_string()

def to_range(self) -> "VersionRange":
Expand Down Expand Up @@ -192,32 +192,32 @@ def compare(self, other):

return 0

def __lt__(self, other):
def __lt__(self, other) -> bool:
return self.compare(other) == -1

def __gt__(self, other):
def __gt__(self, other) -> bool:
return self.compare(other) == 1

def __eq___(self, other):
def __eq___(self, other) -> bool:
return self.compare(other) == 0

def __cmp___(self, other):
return self.compare(other)

@property
def is_unbounded(self):
def is_unbounded(self) -> bool:
return False

@property
def is_lower_bound(self):
def is_lower_bound(self) -> bool:
return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL]

@property
def is_upper_bound(self):
def is_upper_bound(self) -> bool:
return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL]

@property
def is_exact(self):
def is_exact(self) -> bool:
return self.matcher == Matchers.EXACT

@classmethod
Expand Down Expand Up @@ -418,7 +418,7 @@ def reduce_versions(*args):
return to_return


def versions_compatible(*args):
def versions_compatible(*args) -> bool:
if len(args) == 1:
return True

Expand All @@ -429,7 +429,7 @@ def versions_compatible(*args):
return False


def find_possible_versions(requested_range, available_versions):
def find_possible_versions(requested_range, available_versions: Iterable[str]):
possible_versions = []

for version_string in available_versions:
Expand All @@ -442,7 +442,9 @@ def find_possible_versions(requested_range, available_versions):
return [v.to_version_string(skip_matcher=True) for v in sorted_versions]


def resolve_to_specific_version(requested_range, available_versions):
def resolve_to_specific_version(
requested_range, available_versions: Iterable[str]
) -> Optional[str]:
max_version = None
max_version_string = None

Expand Down
6 changes: 3 additions & 3 deletions dbt_common/utils/casting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This is useful for proto generated classes in particular, since
# the default for protobuf for strings is the empty string, so
# Optional[str] types don't work for generated Python classes.
from typing import Optional
from typing import Any, Dict, Optional


def cast_to_str(string: Optional[str]) -> str:
Expand All @@ -18,8 +18,8 @@ def cast_to_int(integer: Optional[int]) -> int:
return integer


def cast_dict_to_dict_of_strings(dct):
new_dct = {}
def cast_dict_to_dict_of_strings(dct: Dict[Any, Any]) -> Dict[str, str]:
new_dct: Dict[str, str] = {}
for k, v in dct.items():
new_dct[str(k)] = str(v)
return new_dct
9 changes: 6 additions & 3 deletions dbt_common/utils/executor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import concurrent.futures
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Protocol, Optional

from dbt_common.context import get_invocation_context, reliably_get_invocation_var
from dbt_common.context import (
get_invocation_context,
reliably_get_invocation_var,
InvocationContext,
)


class ConnectingExecutor(concurrent.futures.Executor):
Expand Down Expand Up @@ -63,7 +66,7 @@ class HasThreadingConfig(Protocol):
threads: Optional[int]


def _thread_initializer(invocation_context: ContextVar) -> None:
def _thread_initializer(invocation_context: InvocationContext) -> None:
invocation_var = reliably_get_invocation_var()
invocation_var.set(invocation_context)

Expand Down
8 changes: 5 additions & 3 deletions dbt_common/utils/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
DOCS_PREFIX = "dbt_docs__"


def get_dbt_macro_name(name):
def get_dbt_macro_name(name) -> str:
if name is None:
raise DbtInternalError("Got None for a macro name!")
return f"{MACRO_PREFIX}{name}"


def get_dbt_docs_name(name):
def get_dbt_docs_name(name) -> str:
if name is None:
raise DbtInternalError("Got None for a doc name!")
return f"{DOCS_PREFIX}{name}"


def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True):
def get_materialization_macro_name(
materialization_name, adapter_type=None, with_prefix=True
) -> str:
if adapter_type is None:
adapter_type = "default"
name = f"materialization_{materialization_name}_{adapter_type}"
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/test_agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@


class TestAgateHelper(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.tempdir = mkdtemp()

def tearDown(self):
def tearDown(self) -> None:
rmtree(self.tempdir)

def test_from_csv(self):
def test_from_csv(self) -> None:
path = os.path.join(self.tempdir, "input.csv")
with open(path, "wb") as fp:
fp.write(SAMPLE_CSV_DATA.encode("utf-8"))
Expand All @@ -61,7 +61,7 @@ def test_from_csv(self):
for idx, row in enumerate(tbl):
self.assertEqual(list(row), EXPECTED[idx])

def test_bom_from_csv(self):
def test_bom_from_csv(self) -> None:
path = os.path.join(self.tempdir, "input.csv")
with open(path, "wb") as fp:
fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8"))
Expand All @@ -70,7 +70,7 @@ def test_bom_from_csv(self):
for idx, row in enumerate(tbl):
self.assertEqual(list(row), EXPECTED[idx])

def test_from_csv_all_reserved(self):
def test_from_csv_all_reserved(self) -> None:
path = os.path.join(self.tempdir, "input.csv")
with open(path, "wb") as fp:
fp.write(SAMPLE_CSV_DATA.encode("utf-8"))
Expand All @@ -79,7 +79,7 @@ def test_from_csv_all_reserved(self):
for expected, row in zip(EXPECTED_STRINGS, tbl):
self.assertEqual(list(row), expected)

def test_from_data(self):
def test_from_data(self) -> None:
column_names = ["a", "b", "c", "d", "e", "f", "g"]
data = [
{
Expand All @@ -106,7 +106,7 @@ def test_from_data(self):
for idx, row in enumerate(tbl):
self.assertEqual(list(row), EXPECTED[idx])

def test_datetime_formats(self):
def test_datetime_formats(self) -> None:
path = os.path.join(self.tempdir, "input.csv")
datetimes = [
"20180806T11:33:29.000Z",
Expand All @@ -120,7 +120,7 @@ def test_datetime_formats(self):
tbl = agate_helper.from_csv(path, ())
self.assertEqual(tbl[0][0], expected)

def test_merge_allnull(self):
def test_merge_allnull(self) -> None:
t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c"))
t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c"))
result = agate_helper.merge_tables([t1, t2])
Expand All @@ -130,7 +130,7 @@ def test_merge_allnull(self):
assert isinstance(result.column_types[2], agate_helper.Integer)
self.assertEqual(len(result), 4)

def test_merge_mixed(self):
def test_merge_mixed(self) -> None:
t1 = agate_helper.table_from_rows(
[(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d")
)
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_merge_mixed(self):
assert isinstance(result.column_types[3], agate.data_types.Number)
self.assertEqual(len(result), 6)

def test_nocast_string_types(self):
def test_nocast_string_types(self) -> None:
# String fields should not be coerced into a representative type
# See: https://github.com/dbt-labs/dbt-core/issues/2984

Expand All @@ -202,7 +202,7 @@ def test_nocast_string_types(self):
for i, row in enumerate(tbl):
self.assertEqual(list(row), expected[i])

def test_nocast_bool_01(self):
def test_nocast_bool_01(self) -> None:
# True and False values should not be cast to 1 and 0, and vice versa
# See: https://github.com/dbt-labs/dbt-core/issues/4511

Expand Down
Loading

0 comments on commit feda22d

Please sign in to comment.