Skip to content

Commit 7f8ce63

Browse files
Merge pull request #4 from janelia-cellmap/main
Add features from cellmap fork
2 parents 4b4e685 + c2d205d commit 7f8ce63

File tree

8 files changed

+591
-116
lines changed

8 files changed

+591
-116
lines changed

funlib/persistence/arrays/datasets.py

Lines changed: 370 additions & 52 deletions
Large diffs are not rendered by default.

funlib/persistence/graphs/pgsql_graph_database.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def __init__(
6969
nodes_table=nodes_table,
7070
edges_table=edges_table,
7171
endpoint_names=endpoint_names,
72-
node_attrs=node_attrs,
73-
edge_attrs=edge_attrs,
72+
node_attrs=node_attrs, # type: ignore
73+
edge_attrs=edge_attrs, # type: ignore
7474
)
7575

7676
def _drop_tables(self) -> None:
@@ -101,12 +101,12 @@ def _create_tables(self) -> None:
101101
f"{self.nodes_table_name}({self.position_attribute})"
102102
)
103103

104-
columns = list(self.edge_attrs.keys())
104+
columns = list(self.edge_attrs.keys()) # type: ignore
105105
types = list([self.__sql_type(t) for t in self.edge_attrs.values()])
106106
column_types = [f"{c} {t}" for c, t in zip(columns, types)]
107107
self.__exec(
108108
f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}("
109-
f"{self.endpoint_names[0]} BIGINT not null, "
109+
f"{self.endpoint_names[0]} BIGINT not null, " # type: ignore
110110
f"{self.endpoint_names[1]} BIGINT not null, "
111111
f"{' '.join([c + ',' for c in column_types])}"
112112
f"PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})"

funlib/persistence/graphs/sql_graph_database.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def __init__(
8181
node_attrs: Optional[dict[str, AttributeType]] = None,
8282
edge_attrs: Optional[dict[str, AttributeType]] = None,
8383
):
84-
assert mode in self.valid_modes, f"Mode '{mode}' not in allowed modes {self.valid_modes}"
84+
assert (
85+
mode in self.valid_modes
86+
), f"Mode '{mode}' not in allowed modes {self.valid_modes}"
8587
self.mode = mode
8688

8789
if mode in self.read_modes:
@@ -135,8 +137,8 @@ def get(value, default):
135137

136138
self.directed = get(directed, False)
137139
self.total_roi = get(
138-
total_roi,
139-
Roi((None,) * self.ndims, (None,) * self.ndims))
140+
total_roi, Roi((None,) * self.ndims, (None,) * self.ndims)
141+
)
140142
self.nodes_table_name = get(nodes_table, "nodes")
141143
self.edges_table_name = get(edges_table, "edges")
142144
self.endpoint_names = get(endpoint_names, ["u", "v"])
@@ -229,7 +231,7 @@ def read_graph(
229231
edges = self.read_edges(
230232
roi, nodes=nodes, read_attrs=edge_attrs, attr_filter=edges_filter
231233
)
232-
u, v = self.endpoint_names
234+
u, v = self.endpoint_names # type: ignore
233235
try:
234236
edge_list = [(e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges]
235237
except KeyError as e:
@@ -336,11 +338,7 @@ def read_nodes(
336338

337339
nodes = [
338340
self._columns_to_node_attrs(
339-
{
340-
key: val
341-
for key, val in zip(read_columns, values)
342-
},
343-
read_attrs
341+
{key: val for key, val in zip(read_columns, values)}, read_attrs
344342
)
345343
for values in self._select_query(select_statement)
346344
]
@@ -375,11 +373,11 @@ def read_edges(
375373
return []
376374

377375
node_ids = ", ".join([str(node["id"]) for node in nodes])
378-
node_condition = f"{self.endpoint_names[0]} IN ({node_ids})"
376+
node_condition = f"{self.endpoint_names[0]} IN ({node_ids})" # type: ignore
379377

380378
logger.debug("Reading nodes in roi %s" % roi)
381379
# TODO: AND vs OR here
382-
desired_columns = ", ".join(self.endpoint_names + list(self.edge_attrs.keys()))
380+
desired_columns = ", ".join(self.endpoint_names + list(self.edge_attrs.keys())) # type: ignore
383381
select_statement = (
384382
f"SELECT {desired_columns} FROM {self.edges_table_name} WHERE "
385383
+ node_condition
@@ -390,7 +388,7 @@ def read_edges(
390388
)
391389
)
392390

393-
edge_attrs = self.endpoint_names + (
391+
edge_attrs = self.endpoint_names + ( # type: ignore
394392
list(self.edge_attrs.keys()) if read_attrs is None else read_attrs
395393
)
396394
attr_filter = attr_filter if attr_filter is not None else {}
@@ -401,7 +399,7 @@ def read_edges(
401399
{
402400
key: val
403401
for key, val in zip(
404-
self.endpoint_names + list(self.edge_attrs.keys()), values
402+
self.endpoint_names + list(self.edge_attrs.keys()), values # type: ignore
405403
)
406404
if key in edge_attrs
407405
}
@@ -486,8 +484,8 @@ def update_edges(
486484
if not roi.contains(pos_u):
487485
logger.debug(
488486
(
489-
f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}},"
490-
+ f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}"
487+
f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," # type: ignore
488+
+ f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" # type: ignore
491489
).format(u, v, data, roi)
492490
)
493491
continue
@@ -497,7 +495,7 @@ def update_edges(
497495
update_statement = (
498496
f"UPDATE {self.edges_table_name} SET "
499497
f"{', '.join(setters)} WHERE "
500-
f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}"
498+
f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" # type: ignore
501499
)
502500

503501
self._update_query(update_statement, commit=False)
@@ -528,10 +526,7 @@ def write_nodes(
528526
pos = self.__get_node_pos(data)
529527
if roi is not None and not roi.contains(pos):
530528
continue
531-
values.append(
532-
[node_id]
533-
+ [data.get(attr, None) for attr in attrs]
534-
)
529+
values.append([node_id] + [data.get(attr, None) for attr in attrs])
535530

536531
if len(values) == 0:
537532
logger.debug("No nodes to insert in %s", roi)
@@ -602,12 +597,13 @@ def __load_metadata(self, metadata):
602597

603598
# simple attributes
604599
for attr_name in [
605-
"position_attribute",
606-
"directed",
607-
"nodes_table_name",
608-
"edges_table_name",
609-
"endpoint_names",
610-
"ndims"]:
600+
"position_attribute",
601+
"directed",
602+
"nodes_table_name",
603+
"edges_table_name",
604+
"endpoint_names",
605+
"ndims",
606+
]:
611607

612608
if getattr(self, attr_name) is None:
613609
setattr(self, attr_name, metadata[attr_name])
@@ -657,7 +653,7 @@ def __remove_keys(self, dictionary, keys):
657653

658654
def __get_node_pos(self, n: dict[str, Any]) -> Optional[Coordinate]:
659655
try:
660-
return Coordinate(n[self.position_attribute])
656+
return Coordinate(n[self.position_attribute]) # type: ignore
661657
except KeyError:
662658
return None
663659

@@ -681,11 +677,13 @@ def __attr_query(self, attrs: dict[str, Any]) -> str:
681677
def __roi_query(self, roi: Roi) -> str:
682678
query = "WHERE "
683679
pos_attr = self.position_attribute
684-
for dim in range(self.ndims):
680+
for dim in range(self.ndims): # type: ignore
685681
if dim > 0:
686682
query += " AND "
687683
if roi.begin[dim] is not None and roi.end[dim] is not None:
688-
query += f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}"
684+
query += (
685+
f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}"
686+
)
689687
elif roi.begin[dim] is not None:
690688
query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}"
691689
elif roi.begin[dim] is not None:

funlib/persistence/graphs/sqlite_graph_database.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ def __init__(
5151
def node_array_columns(self):
5252
if not self._node_array_columns:
5353
self._node_array_columns = {
54-
attr: [
55-
f"{attr}_{d}" for d in range(attr_type.size)
56-
]
54+
attr: [f"{attr}_{d}" for d in range(attr_type.size)]
5755
for attr, attr_type in self.node_attrs.items()
5856
if isinstance(attr_type, Vec)
5957
}
@@ -63,9 +61,7 @@ def node_array_columns(self):
6361
def edge_array_columns(self):
6462
if not self._edge_array_columns:
6563
self._edge_array_columns = {
66-
attr: [
67-
f"{attr}_{d}" for d in range(attr_type.size)
68-
]
64+
attr: [f"{attr}_{d}" for d in range(attr_type.size)]
6965
for attr, attr_type in self.edge_attrs.items()
7066
if isinstance(attr_type, Vec)
7167
}
@@ -100,16 +96,16 @@ def _create_tables(self) -> None:
10096
f"{', '.join(node_columns)}"
10197
")"
10298
)
103-
if self.ndims > 1:
99+
if self.ndims > 1: # type: ignore
104100
position_columns = self.node_array_columns[self.position_attribute]
105101
else:
106102
position_columns = self.position_attribute
107103
self.cur.execute(
108104
f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})"
109105
)
110106
edge_columns = [
111-
f"{self.endpoint_names[0]} INTEGER not null",
112-
f"{self.endpoint_names[1]} INTEGER not null",
107+
f"{self.endpoint_names[0]} INTEGER not null", # type: ignore
108+
f"{self.endpoint_names[1]} INTEGER not null", # type: ignore
113109
]
114110
for attr in self.edge_attrs.keys():
115111
if attr in self.edge_array_columns:
@@ -119,7 +115,7 @@ def _create_tables(self) -> None:
119115
self.cur.execute(
120116
f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}("
121117
+ f"{', '.join(edge_columns)}"
122-
+ f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})"
118+
+ f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" # type: ignore
123119
+ ")"
124120
)
125121

