Skip to content

Commit

Permalink
Merge pull request #25 from mam-dev/strict-mypy
Browse files Browse the repository at this point in the history
Make mypy strict
  • Loading branch information
bunny-therapist authored Jun 16, 2023
2 parents bc3854d + f7e9bc5 commit 36c000d
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 72 deletions.
18 changes: 13 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,24 @@ usefixtures = ["requests_mock"]
testpaths = ["test"]

[tool.mypy]
files = ["src", "test"]
warn_no_return = true
warn_return_any = true
warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true
warn_unreachable = true
files = ["src", "test"]

[[tool.mypy.overrides]]
module = 'py'
ignore_missing_imports = true
check_untyped_defs = true
disallow_any_generics = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
no_implicit_reexport = true
strict_equality = true
strict_concatenate = true

[tool.ruff]
src = ["src", "test"]
Expand Down
56 changes: 36 additions & 20 deletions src/security_constraints/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,27 @@
import argparse
import dataclasses
import enum
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Set, get_type_hints
import sys
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Set,
Tuple,
Type,
get_type_hints,
)

if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover (<py311)
else:
from typing_extensions import Self # pragma: no cover (>=py311)

if TYPE_CHECKING: # pragma: no cover
import sys

if TYPE_CHECKING: # pragma: no cover
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
Expand All @@ -27,18 +43,18 @@ class SeverityLevel(str, enum.Enum):
LOW = "LOW"

@classmethod
def _missing_(cls, value: object) -> Optional["SeverityLevel"]:
def _missing_(cls: Type[Self], value: object) -> Optional[Self]:
# Makes instantiation case-insensitive
if isinstance(value, str):
for member in cls:
if member.value == value.upper():
return member
return None

def get_higher_or_equal_severities(self) -> Set["SeverityLevel"]:
def get_higher_or_equal_severities(self: Self) -> Set[Self]:
"""Get a set containing this SeverityLevel and all higher ones."""
return {
SeverityLevel(value)
type(self)(value)
for value in type(self).__members__.values()
if self.severity_score <= SeverityLevel(value).severity_score
}
Expand Down Expand Up @@ -68,16 +84,16 @@ def _compare_as_int(self, method_name: str, other: Any) -> bool:
comparison_method = getattr(self.severity_score, method_name)
return comparison_method(other.severity_score) # type: ignore[no-any-return]

def __gt__(self, other) -> bool:
def __gt__(self, other: Any) -> bool:
return self._compare_as_int("__gt__", other)

def __lt__(self, other) -> bool:
def __lt__(self, other: Any) -> bool:
return self._compare_as_int("__lt__", other)

def __ge__(self, other) -> bool:
def __ge__(self, other: Any) -> bool:
return self._compare_as_int("__ge__", other)

def __le__(self, other) -> bool:
def __le__(self, other: Any) -> bool:
return self._compare_as_int("__le__", other)

def __str__(self) -> str:
Expand All @@ -90,12 +106,12 @@ class ArgumentNamespace(argparse.Namespace):
dump_config: bool
debug: bool
version: bool
output: Optional[IO]
output: Optional[IO[str]]
ignore_ids: List[str]
config: Optional[str]
min_severity: SeverityLevel

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: Any) -> None:
# Makes it so that no attributes except those type hinted above can be set.
if key not in get_type_hints(self):
raise AttributeError(f"No attribute named '{key}'")
Expand Down Expand Up @@ -126,9 +142,9 @@ class Configuration:
ignore_ids: Set[str] = dataclasses.field(default_factory=set)
min_severity: SeverityLevel = dataclasses.field(default=SeverityLevel.CRITICAL)

def to_dict(self) -> Dict:
def _dict_factory(data):
def convert(obj):
def to_dict(self) -> Dict[str, Any]:
def _dict_factory(data: List[Tuple[str, Any]]) -> Dict[str, Any]:
def convert(obj: Any) -> Any:
if isinstance(obj, enum.Enum):
# Use values for Enums
return obj.value
Expand All @@ -142,7 +158,7 @@ def convert(obj):
return dataclasses.asdict(self, dict_factory=_dict_factory)

@classmethod
def from_dict(cls, in_dict: Dict) -> "Configuration":
def from_dict(cls: Type[Self], in_dict: Dict[str, Any]) -> Self:
kwargs: _ConfigurationKwargs = {}
if "ignore_ids" in in_dict:
kwargs["ignore_ids"] = set(in_dict["ignore_ids"])
Expand All @@ -151,14 +167,14 @@ def from_dict(cls, in_dict: Dict) -> "Configuration":
return cls(**kwargs)

@classmethod
def from_args(cls, args: ArgumentNamespace) -> "Configuration":
def from_args(cls: Type[Self], args: ArgumentNamespace) -> Self:
return cls(
ignore_ids=set(args.ignore_ids),
min_severity=args.min_severity,
)

@classmethod
def merge(cls, *config: "Configuration") -> "Configuration":
def merge(cls: Type[Self], *config: Self) -> Self:
"""Merge multiple Configurations into a new one."""
all_ignore_ids_entries = (c.ignore_ids for c in config)
all_min_severity_entries = (c.min_severity for c in config)
Expand All @@ -168,7 +184,7 @@ def merge(cls, *config: "Configuration") -> "Configuration":
)

