Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import binascii
from collections.abc import Callable, Sequence
from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

import cloudpickle
import numpy as np
Expand All @@ -11,6 +11,7 @@
from polars._typing import SchemaDict
from polars.datatypes.convert import numpy_char_code_to_dtype
from sqlalchemy.orm import DeclarativeBase, Session, aliased, load_only
from sqlalchemy.orm.query import Query
from sqlalchemy.sql.type_api import TypeEngine

from tracksdata.attrs import AttrComparison, split_attr_comps
Expand All @@ -26,6 +27,9 @@
from tracksdata.graph._graph_view import GraphView


T = TypeVar("T")


def _is_builtin(obj: Any) -> bool:
"""Check if an object is a built-in type."""
return getattr(obj.__class__, "__module__", None) == "builtins"
Expand Down Expand Up @@ -725,7 +729,7 @@ def bulk_add_nodes(
node[DEFAULT_ATTR_KEYS.NODE_ID] = node_id
node_ids.append(node_id)

self._chunked_sa_operation(Session.bulk_insert_mappings, self.Node, nodes)
self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node)

if is_signal_on(self.node_added):
for node_id in node_ids:
Expand Down Expand Up @@ -884,7 +888,7 @@ def bulk_add_edges(
return list(result.scalars().all())

else:
self._chunked_sa_operation(Session.bulk_insert_mappings, self.Edge, edges)
self._chunked_sa_write(Session.bulk_insert_mappings, edges, self.Edge)
return None

def add_overlap(
Expand Down Expand Up @@ -938,7 +942,7 @@ def bulk_add_overlaps(
overlaps = overlaps.tolist()

overlaps = [{"source_id": source_id, "target_id": target_id} for source_id, target_id in overlaps]
self._chunked_sa_operation(Session.bulk_insert_mappings, self.Overlap, overlaps)
self._chunked_sa_write(Session.bulk_insert_mappings, overlaps, self.Overlap)

def overlaps(
self,
Expand Down Expand Up @@ -1037,14 +1041,19 @@ def _get_neighbors(

query = session.query(getattr(self.Edge, node_key), *node_columns)
query = query.join(self.Edge, getattr(self.Edge, neighbor_key) == self.Node.node_id)
if filter_node_ids is not None:
query = query.filter(getattr(self.Edge, node_key).in_(filter_node_ids))

node_df = pl.read_database(
query.statement,
connection=session.connection(),
schema_overrides=self._polars_schema_override(self.Node),
)
if filter_node_ids is None or len(filter_node_ids) == 0:
node_df = pl.read_database(
query.statement,
connection=session.connection(),
schema_overrides=self._polars_schema_override(self.Node),
)
else:
node_df = self._chunked_sa_read(
session,
lambda x: query.filter(getattr(self.Edge, node_key).in_(x)),
filter_node_ids,
self.Node,
)
node_df = unpickle_bytes_columns(node_df)
node_df = self._cast_array_columns(self.Node, node_df)

Expand Down Expand Up @@ -1635,13 +1644,13 @@ def _update_table(
LOG.info("update %s table with %d rows", table_class.__table__, len(update_data))
LOG.info("update data sample: %s", update_data[:2])

self._chunked_sa_operation(Session.bulk_update_mappings, table_class, update_data)
self._chunked_sa_write(Session.bulk_update_mappings, update_data, table_class)

def _chunked_sa_operation(
def _chunked_sa_write(
self,
session_op: Callable[[Session, type[DeclarativeBase], list[dict[str, Any]]], None],
table_class: type[DeclarativeBase],
data: list[dict[str, Any]],
table_class: type[DeclarativeBase],
) -> None:
if len(data) == 0:
return
Expand All @@ -1654,6 +1663,55 @@ def _chunked_sa_operation(
session_op(session, table_class, data[i : i + chunk_size])
session.commit()

def _chunked_sa_read(
self,
session: Session,
query_filter_op: Callable[[T], Query],
data: list[T],
table_class: type[DeclarativeBase],
) -> pl.DataFrame:
"""
Apply a query filter in chunks and concatenate the results.

Parameters
----------
session : Session
The SQLAlchemy session.
query_filter_op : Callable[[T], Query]
The function to apply a query filter to the data. It must return a SQLAlchemy Query object.
data : list[T]
List of data to passed into the query_filter_op function.
table_class : type[DeclarativeBase]
The SQLAlchemy table class.

Examples
--------
```python
data = [1, 2, 3, 4, 5]
query_filter_op = lambda x: query.filter(x.id.in_(data))
data_df = self._chunked_sa_read(session, query_filter_op, data, Node)
```

Returns
-------
pl.DataFrame
The data as a Polars DataFrame.
"""
if len(data) == 0:
raise ValueError("Data is empty")

chunk_size = max(1, self._sql_chunk_size())
chunks = []
for i in range(0, len(data), chunk_size):
query = query_filter_op(data[i : i + chunk_size])
data_df = pl.read_database(
query.statement,
connection=session.connection(),
schema_overrides=self._polars_schema_override(table_class),
)
chunks.append(data_df)
return pl.concat(chunks)

def update_node_attrs(
self,
*,
Expand Down
5 changes: 4 additions & 1 deletion src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2497,7 +2497,7 @@ def test_pickle_roundtrip(graph_backend: BaseGraph) -> None:


@pytest.mark.slow
def test_sql_graph_huge_update() -> None:
def test_sql_graph_huge_dataset() -> None:
# test is only executed if `--slow` is passed to pytest
graph = SQLGraph("sqlite", ":memory:")

Expand All @@ -2519,6 +2519,9 @@ def test_sql_graph_huge_update() -> None:
node_ids=graph.node_ids(),
)

# testing if successors works with huge dataset
graph.successors(graph.node_ids())


def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None:
pytest.importorskip("traccuracy")
Expand Down
Loading