From 7381b3b22e06aa17a4a869a758f959046e89ed7f Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Thu, 16 Nov 2023 21:09:18 -0500 Subject: [PATCH] Fix mypy complaints --- .github/workflows/mypy.yaml | 3 +-- funlib/persistence/arrays/datasets.py | 4 ++-- funlib/persistence/graphs/graph_database.py | 4 ++-- funlib/persistence/graphs/pgsql_graph_database.py | 11 +++++++---- funlib/persistence/graphs/sql_graph_database.py | 6 +++--- funlib/persistence/graphs/sqlite_graph_database.py | 4 ++-- pyproject.toml | 2 +- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 93cb874..9535ed0 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,6 +15,5 @@ jobs: uses: actions/checkout@v2 - name: mypy run: | - pip install . - pip install --upgrade mypy + pip install '.[dev]' mypy funlib/persistence diff --git a/funlib/persistence/arrays/datasets.py b/funlib/persistence/arrays/datasets.py index 7d4f4db..4bf817a 100644 --- a/funlib/persistence/arrays/datasets.py +++ b/funlib/persistence/arrays/datasets.py @@ -9,7 +9,7 @@ import logging import os import shutil -from typing import Optional +from typing import Optional, Union logger = logging.getLogger(__name__) @@ -180,7 +180,7 @@ def prepare_ds( write_roi: Roi = None, write_size: Coordinate = None, num_channels: Optional[int] = None, - compressor: str = "default", + compressor: Union[str, dict] = "default", delete: bool = False, force_exact_write_size: bool = False, ) -> Array: diff --git a/funlib/persistence/graphs/graph_database.py b/funlib/persistence/graphs/graph_database.py index 5d5896e..9716ea9 100644 --- a/funlib/persistence/graphs/graph_database.py +++ b/funlib/persistence/graphs/graph_database.py @@ -33,7 +33,7 @@ def __getitem__(self, roi) -> Graph: @property @abstractmethod - def node_attrs(self) -> list[str]: + def node_attrs(self) -> dict[str, type]: """ Return the node attributes supported by the database. """ @@ -41,7 +41,7 @@ def node_attrs(self) -> list[str]: @property @abstractmethod - def edge_attrs(self) -> list[str]: + def edge_attrs(self) -> dict[str, type]: """ Return the edge attributes supported by the database. """ diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index 467ac27..59de797 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -23,7 +23,7 @@ def __init__( total_roi: Optional[Roi] = None, nodes_table: str = "nodes", edges_table: str = "edges", - endpoint_names: Optional[tuple[str, str]] = None, + endpoint_names: Optional[list[str]] = None, node_attrs: Optional[dict[str, type]] = None, edge_attrs: Optional[dict[str, type]] = None, ): @@ -122,16 +122,19 @@ def _store_metadata(self, metadata) -> None: "metadata", ["value"], [[json.dumps(metadata)]], fail_if_exists=True ) - def _read_metadata(self) -> dict[str, Any]: + def _read_metadata(self) -> Optional[dict[str, Any]]: try: self.__exec("SELECT value FROM metadata") except psycopg2.errors.UndefinedTable: self.connection.rollback() return None - metadata = self.cur.fetchone()[0] + result = self.cur.fetchone() + if result is not None: + metadata = result[0] + return json.loads(metadata) - return json.loads(metadata) + return None def _select_query(self, query) -> Iterable[Any]: self.__exec(query) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 5b927a7..ede8a03 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -66,7 +66,7 @@ def __init__( total_roi: Optional[Roi] = None, nodes_table: str = "nodes", edges_table: str = "edges", - endpoint_names: Optional[tuple[str, str]] = None, + endpoint_names: Optional[list[str]] = None, node_attrs: Optional[dict[str, type]] = None, edge_attrs: Optional[dict[str, type]] = None, ): @@ -101,7 +101,7 @@ def _store_metadata(self, metadata) -> None: pass @abstractmethod - def _read_metadata(self) -> dict[str, Any]: + def _read_metadata(self) -> Optional[dict[str, Any]]: pass @abstractmethod @@ -217,7 +217,7 @@ def edge_attrs(self) -> dict[str, type]: return self._edge_attrs if self._edge_attrs is not None else {} @edge_attrs.setter - def edge_attrs(self, value: Optional[Iterable[str]]) -> None: + def edge_attrs(self, value: dict[str, type]) -> None: self._edge_attrs = value def read_nodes( diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index ddacb3e..6ed1c3f 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -21,7 +21,7 @@ def __init__( total_roi: Optional[Roi] = None, nodes_table: str = "nodes", edges_table: str = "edges", - endpoint_names: Optional[tuple[str, str]] = None, + endpoint_names: Optional[list[str]] = None, node_attrs: Optional[dict[str, type]] = None, edge_attrs: Optional[dict[str, type]] = None, ): @@ -87,7 +87,7 @@ def _store_metadata(self, metadata): with open(self.meta_collection, "w") as f: json.dump(metadata, f) - def _read_metadata(self) -> dict[str, Any]: + def _read_metadata(self) -> Optional[dict[str, Any]]: if not self.meta_collection.exists(): return None diff --git a/pyproject.toml b/pyproject.toml index 0fc52a4..8b540c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ version = { attr = "funlib.persistence.__version__" } [project.optional-dependencies] -dev = ['coverage>=5.0.3', 'pytest', 'black', 'mypy'] +dev = ['coverage>=5.0.3', 'pytest', 'black', 'mypy', 'types-psycopg2'] [tool.black] target_version = ['py39', 'py310', 'py311']