diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 0fdc22f7..982d38c1 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,7 +1,7 @@ import binascii from collections.abc import Callable, Sequence from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar import cloudpickle import numpy as np @@ -11,6 +11,7 @@ from polars._typing import SchemaDict from polars.datatypes.convert import numpy_char_code_to_dtype from sqlalchemy.orm import DeclarativeBase, Session, aliased, load_only +from sqlalchemy.orm.query import Query from sqlalchemy.sql.type_api import TypeEngine from tracksdata.attrs import AttrComparison, split_attr_comps @@ -26,6 +27,9 @@ from tracksdata.graph._graph_view import GraphView +T = TypeVar("T") + + def _is_builtin(obj: Any) -> bool: """Check if an object is a built-in type.""" return getattr(obj.__class__, "__module__", None) == "builtins" @@ -725,7 +729,7 @@ def bulk_add_nodes( node[DEFAULT_ATTR_KEYS.NODE_ID] = node_id node_ids.append(node_id) - self._chunked_sa_operation(Session.bulk_insert_mappings, self.Node, nodes) + self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node) if is_signal_on(self.node_added): for node_id in node_ids: @@ -884,7 +888,7 @@ def bulk_add_edges( return list(result.scalars().all()) else: - self._chunked_sa_operation(Session.bulk_insert_mappings, self.Edge, edges) + self._chunked_sa_write(Session.bulk_insert_mappings, edges, self.Edge) return None def add_overlap( @@ -938,7 +942,7 @@ def bulk_add_overlaps( overlaps = overlaps.tolist() overlaps = [{"source_id": source_id, "target_id": target_id} for source_id, target_id in overlaps] - self._chunked_sa_operation(Session.bulk_insert_mappings, self.Overlap, overlaps) + self._chunked_sa_write(Session.bulk_insert_mappings, overlaps, self.Overlap) def overlaps( self, @@ -1037,14 +1041,19 @@ def _get_neighbors( query = session.query(getattr(self.Edge, node_key), *node_columns) query = query.join(self.Edge, getattr(self.Edge, neighbor_key) == self.Node.node_id) - if filter_node_ids is not None: - query = query.filter(getattr(self.Edge, node_key).in_(filter_node_ids)) - - node_df = pl.read_database( - query.statement, - connection=session.connection(), - schema_overrides=self._polars_schema_override(self.Node), - ) + if filter_node_ids is None or len(filter_node_ids) == 0: + node_df = pl.read_database( + query.statement, + connection=session.connection(), + schema_overrides=self._polars_schema_override(self.Node), + ) + else: + node_df = self._chunked_sa_read( + session, + lambda x: query.filter(getattr(self.Edge, node_key).in_(x)), + filter_node_ids, + self.Node, + ) node_df = unpickle_bytes_columns(node_df) node_df = self._cast_array_columns(self.Node, node_df) @@ -1635,13 +1644,13 @@ def _update_table( LOG.info("update %s table with %d rows", table_class.__table__, len(update_data)) LOG.info("update data sample: %s", update_data[:2]) - self._chunked_sa_operation(Session.bulk_update_mappings, table_class, update_data) + self._chunked_sa_write(Session.bulk_update_mappings, update_data, table_class) - def _chunked_sa_operation( + def _chunked_sa_write( self, session_op: Callable[[Session, type[DeclarativeBase], list[dict[str, Any]]], None], - table_class: type[DeclarativeBase], data: list[dict[str, Any]], + table_class: type[DeclarativeBase], ) -> None: if len(data) == 0: return @@ -1654,6 +1663,55 @@ def _chunked_sa_operation( session_op(session, table_class, data[i : i + chunk_size]) session.commit() + def _chunked_sa_read( + self, + session: Session, + query_filter_op: Callable[[T], Query], + data: list[T], + table_class: type[DeclarativeBase], + ) -> pl.DataFrame: + """ + Apply a query filter in chunks and concatenate the results. + + Parameters + ---------- + session : Session + The SQLAlchemy session. + query_filter_op : Callable[[T], Query] + The function to apply a query filter to the data. It must return a SQLAlchemy Query object. + data : list[T] + List of data to passed into the query_filter_op function. + table_class : type[DeclarativeBase] + The SQLAlchemy table class. + + Examples + -------- + ```python + data = [1, 2, 3, 4, 5] + query_filter_op = lambda x: query.filter(x.id.in_(data)) + data_df = self._chunked_sa_read(session, query_filter_op, data, Node) + ``` + + Returns + ------- + pl.DataFrame + The data as a Polars DataFrame. + """ + if len(data) == 0: + raise ValueError("Data is empty") + + chunk_size = max(1, self._sql_chunk_size()) + chunks = [] + for i in range(0, len(data), chunk_size): + query = query_filter_op(data[i : i + chunk_size]) + data_df = pl.read_database( + query.statement, + connection=session.connection(), + schema_overrides=self._polars_schema_override(table_class), + ) + chunks.append(data_df) + return pl.concat(chunks) + def update_node_attrs( self, *, diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 1420ee44..8ed8eea6 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -2497,7 +2497,7 @@ def test_pickle_roundtrip(graph_backend: BaseGraph) -> None: @pytest.mark.slow -def test_sql_graph_huge_update() -> None: +def test_sql_graph_huge_dataset() -> None: # test is only executed if `--slow` is passed to pytest graph = SQLGraph("sqlite", ":memory:") @@ -2519,6 +2519,9 @@ def test_sql_graph_huge_update() -> None: node_ids=graph.node_ids(), ) + # testing if successors works with huge dataset + graph.successors(graph.node_ids()) + def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None: pytest.importorskip("traccuracy")