@classmethod
def supported_keys(cls) -> List[str]:
def supported_keys(cls: Type[Self]) -> List[str]:
"""Return a list of keys which are supported in the config file."""
return list(cls().to_dict().keys())

Expand All @@ -179,7 +195,7 @@ class PackageConstraints:
Attributes:
package: The name of the package.
specifies: A list of version specifiers, e.g. ">3.0".
specifiers: A list of version specifiers, e.g. ">3.0".
"""

Expand Down
43 changes: 38 additions & 5 deletions src/security_constraints/github_security_advisory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import string
from typing import Any, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set

import requests

Expand All @@ -14,6 +14,33 @@
SeverityLevel,
)

if TYPE_CHECKING: # pragma: no cover
import sys

if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict

class _GraphQlResponseJson(TypedDict, total=False):
data: Dict[Any, Any]

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard


def _is_graphql_response_json(
response_json: Any,
) -> "TypeGuard[_GraphQlResponseJson]":
return (
isinstance(response_json, dict)
and isinstance(response_json.get("data"), dict)
and all(isinstance(key, str) for key in response_json["data"])
)


LOGGER = logging.getLogger(__name__)

QUERY_TEMPLATE = string.Template(
Expand Down Expand Up @@ -75,11 +102,11 @@ def get_vulnerabilities(
vulnerabilities: List[SecurityVulnerability] = []
more_data_exists = True
while more_data_exists:
json_response: Dict = self._do_graphql_request(
json_response: "_GraphQlResponseJson" = self._do_graphql_request(
severities=severities, after=after
)
try:
json_data: Dict = json_response["data"]
json_data: Dict[str, Any] = json_response["data"]
vulnerabilities.extend(
[
SecurityVulnerability(
Expand Down Expand Up @@ -109,7 +136,7 @@ def get_vulnerabilities(

def _do_graphql_request(
self, severities: Set[SeverityLevel], after: Optional[str] = None
) -> Any:
) -> "_GraphQlResponseJson":
query = QUERY_TEMPLATE.substitute(
first=100,
severities=",".join(sorted([str(severity) for severity in severities])),
Expand All @@ -122,7 +149,11 @@ def _do_graphql_request(
)
try:
response.raise_for_status()
return response.json()
json_content: Any = response.json()
if not _is_graphql_response_json(response_json=json_content):
raise FetchVulnerabilitiesError(
f"Unexpected json data format in response: {json_content}"
)
except requests.HTTPError as error:
LOGGER.error(
"HTTP error (status %s) received from URL %s: %s",
Expand All @@ -134,3 +165,5 @@ def _do_graphql_request(
except requests.JSONDecodeError as error:
LOGGER.error("Could not decode json data in response: %s", response.text)
raise FetchVulnerabilitiesError from error
else:
return json_content
2 changes: 1 addition & 1 deletion src/security_constraints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def main() -> int:
The program exit code as an integer.
"""
output: Optional[IO] = None
output: Optional[IO[str]] = None
try:
args = get_args()
if args.version:
Expand Down
10 changes: 7 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import datetime
from typing import TYPE_CHECKING, Generator
from unittest.mock import Mock

import freezegun
import pytest

if TYPE_CHECKING:
from _pytest.monkeypatch import MonkeyPatch

from security_constraints.common import ArgumentNamespace, SeverityLevel


Expand All @@ -21,23 +25,23 @@ def fixture_arg_namespace() -> ArgumentNamespace:


@pytest.fixture(name="github_token")
def fixture_token_in_env(monkeypatch) -> str:
def fixture_token_in_env(monkeypatch: "MonkeyPatch") -> str:
"""Set SC_GITHUB_TOKEN environment variable and return it."""
token = "3e00409b-f017-4ecc-b7bf-f11f6e2a5693"
monkeypatch.setenv("SC_GITHUB_TOKEN", token)
return token


@pytest.fixture(name="mock_version")
def fixture_mock_version(monkeypatch) -> Mock:
def fixture_mock_version(monkeypatch: "MonkeyPatch") -> Mock:
"""Mock main.version with a mock that returns 'x.y.z'."""
mock_version: Mock = Mock(return_value="x.y.z")
monkeypatch.setattr("security_constraints.main.version", mock_version)
return mock_version


@pytest.fixture(name="frozen_time")
def _fixture_frozen_time():
def _fixture_frozen_time() -> Generator[None, None, None]:
"""Freeze time during the test.
The UTC timestamp will be '1986-04-09T12:11:10.000009Z'.
Expand Down
6 changes: 4 additions & 2 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_severity_level_ge(
assert (first >= second) == expected


def test_argument_namespace_can_be_modified(arg_namespace) -> None:
def test_argument_namespace_can_be_modified(arg_namespace: ArgumentNamespace) -> None:
arg_namespace.dump_config = True
assert arg_namespace.dump_config
arg_namespace.debug = True
Expand All @@ -127,7 +127,9 @@ def test_argument_namespace_can_be_modified(arg_namespace) -> None:
assert arg_namespace.min_severity == SeverityLevel.HIGH


def test_argument_namespace_cannot_be_extended(arg_namespace) -> None:
def test_argument_namespace_cannot_be_extended(
arg_namespace: ArgumentNamespace,
) -> None:
with pytest.raises(AttributeError):
arg_namespace.does_not_exist = True

Expand Down
Loading

0 comments on commit 36c000d

Please sign in to comment.