Skip to content

Commit

Permalink
chore: typing and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravencentric committed Sep 18, 2024
1 parent d95b168 commit 918bfd6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/pynyaa/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pynyaa._compat import IntEnum, StrEnum
from pynyaa._types import CategoryID, CategoryLiteral, SortByLiteral
from pynyaa._utils import get_category_id_from_name
from pynyaa._utils import _get_category_id_from_name


class BaseStrEnum(StrEnum):
Expand All @@ -17,7 +17,7 @@ def _missing_(cls, value: object) -> Self:
for member in cls:
if member.value.casefold() == str(value).casefold():
return member
message = f"'{value}' is not a valid {type(cls)}"
message = f"'{value}' is not a valid {cls.__name__}"
raise ValueError(message)


Expand Down Expand Up @@ -63,7 +63,7 @@ def id(self) -> CategoryID:
This ID corresponds to the category as seen in the URL
`https://nyaa.si/?f=0&c=1_2&q=`, where `c=1_2` is the ID for `Anime - English-translated`.
"""
return get_category_id_from_name(self.value)
return _get_category_id_from_name(self.value)

@overload
@classmethod
Expand Down
10 changes: 6 additions & 4 deletions src/pynyaa/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
P = ParamSpec("P")
T = TypeVar("T")

# functools.cache destroys the signature of whatever it wraps, so we use this to fix it.
# This is to only "fool" typecheckers and IDEs, this doesn't exist at runtime.
def cache(user_function: Callable[P, T], /) -> Callable[P, T]: ... # type: ignore


Expand All @@ -22,16 +24,16 @@ def get_user_cache_path() -> Path:


@overload
def get_category_id_from_name(key: CategoryName) -> CategoryID: ...
def _get_category_id_from_name(key: CategoryName) -> CategoryID: ...


@overload
def get_category_id_from_name(key: str) -> CategoryID: ...
def _get_category_id_from_name(key: str) -> CategoryID: ...


@cache
def get_category_id_from_name(key: CategoryName | str) -> str:
mapping = {
def _get_category_id_from_name(key: CategoryName | str) -> CategoryID:
mapping: dict[str, CategoryID] = {
# All, c=0_0
"All": "0_0",
# Anime, c=1_X
Expand Down
10 changes: 10 additions & 0 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
from pynyaa import Category, Filter, SortBy


def test_category_value_error() -> None:
with pytest.raises(ValueError):
Category.get("asdadadsad", "invalid default")


def test_sortby_value_error() -> None:
with pytest.raises(ValueError):
SortBy.get("asdadadsad", "invalid default")


@pytest.mark.parametrize(
"category, expected_id",
[
Expand Down

0 comments on commit 918bfd6

Please sign in to comment.