@@ -142,7 +138,7 @@ def _select_query(self, query):
142138
#
143139
# If SQL dialects allow array element access, they start counting at 1.
144140
# We don't want that, we start counting at 0 like normal people.
145-
query = re.sub(r'\[(\d+)\]', lambda m: "_" + str(int(m.group(1)) - 1), query)
141+
query = re.sub(r"\[(\d+)\]", lambda m: "_" + str(int(m.group(1)) - 1), query)
146142

147143
try:
148144
return self.cur.execute(query)
@@ -201,9 +197,7 @@ def _node_attrs_to_columns(self, attrs):
201197
for attr in attrs:
202198
attr_type = self.node_attrs[attr]
203199
if isinstance(attr_type, Vec):
204-
columns += [
205-
f"{attr}_{d}" for d in range(attr_type.size)
206-
]
200+
columns += [f"{attr}_{d}" for d in range(attr_type.size)]
207201
else:
208202
columns.append(attr)
209203
return columns
@@ -213,8 +207,7 @@ def _columns_to_node_attrs(self, columns, query_attrs):
213207
for attr in query_attrs:
214208
if attr in self.node_array_columns:
215209
value = tuple(
216-
columns[f"{attr}_{d}"]
217-
for d in range(self.node_attrs[attr].size)
210+
columns[f"{attr}_{d}"] for d in range(self.node_attrs[attr].size)
218211
)
219212
else:
220213
value = columns[attr]
@@ -226,9 +219,7 @@ def _edge_attrs_to_columns(self, attrs):
226219
for attr in attrs:
227220
attr_type = self.edge_attrs[attr]
228221
if isinstance(attr_type, Vec):
229-
columns += [
230-
f"{attr}_{d}" for d in range(attr_type.size)
231-
]
222+
columns += [f"{attr}_{d}" for d in range(attr_type.size)]
232223
else:
233224
columns.append(attr)
234225
return columns
@@ -238,8 +229,7 @@ def _columns_to_edge_attrs(self, columns, query_attrs):
238229
for attr in query_attrs:
239230
if attr in self.edge_array_columns:
240231
value = tuple(
241-
columns[f"{attr}_{d}"]
242-
for d in range(self.edge_attrs[attr].size)
232+
columns[f"{attr}_{d}"] for d in range(self.edge_attrs[attr].size)
243233
)
244234
else:
245235
value = columns[attr]

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,7 @@ ignore_missing_imports = True
1111
ignore_missing_imports = True
1212

1313
[mypy-h5py.*]
14+
ignore_missing_imports = True
15+
16+
[mypy-psycopg2.*]
1417
ignore_missing_imports = True

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies = [
2424
"pymongo",
2525
"numpy",
2626
"h5py",
27-
"psycopg2",
27+
"psycopg2-binary",
2828
]
2929

3030
[tool.setuptools.dynamic]

0 commit comments

Comments
 (0)