diff --git a/paddle_geometric/data/collate.py b/paddle_geometric/data/collate.py index 5d43672..3ff1e8e 100644 --- a/paddle_geometric/data/collate.py +++ b/paddle_geometric/data/collate.py @@ -15,6 +15,7 @@ import paddle from paddle import Tensor +import paddle_geometric.typing from paddle_geometric import EdgeIndex, Index from paddle_geometric.data.data import BaseData from paddle_geometric.data.storage import BaseStorage, NodeStorage @@ -137,9 +138,18 @@ def _collate( if increment: incs = get_incs(key, values, data_list, stores) if incs.ndim > 1 or int(incs[-1]) != 0: - values = [value + inc for value, inc in zip(values, incs)] + values = [value + inc.to(value.device) if hasattr(inc, 'to') else value + inc + for value, inc in zip(values, incs)] else: incs = None + + if getattr(elem, 'is_nested', False): + tensors = [] + for nested_tensor in values: + tensors.extend(paddle.unbind(nested_tensor, axis=0)) + value = paddle.nn.utils.rnn.pad_sequence(tensors) + return value, slices, incs + value = paddle.concat(values, axis=cat_dim or 0) if increment and isinstance(value, Index) and values[0].is_sorted: # Check whether the whole `Index` is sorted: diff --git a/paddle_geometric/data/database.py b/paddle_geometric/data/database.py index 0d8e659..4317c17 100644 --- a/paddle_geometric/data/database.py +++ b/paddle_geometric/data/database.py @@ -4,15 +4,18 @@ from dataclasses import dataclass from functools import cached_property from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - +from paddle_geometric.utils.mixin import CastMixin import paddle from paddle import Tensor from tqdm import tqdm import pickle # Used for serializing and deserializing complex objects +from paddle_geometric import EdgeIndex, Index +from paddle_geometric.edge_index import SortOrder + @dataclass -class TensorInfo: +class TensorInfo(CastMixin): """Describes the type information of a tensor, including data type, size, whether it's an index or an edge index.""" dtype: paddle.dtype @@ -41,11 +44,18 @@ def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]: valid_keys = {'dtype', 'size', 'is_index', 'is_edge_index'} if len(set(value.keys()) | valid_keys) != len(valid_keys): return value - return TensorInfo(**value) + return TensorInfo.cast(value) Schema = Union[Any, Dict[str, Any], Tuple[Any], List[Any]] +SORT_ORDER_TO_INDEX: Dict[Optional[SortOrder], int] = { + None: -1, + SortOrder.ROW: 0, + SortOrder.COL: 1, +} +INDEX_TO_SORT_ORDER = {v: k for k, v in SORT_ORDER_TO_INDEX.items()} + class Database(ABC): """Abstract base class for a database that supports inserting and retrieving data. @@ -56,15 +66,57 @@ def __init__(self, schema: Schema = object) -> None: schema_dict = self._to_dict(schema) self.schema: Dict[Union[str, int], Any] = schema_dict + @abstractmethod + def connect(self) -> None: + """Connects to the database. + Databases will automatically connect on instantiation. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Closes the connection to the database.""" + raise NotImplementedError + @abstractmethod def insert(self, index: int, data: Any) -> None: """Insert data at a specified index.""" raise NotImplementedError - def multi_insert(self, indices: Union[Sequence[int], slice], data_list: Sequence[Any]) -> None: + def multi_insert( + self, + indices: Union[Sequence[int], Tensor, slice, range], + data_list: Sequence[Any], + batch_size: Optional[int] = None, + log: bool = False, + ) -> None: """Insert multiple data entries at specified indices.""" if isinstance(indices, slice): indices = self.slice_to_range(indices) + + length = min(len(indices), len(data_list)) + batch_size = length if batch_size is None else batch_size + + if log and length > batch_size: + desc = f'Insert {length} entries' + offsets = tqdm(range(0, length, batch_size), desc=desc) + else: + offsets = range(0, length, batch_size) + + for start in offsets: + self._multi_insert( + indices[start:start + batch_size], + data_list[start:start + batch_size], + ) + + def _multi_insert( + self, + indices: Union[Sequence[int], Tensor, range], + data_list: Sequence[Any], + ) -> None: + """Internal method for batch insertion.""" + if isinstance(indices, Tensor): + indices = indices.tolist() for index, data in zip(indices, data_list): self.insert(index, data) @@ -73,10 +125,28 @@ def get(self, index: int) -> Any: """Retrieve data from a specified index.""" raise NotImplementedError - def multi_get(self, indices: Union[Sequence[int], slice]) -> List[Any]: + def multi_get( + self, + indices: Union[Sequence[int], Tensor, slice, range], + batch_size: Optional[int] = None, + ) -> List[Any]: """Retrieve data from multiple indices.""" if isinstance(indices, slice): indices = self.slice_to_range(indices) + + length = len(indices) + batch_size = length if batch_size is None else batch_size + + data_list: List[Any] = [] + for start in range(0, length, batch_size): + chunk_indices = indices[start:start + batch_size] + data_list.extend(self._multi_get(chunk_indices)) + return data_list + + def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]: + """Internal method for batch retrieval.""" + if isinstance(indices, Tensor): + indices = indices.tolist() return [self.get(index) for index in indices] @staticmethod @@ -99,13 +169,20 @@ def __len__(self) -> int: """Return the number of entries in the database.""" raise NotImplementedError - def __getitem__(self, key: Union[int, Sequence[int], slice]) -> Union[Any, List[Any]]: + def __getitem__( + self, + key: Union[int, Sequence[int], Tensor, slice, range], + ) -> Union[Any, List[Any]]: """Retrieve data using index or slice.""" if isinstance(key, int): return self.get(key) return self.multi_get(key) - def __setitem__(self, key: Union[int, Sequence[int], slice], value: Union[Any, Sequence[Any]]) -> None: + def __setitem__( + self, + key: Union[int, Sequence[int], Tensor, slice, range], + value: Union[Any, Sequence[Any]], + ) -> None: """Insert data using index or slice.""" if isinstance(key, int): self.insert(key, value) @@ -126,6 +203,9 @@ class SQLiteDatabase(Database): """ def __init__(self, path: str, name: str, schema: Schema = object) -> None: super().__init__(schema) + + warnings.filterwarnings('ignore', '.*given buffer is not writable.*') + import sqlite3 self.path = path self.name = name @@ -133,11 +213,16 @@ def __init__(self, path: str, name: str, schema: Schema = object) -> None: self._cursor: Optional[sqlite3.Cursor] = None self.connect() - # Create table if it does not exist - schema_str = ", ".join( - f"{key} BLOB NOT NULL" for key in self.schema.keys() - ) - query = f"CREATE TABLE IF NOT EXISTS {self.name} (id INTEGER PRIMARY KEY, {schema_str})" + # Create the table (if it does not exist) by mapping the Python schema + # to the corresponding SQL schema: + sql_schema = ',\n'.join([ + f' {col_name} {self._to_sql_type(type_info)}' + for col_name, type_info in zip(self._col_names, self.schema.values()) + ]) + query = (f'CREATE TABLE IF NOT EXISTS {self.name} (\n' + f' id INTEGER PRIMARY KEY,\n' + f'{sql_schema}\n' + f')') self.cursor.execute(query) def connect(self) -> None: @@ -170,32 +255,136 @@ def cursor(self) -> Any: def insert(self, index: int, data: Any) -> None: """Insert a single data entry.""" - query = f"INSERT INTO {self.name} (id, {', '.join(self.schema.keys())}) VALUES (?, {', '.join(['?'] * len(self.schema))})" + query = (f'INSERT INTO {self.name} ' + f'(id, {self._joined_col_names}) ' + f'VALUES (?, {self._dummies})') self.cursor.execute(query, (index, *self._serialize(data))) self.connection.commit() + def _multi_insert( + self, + indices: Union[Sequence[int], Tensor, range], + data_list: Sequence[Any], + ) -> None: + if isinstance(indices, Tensor): + indices = indices.tolist() + + data_list = [(index, *self._serialize(data)) + for index, data in zip(indices, data_list)] + + query = (f'INSERT INTO {self.name} ' + f'(id, {self._joined_col_names}) ' + f'VALUES (?, {self._dummies})') + self.cursor.executemany(query, data_list) + self.connection.commit() + def get(self, index: int) -> Any: """Retrieve a single data entry.""" - query = f"SELECT {', '.join(self.schema.keys())} FROM {self.name} WHERE id = ?" + query = (f'SELECT {self._joined_col_names} FROM {self.name} ' + f'WHERE id = ?') self.cursor.execute(query, (index,)) row = self.cursor.fetchone() if row is None: raise KeyError(f"Index {index} not found in database") return self._deserialize(row) + def multi_get( + self, + indices: Union[Sequence[int], Tensor, slice, range], + batch_size: Optional[int] = None, + ) -> List[Any]: + if isinstance(indices, slice): + indices = self.slice_to_range(indices) + elif isinstance(indices, Tensor): + indices = indices.tolist() + + join_table_name = f'{self.name}__join' + query = (f'CREATE TEMP TABLE {join_table_name} (\n' + f' id INTEGER,\n' + f' row_id INTEGER\n' + f')') + self.cursor.execute(query) + + query = f'INSERT INTO {join_table_name} (id, row_id) VALUES (?, ?)' + self.cursor.executemany(query, zip(indices, range(len(indices)))) + self.connection.commit() + + query = (f'SELECT {self._joined_col_names} ' + f'FROM {self.name} INNER JOIN {join_table_name} ' + f'ON {self.name}.id = {join_table_name}.id ' + f'ORDER BY {join_table_name}.row_id') + self.cursor.execute(query) + + if batch_size is None: + data_list = self.cursor.fetchall() + else: + data_list = [] + while True: + chunk_list = self.cursor.fetchmany(size=batch_size) + if len(chunk_list) == 0: + break + data_list.extend(chunk_list) + + query = f'DROP TABLE {join_table_name}' + self.cursor.execute(query) + + return [self._deserialize(data) for data in data_list] + def __len__(self) -> int: """Get the total number of entries in the database.""" query = f"SELECT COUNT(*) FROM {self.name}" self.cursor.execute(query) return self.cursor.fetchone()[0] + # Helper functions ######################################################## + + @cached_property + def _col_names(self) -> List[str]: + return [f'COL_{key}' for key in self.schema.keys()] + + @cached_property + def _joined_col_names(self) -> str: + return ', '.join(self._col_names) + + @cached_property + def _dummies(self) -> str: + return ', '.join(['?'] * len(self.schema.keys())) + + def _to_sql_type(self, type_info: Any) -> str: + if type_info == int: + return 'INTEGER NOT NULL' + if type_info == float: + return 'FLOAT' + if type_info == str: + return 'TEXT NOT NULL' + return 'BLOB NOT NULL' + def _serialize(self, data: Any) -> List[bytes]: """Serialize data into a byte stream.""" - return [pickle.dumps(data.get(key)) for key in self.schema.keys()] + # Handle both dict-like data and single tensor/data + if isinstance(data, dict): + return [pickle.dumps(data.get(key)) for key in self.schema.keys()] + elif len(self.schema) == 1 and 0 in self.schema: + # Single object schema: {0: type}, data is a single value (e.g., Tensor) + return [pickle.dumps(data)] + else: + # Fallback: try to access as dict or use data directly + return [ + pickle.dumps(data.get(key) if hasattr(data, 'get') else data) + for key in self.schema.keys() + ] - def _deserialize(self, row: Tuple[bytes]) -> Dict[str, Any]: + def _deserialize(self, row: Tuple[bytes]) -> Any: """Deserialize a byte stream into original data.""" - return {key: pickle.loads(value) for key, value in zip(self.schema.keys(), row)} + result = { + key: pickle.loads(value) + for key, value in zip(self.schema.keys(), row) + } + + # If schema has only one key (0), return the single value instead of dict + if len(result) == 1 and 0 in result: + return result[0] + return result class RocksDatabase(Database): @@ -242,6 +431,15 @@ def get(self, index: int) -> Any: """Retrieve a single data entry.""" return self._deserialize(self.db[self.to_key(index)]) + def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]: + """RocksDB 批量 key 查询""" + if isinstance(indices, Tensor): + indices = indices.tolist() + + # rocksdict.Rdict 逐个获取 key + return [self._deserialize(self.db[self.to_key(index)]) + for index in indices] + def _serialize(self, data: Any) -> bytes: """Serialize data into a byte stream.""" buffer = io.BytesIO() diff --git a/paddle_geometric/data/dataset.py b/paddle_geometric/data/dataset.py index 46c6b55..dee0a2a 100644 --- a/paddle_geometric/data/dataset.py +++ b/paddle_geometric/data/dataset.py @@ -19,6 +19,8 @@ import paddle from paddle import Tensor + + from paddle_geometric.data.data import BaseData from paddle_geometric.io import fs diff --git a/paddle_geometric/data/extract.py b/paddle_geometric/data/extract.py index a87788a..7e3bb9d 100644 --- a/paddle_geometric/data/extract.py +++ b/paddle_geometric/data/extract.py @@ -28,7 +28,7 @@ def extract_tar( """ maybe_log(path, log) with tarfile.open(path, mode) as f: - f.extractall(folder) + f.extractall(folder, filter='data') def extract_zip(path: str, folder: str, log: bool = True) -> None: diff --git a/paddle_geometric/data/temporal.py b/paddle_geometric/data/temporal.py index 00c73d7..9c4a527 100644 --- a/paddle_geometric/data/temporal.py +++ b/paddle_geometric/data/temporal.py @@ -188,6 +188,19 @@ def size( size = (int(self.src.max()), int(self.dst.max())) return size if dim is None else size[dim] + def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: + return 0 + + def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: + if 'batch' in key and isinstance(value, Tensor): + return int(value.max().item()) + 1 + elif key in ['src', 'dst']: + return self.num_nodes + else: + return 0 + + ########################################################################### + def train_val_test_split(self, val_ratio: float = 0.15, test_ratio: float = 0.15): """Splits the data into training, validation, and test sets based on @@ -205,6 +218,22 @@ def __repr__(self) -> str: cls = self.__class__.__name__ info = ', '.join([size_repr(k, v) for k, v in self._store.items()]) return f'{cls}({info})' + ########################################################################### + + def coalesce(self): + raise NotImplementedError + + def has_isolated_nodes(self) -> bool: + raise NotImplementedError + + def has_self_loops(self) -> bool: + raise NotImplementedError + + def is_undirected(self) -> bool: + raise NotImplementedError + + def is_directed(self) -> bool: + raise NotImplementedError ############################################################################### diff --git a/test/data/test_batch.py b/test/data/test_batch.py new file mode 100644 index 0000000..afd3c9c --- /dev/null +++ b/test/data/test_batch.py @@ -0,0 +1,615 @@ +import os.path as osp +import tempfile + +import numpy as np +import paddle +import pytest + +import paddle_geometric +from paddle_geometric import EdgeIndex, Index +from paddle_geometric.data import Batch, Data, HeteroData +from paddle_geometric.testing import withPackage + + +def get_random_edge_index(num_src, num_dst, num_edges): + """ + Generate random edge index tensor for testing. + + Args: + num_src: Number of source nodes + num_dst: Number of destination nodes + num_edges: Number of edges + + Returns: + Edge index tensor of shape (2, num_edges) + """ + edge_index = paddle.stack([ + paddle.randint(0, num_src, (num_edges,)) + for _ in range(2) + ]) + # Set second dimension to destination range + edge_index[1] = paddle.randint(0, num_dst, (num_edges,)) + return edge_index + + +def test_batch_basic(): + paddle_geometric.set_debug(True) + + x = paddle.to_tensor([1.0, 2.0, 3.0]) + edge_index = paddle.to_tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + data1 = Data(x=x, y=1, edge_index=edge_index, string='1', array=['1', '2'], + num_nodes=3) + + x = paddle.to_tensor([1.0, 2.0]) + edge_index = paddle.to_tensor([[0, 1], [1, 0]]) + data2 = Data(x=x, y=2, edge_index=edge_index, string='2', + array=['3', '4', '5'], num_nodes=2) + + x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0]) + edge_index = paddle.to_tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + data3 = Data(x=x, y=3, edge_index=edge_index, string='3', + array=['6', '7', '8', '9'], num_nodes=4) + + # Test single data batch + batch = Batch.from_data_list([data1]) + assert str(batch) == ('DataBatch(x=[3], edge_index=[2, 4], y=[1], ' + 'string=[1], array=[1], num_nodes=3, batch=[3], ' + 'ptr=[2])') + assert batch.num_graphs == len(batch) == 1 + assert batch.x.tolist() == [1, 2, 3] + assert batch.y.tolist() == [1] + assert batch.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert batch.string == ['1'] + assert batch.array == [['1', '2']] + assert batch.num_nodes == 3 + assert batch.batch.tolist() == [0, 0, 0] + assert batch.ptr.tolist() == [0, 3] + + # Test multi-data batch with follow_batch + batch = Batch.from_data_list([data1, data2, data3], follow_batch=['string']) + + assert str(batch) == ('DataBatch(x=[9], edge_index=[2, 12], y=[3], ' + 'string=[3], string_batch=[3], string_ptr=[4], ' + 'array=[3], num_nodes=9, batch=[9], ptr=[4])') + assert batch.num_graphs == len(batch) == 3 + assert batch.x.tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] + assert batch.y.tolist() == [1, 2, 3] + assert batch.edge_index.tolist() == [[0, 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8], + [1, 0, 2, 1, 4, 3, 6, 5, 7, 6, 8, 7]] + assert batch.string == ['1', '2', '3'] + assert batch.string_batch.tolist() == [0, 1, 2] + assert batch.string_ptr.tolist() == [0, 1, 2, 3] + assert batch.array == [['1', '2'], ['3', '4', '5'], ['6', '7', '8', '9']] + assert batch.num_nodes == 9 + assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] + assert batch.ptr.tolist() == [0, 3, 5, 9] + + # Test batch indexing + assert str(batch[0]) == ("Data(x=[3], edge_index=[2, 4], y=[1], " + "string='1', array=[2], num_nodes=3)") + assert str(batch[1]) == ("Data(x=[2], edge_index=[2, 2], y=[1], " + "string='2', array=[3], num_nodes=2)") + assert str(batch[2]) == ("Data(x=[4], edge_index=[2, 6], y=[1], " + "string='3', array=[4], num_nodes=4)") + + # Test batch selection + assert len(batch.index_select([1, 0])) == 2 + assert len(batch.index_select(paddle.to_tensor([1, 0]))) == 2 + assert len(batch.index_select(paddle.to_tensor([True, False]))) == 1 + assert len(batch.index_select(np.array([1, 0], dtype=np.int64))) == 2 + assert len(batch.index_select(np.array([True, False]))) == 1 + assert len(batch[:2]) == 2 + + # Test to_data_list + data_list = batch.to_data_list() + assert len(data_list) == 3 + + assert len(data_list[0]) == 6 + assert data_list[0].x.tolist() == [1, 2, 3] + assert data_list[0].y.tolist() == [1] + assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert data_list[0].string == '1' + assert data_list[0].array == ['1', '2'] + assert data_list[0].num_nodes == 3 + + assert len(data_list[1]) == 6 + assert data_list[1].x.tolist() == [1, 2] + assert data_list[1].y.tolist() == [2] + assert data_list[1].edge_index.tolist() == [[0, 1], [1, 0]] + assert data_list[1].string == '2' + assert data_list[1].array == ['3', '4', '5'] + assert data_list[1].num_nodes == 2 + + assert len(data_list[2]) == 6 + assert data_list[2].x.tolist() == [1, 2, 3, 4] + assert data_list[2].y.tolist() == [3] + assert data_list[2].edge_index.tolist() == [[0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2]] + assert data_list[2].string == '3' + assert data_list[2].array == ['6', '7', '8', '9'] + assert data_list[2].num_nodes == 4 + + paddle_geometric.set_debug(False) + + +def test_index(): + """Test Index type preservation during batching.""" + index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) + index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True) + + data1 = Data(index=index1, num_nodes=3) + data2 = Data(index=index2, num_nodes=4) + + batch = Batch.from_data_list([data1, data2]) + + assert len(batch) == 2 + # Paddle's equal() returns bool directly, not a tensor + assert bool(batch.batch.equal(paddle.to_tensor([0, 0, 0, 1, 1, 1, 1]))) + assert bool(batch.ptr.equal(paddle.to_tensor([0, 3, 7]))) + assert isinstance(batch.index, Index) + assert bool(batch.index.equal(paddle.to_tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6]))) + assert batch.index.dim_size == 7 + assert batch.index.is_sorted + + # Test unbatching preserves Index attributes + for i, index in enumerate([index1, index2]): + data = batch[i] + assert isinstance(data.index, Index) + assert bool(data.index.equal(index)) + assert data.index.dim_size == index.dim_size + assert data.index.is_sorted == index.is_sorted + + +def test_edge_index(): + """Test EdgeIndex type preservation during batching.""" + edge_index1 = EdgeIndex( + [[0, 1, 1, 2], [1, 0, 2, 1]], + sparse_size=(3, 3), + sort_order='row', + is_undirected=True, + ) + edge_index2 = EdgeIndex( + [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]], + sparse_size=(4, 4), + sort_order='col', + ) + + data1 = Data(edge_index=edge_index1) + data2 = Data(edge_index=edge_index2) + + batch = Batch.from_data_list([data1, data2]) + + assert len(batch) == 2 + # Paddle's equal() returns bool directly, not a tensor + assert bool(batch.batch.equal(paddle.to_tensor([0, 0, 0, 1, 1, 1, 1]))) + assert bool(batch.ptr.equal(paddle.to_tensor([0, 3, 7]))) + assert isinstance(batch.edge_index, EdgeIndex) + assert bool(batch.edge_index.equal( + paddle.to_tensor([ + [0, 1, 1, 2, 4, 3, 5, 4, 6, 5], + [1, 0, 2, 1, 3, 4, 4, 5, 5, 6], + ]))) + assert batch.edge_index.sparse_size() == (7, 7) + assert batch.edge_index.sort_order is None + assert not batch.edge_index.is_undirected + + # Test unbatching preserves EdgeIndex attributes + for i, edge_index in enumerate([edge_index1, edge_index2]): + data = batch[i] + assert isinstance(data.edge_index, EdgeIndex) + assert bool(data.edge_index.equal(edge_index)) + assert data.edge_index.sparse_size() == edge_index.sparse_size() + assert data.edge_index.sort_order == edge_index.sort_order + assert data.edge_index.is_undirected == edge_index.is_undirected + +def test_batch_with_paddle_coo_tensor(): + # Paddle requires sparse_dim parameter for to_sparse_coo() + x = paddle.to_tensor([[1.0], [2.0], [3.0]]).to_sparse_coo(sparse_dim=2) + data1 = Data(x=x) + + x = paddle.to_tensor([[1.0], [2.0]]).to_sparse_coo(sparse_dim=2) + data2 = Data(x=x) + + x = paddle.to_tensor([[1.0], [2.0], [3.0], [4.0]]).to_sparse_coo(sparse_dim=2) + data3 = Data(x=x) + + batch = Batch.from_data_list([data1]) + assert str(batch) == ('DataBatch(x=[3, 1], batch=[3], ptr=[2])') + assert batch.num_graphs == len(batch) == 1 + assert batch.x.to_dense().tolist() == [[1], [2], [3]] + assert batch.batch.tolist() == [0, 0, 0] + assert batch.ptr.tolist() == [0, 3] + + batch = Batch.from_data_list([data1, data2, data3]) + + assert str(batch) == ('DataBatch(x=[9, 1], batch=[9], ptr=[4])') + assert batch.num_graphs == len(batch) == 3 + assert batch.x.to_dense().reshape([-1]).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] + assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] + assert batch.ptr.tolist() == [0, 3, 5, 9] + + assert str(batch[0]) == ("Data(x=[3, 1])") + assert str(batch[1]) == ("Data(x=[2, 1])") + assert str(batch[2]) == ("Data(x=[4, 1])") + + data_list = batch.to_data_list() + assert len(data_list) == 3 + + assert len(data_list[0]) == 1 + assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]] + + assert len(data_list[1]) == 1 + assert data_list[1].x.to_dense().tolist() == [[1], [2]] + + assert len(data_list[2]) == 1 + assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]] + + +def test_batching_with_new_dimension(): + """Test batching with custom __cat_dim__ returning None.""" + paddle_geometric.set_debug(True) + + class MyData(Data): + def __cat_dim__(self, key, value, *args, **kwargs): + if key == 'foo': + return None + else: + return super().__cat_dim__(key, value, *args, **kwargs) + + x1 = paddle.to_tensor([1, 2, 3], dtype=paddle.float32) + foo1 = paddle.randn(4) + y1 = paddle.to_tensor(1) + + x2 = paddle.to_tensor([1, 2], dtype=paddle.float32) + foo2 = paddle.randn(4) + y2 = paddle.to_tensor(2) + + batch = Batch.from_data_list( + [MyData(x=x1, foo=foo1, y=y1), + MyData(x=x2, foo=foo2, y=y2)]) + + assert str(batch) == ('MyDataBatch(x=[5], y=[2], foo=[2, 4], batch=[5], ' + 'ptr=[3])') + assert batch.num_graphs == len(batch) == 2 + assert batch.x.tolist() == [1, 2, 3, 1, 2] + assert batch.foo.shape == [2, 4] + assert batch.foo[0].tolist() == foo1.tolist() + assert batch.foo[1].tolist() == foo2.tolist() + assert batch.y.tolist() == [1, 2] + assert batch.batch.tolist() == [0, 0, 0, 1, 1] + assert batch.ptr.tolist() == [0, 3, 5] + assert batch.num_graphs == 2 + + data = batch[0] + assert str(data) == ('MyData(x=[3], y=[1], foo=[4])') + data = batch[1] + assert str(data) == ('MyData(x=[2], y=[1], foo=[4])') + + paddle_geometric.set_debug(False) + + +def test_pickling(tmp_path): + """Test pickling and unpickling of batch objects.""" + import pickle + + data = Data(x=paddle.randn(5, 16)) + batch = Batch.from_data_list([data, data, data, data]) + assert id(batch._store._parent()) == id(batch) + assert batch.num_nodes == 20 + + # Use Python's pickle instead of paddle.save/load + # Paddle's save/load may not support custom objects properly + path = osp.join(tmp_path, 'batch.pkl') + with open(path, 'wb') as f: + pickle.dump(batch, f) + + assert id(batch._store._parent()) == id(batch) + assert batch.num_nodes == 20 + + with open(path, 'rb') as f: + batch = pickle.load(f) + + assert id(batch._store._parent()) == id(batch) + assert batch.num_nodes == 20 + + assert batch.__class__.__name__ == 'DataBatch' + assert batch.num_graphs == len(batch) == 4 + + +def test_recursive_batch(): + """Test recursive batching with dict and list attributes.""" + data1 = Data( + x={ + '1': paddle.randn(10, 32), + '2': paddle.randn(20, 48) + }, + edge_index=[ + get_random_edge_index(30, 30, 50), + get_random_edge_index(30, 30, 70) + ], + num_nodes=30, + ) + + data2 = Data( + x={ + '1': paddle.randn(20, 32), + '2': paddle.randn(40, 48) + }, + edge_index=[ + get_random_edge_index(60, 60, 80), + get_random_edge_index(60, 60, 90) + ], + num_nodes=60, + ) + + batch = Batch.from_data_list([data1, data2]) + + assert batch.num_graphs == len(batch) == 2 + assert batch.num_nodes == 90 + + assert paddle.allclose(batch.x['1'], + paddle.concat([data1.x['1'], data2.x['1']], axis=0)).item() + assert paddle.allclose(batch.x['2'], + paddle.concat([data1.x['2'], data2.x['2']], axis=0)).item() + assert (batch.edge_index[0].tolist() == paddle.concat( + [data1.edge_index[0], data2.edge_index[0] + 30], axis=1).tolist()) + assert (batch.edge_index[1].tolist() == paddle.concat( + [data1.edge_index[1], data2.edge_index[1] + 30], axis=1).tolist()) + assert batch.batch.shape == [90] + assert batch.ptr.shape == [3] + + out1 = batch[0] + assert len(out1) == 3 + assert out1.num_nodes == 30 + assert paddle.allclose(out1.x['1'], data1.x['1']) + assert paddle.allclose(out1.x['2'], data1.x['2']) + assert out1.edge_index[0].tolist(), data1.edge_index[0].tolist() + assert out1.edge_index[1].tolist(), data1.edge_index[1].tolist() + + out2 = batch[1] + assert len(out2) == 3 + assert out2.num_nodes == 60 + assert paddle.allclose(out2.x['1'], data2.x['1']) + assert paddle.allclose(out2.x['2'], data2.x['2']) + assert out2.edge_index[0].tolist(), data2.edge_index[0].tolist() + assert out2.edge_index[1].tolist(), data2.edge_index[1].tolist() + + +def test_batching_of_batches(): + """Test batching of already batched data.""" + data = Data(x=paddle.randn(2, 16)) + batch = Batch.from_data_list([data, data]) + + batch = Batch.from_data_list([batch, batch]) + assert batch.num_graphs == len(batch) == 2 + assert batch.x[0:2].tolist() == data.x.tolist() + assert batch.x[2:4].tolist() == data.x.tolist() + assert batch.x[4:6].tolist() == data.x.tolist() + assert batch.x[6:8].tolist() == data.x.tolist() + assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] + + +def test_hetero_batch(): + e1 = ('p', 'a') + e2 = ('a', 'p') + data1 = HeteroData() + data1['p'].x = paddle.randn(100, 128) + data1['a'].x = paddle.randn(200, 128) + data1[e1].edge_index = get_random_edge_index(100, 200, 500) + data1[e1].edge_attr = paddle.randn(500, 32) + data1[e2].edge_index = get_random_edge_index(200, 100, 400) + data1[e2].edge_attr = paddle.randn(400, 32) + + data2 = HeteroData() + data2['p'].x = paddle.randn(50, 128) + data2['a'].x = paddle.randn(100, 128) + data2[e1].edge_index = get_random_edge_index(50, 100, 300) + data2[e1].edge_attr = paddle.randn(300, 32) + data2[e2].edge_index = get_random_edge_index(100, 50, 200) + data2[e2].edge_attr = paddle.randn(200, 32) + + batch = Batch.from_data_list([data1, data2]) + + assert batch.num_graphs == len(batch) == 2 + assert batch.num_nodes == 450 + + assert paddle.allclose(batch['p'].x[:100], data1['p'].x) + assert paddle.allclose(batch['a'].x[:200], data1['a'].x) + assert paddle.allclose(batch['p'].x[100:], data2['p'].x) + assert paddle.allclose(batch['a'].x[200:], data2['a'].x) + assert (batch[e1].edge_index.tolist() == paddle.concat([ + data1[e1].edge_index, + data2[e1].edge_index + paddle.to_tensor([[100], [200]]) + ], axis=1).tolist()) + assert paddle.allclose( + batch[e1].edge_attr, + paddle.concat([data1[e1].edge_attr, data2[e1].edge_attr], axis=0)) + assert (batch[e2].edge_index.tolist() == paddle.concat([ + data1[e2].edge_index, + data2[e2].edge_index + paddle.to_tensor([[200], [100]]) + ], axis=1).tolist()) + assert paddle.allclose( + batch[e2].edge_attr, + paddle.concat([data1[e2].edge_attr, data2[e2].edge_attr], axis=0)) + assert batch['p'].batch.shape == [150] + assert batch['p'].ptr.shape == [3] + assert batch['a'].batch.shape == [300] + assert batch['a'].ptr.shape == [3] + + out1 = batch[0] + assert len(out1) == 3 + assert out1.num_nodes == 300 + assert paddle.allclose(out1['p'].x, data1['p'].x) + assert paddle.allclose(out1['a'].x, data1['a'].x) + assert out1[e1].edge_index.tolist() == data1[e1].edge_index.tolist() + assert paddle.allclose(out1[e1].edge_attr, data1[e1].edge_attr) + assert out1[e2].edge_index.tolist() == data1[e2].edge_index.tolist() + assert paddle.allclose(out1[e2].edge_attr, data1[e2].edge_attr) + + out2 = batch[1] + assert len(out2) == 3 + assert out2.num_nodes == 150 + assert paddle.allclose(out2['p'].x, data2['p'].x) + assert paddle.allclose(out2['a'].x, data2['a'].x) + assert out2[e1].edge_index.tolist() == data2[e1].edge_index.tolist() + assert paddle.allclose(out2[e1].edge_attr, data2[e1].edge_attr) + assert out2[e2].edge_index.tolist() == data2[e2].edge_index.tolist() + assert paddle.allclose(out2[e2].edge_attr, data2[e2].edge_attr) + + +def test_pair_data_batching(): + class PairData(Data): + def __inc__(self, key, value, *args, **kwargs): + if key == 'edge_index_s': + return self.x_s.shape[0] + if key == 'edge_index_t': + return self.x_t.shape[0] + return super().__inc__(key, value, *args, **kwargs) + + x_s = paddle.randn(5, 16) + edge_index_s = paddle.to_tensor([ + [0, 0, 0, 0], + [1, 2, 3, 4], + ]) + x_t = paddle.randn(4, 16) + edge_index_t = paddle.to_tensor([ + [0, 0, 0], + [1, 2, 3], + ]) + + data = PairData(x_s=x_s, edge_index_s=edge_index_s, x_t=x_t, + edge_index_t=edge_index_t) + batch = Batch.from_data_list([data, data]) + + assert paddle.allclose(batch.x_s, paddle.concat([x_s, x_s], axis=0)) + assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5], + [1, 2, 3, 4, 6, 7, 8, 9]] + + assert paddle.allclose(batch.x_t, paddle.concat([x_t, x_t], axis=0)) + assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4], + [1, 2, 3, 5, 6, 7]] + + +def test_batch_with_empty_list(): + """Test batching with empty list attributes.""" + x = paddle.randn(4, 1) + edge_index = paddle.to_tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + data = Data(x=x, edge_index=edge_index, nontensor=[]) + + batch = Batch.from_data_list([data, data]) + assert batch.nontensor == [[], []] + assert batch[0].nontensor == [] + assert batch[1].nontensor == [] + + +def test_nested_follow_batch(): + """Test follow_batch with nested structures.""" + def tr(n, m): + return paddle.rand((n, m)) + + d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={"aa": tr(11, 3)}, + x=tr(10, 5)) + d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={"aa": tr(2, 3)}, + x=tr(11, 5)) + d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={"aa": tr(4, 3)}, + x=tr(9, 5)) + d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)}, + x=tr(8, 5)) + + data_list = [d1, d2, d3, d4] + + batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a']) + + assert batch.xs[0].shape == [19, 3] + assert batch.xs[1].shape == [56, 4] + assert batch.xs[2].shape == [7, 2] + assert batch.a['aa'].shape == [25, 3] + + assert len(batch.xs_batch) == 3 + assert len(batch.a_batch) == 1 + + assert batch.xs_batch[0].tolist() == \ + [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + assert batch.xs_batch[1].tolist() == \ + [0] * 11 + [1] * 14 + [2] * 15 + [3] * 16 + assert batch.xs_batch[2].tolist() == \ + [0] * 1 + [1] * 3 + [2] * 2 + [3] * 1 + + assert batch.a_batch['aa'].tolist() == \ + [0] * 11 + [1] * 2 + [2] * 4 + [3] * 8 + + +# ============================================================================ +# Sparse Tensor Tests +# ============================================================================ + +@withPackage('paddle_sparse') +def test_batch_with_sparse_tensor(): + """Test batching with sparse tensors.""" + from paddle_geometric.typing import SparseTensor + + x = SparseTensor.from_dense(paddle.to_tensor([[1.0], [2.0], [3.0]])) + edge_index = paddle.to_tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + adj = SparseTensor.from_edge_index(edge_index) + data1 = Data(x=x, adj=adj) + + x = SparseTensor.from_dense(paddle.to_tensor([[1.0], [2.0]])) + edge_index = paddle.to_tensor([[0, 1], [1, 0]]) + adj = SparseTensor.from_edge_index(edge_index) + data2 = Data(x=x, adj=adj) + + x = SparseTensor.from_dense(paddle.to_tensor([[1.0], [2.0], [3.0], [4.0]])) + edge_index = paddle.to_tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + adj = SparseTensor.from_edge_index(edge_index) + data3 = Data(x=x, adj=adj) + + batch = Batch.from_data_list([data1, data2, data3]) + + assert batch.num_graphs == len(batch) == 3 + + data_list = batch.to_data_list() + assert len(data_list) == 3 + + +@pytest.mark.skip( + reason="Paddle does not have a nested_tensor API equivalent to " + "torch.nested.nested_tensor. " + "Paddle uses pad_sequence for variable-length sequences instead." +) + + +@withPackage('paddle_sparse') +def test_batch_with_sparse_tensor_coo(): + + from paddle_geometric.typing import SparseTensor + + # Create data with COO sparse tensors + x1 = paddle.to_tensor([[1.0], [2.0], [3.0]]).to_sparse_coo() + edge_index1 = paddle.to_tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + adj1 = SparseTensor.from_edge_index(edge_index1) + data1 = Data(x=x1, adj=adj1) + + x2 = paddle.to_tensor([[1.0], [2.0]]).to_sparse_coo() + edge_index2 = paddle.to_tensor([[0, 1], [1, 0]]) + adj2 = SparseTensor.from_edge_index(edge_index2) + data2 = Data(x=x2, adj=adj2) + + # Batch the data + batch = Batch.from_data_list([data1, data2]) + + # Verify batch properties + assert batch.num_graphs == len(batch) == 2 + assert batch.x.shape[0] == 5 # 3 + 2 + + # Verify sparse format is preserved + assert batch.x.is_sparse() + + # Verify edge indices are correctly concatenated + assert batch.adj.coo()[0].tolist() == [0, 1, 1, 2, 3, 4] + assert batch.adj.coo()[1].tolist() == [1, 0, 2, 1, 3, 4] + + # Test unbatching + data_list = batch.to_data_list() + assert len(data_list) == 2 + assert data_list[0].x.to_dense().tolist() == [[1.0], [2.0], [3.0]] + assert data_list[1].x.to_dense().tolist() == [[1.0], [2.0]] diff --git a/test/data/test_data.py b/test/data/test_data.py index 3652d34..58a8535 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -592,3 +592,102 @@ def test_data_time_handling(num_nodes, num_edges): out = out.sort_by_time() assert paddle.equal_all(out.time, data.time.repeat_interleave(2)).item() + + +def test_data_connected_components(): + data = Data() + data.x = paddle.to_tensor([[1.0], [2.0], [3.0], [4.0], [5.0]]) + data.y = paddle.to_tensor([[1.1, 1.2], [2.1, 2.2], [3.1, 3.2], [4.1, 4.2], + [5.1, 5.2]]) + data.edge_index = paddle.to_tensor([[0, 1, 2, 3], [1, 0, 3, 2]], + dtype=paddle.int64) + + split_data = data.connected_components() + assert isinstance(split_data, list) + assert len(split_data) == 3 + + assert paddle.equal_all(split_data[0].x, paddle.to_tensor([[1.0], [2.0]])).item() + assert paddle.equal_all(split_data[0].y, paddle.to_tensor([[1.1, 1.2], [2.1, 2.2]])).item() + assert paddle.equal_all(split_data[0].edge_index, paddle.to_tensor([[0, 1], [1, 0]])).item() + + assert paddle.equal_all(split_data[1].x, paddle.to_tensor([[3.0], [4.0]])).item() + assert paddle.equal_all(split_data[1].y, paddle.to_tensor([[3.1, 3.2], [4.1, 4.2]])).item() + assert paddle.equal_all(split_data[1].edge_index, paddle.to_tensor([[0, 1], [1, 0]])).item() + + assert paddle.equal_all(split_data[2].x, paddle.to_tensor([[5.0]])).item() + assert paddle.equal_all(split_data[2].y, paddle.to_tensor([[5.1, 5.2]])).item() + assert paddle.equal_all(split_data[2].edge_index, + paddle.to_tensor([[], []], dtype=paddle.int64)).item() + + +def test_data_find_parent(): + + # Case 1: Parent does not exist + data = Data() + data._parents = {} + data._ranks = {} + node = 1 + assert data._find_parent(node) == node + assert data._parents == {1: 1} + assert data._ranks == {1: 0} + + # Case 2: Parent exists + data._parents[node] = 0 + assert data._find_parent(node) == 0 + + +def test_data_union(): + + # Setup: two nodes in different sets + data = Data() + data._parents = {} + data._ranks = {} + node1 = 1 + node2 = 2 + + # Initially, both nodes are their own parents with rank 0 + assert data._find_parent(node1) == node1 + assert data._find_parent(node2) == node2 + data._ranks[node1] = 0 + data._ranks[node2] = 0 + + # Union them: node2 should now point to node1, and node1's rank increases + data._union(node1, node2) + assert data._find_parent(node1) == node1 + assert data._find_parent(node2) == node1 + assert data._ranks[node1] == 1 + + # Add a third node with higher rank and union with node1 + node3 = 3 + data._parents[node3] = node3 + data._ranks[node3] = 2 + data._union(node1, node3) + # node1's parent should now be node3, since node3 has higher rank + assert data._find_parent(node1) == node3 + assert data._find_parent(node3) == node3 + + # Add a fourth node with lower rank and union with node3 + node4 = 4 + data._parents[node4] = node4 + data._ranks[node4] = 0 + data._union(node3, node4) + assert data._find_parent(node4) == node3 + assert data._find_parent(node3) == node3 + + # Union of already connected nodes should not change anything + prev_ranks = data._ranks.copy() + prev_parents = data._parents.copy() + data._union(node1, node3) + assert data._ranks == prev_ranks + assert data._parents == prev_parents + + +def test_data_inc(): + data = Data(edge_index=paddle.to_tensor([[0, 1], [1, 0]])) + with pytest.warns(UserWarning, match="Unable to accurately infer"): + assert data.__inc__('edge_index', data.edge_index) == 2 + + data = Data(index=paddle.empty([2, 0], dtype=paddle.int64)) + with pytest.raises(RuntimeError, match="Unable to infer"): + with pytest.warns(UserWarning, match="Unable to accurately infer"): + data.__inc__('index', data.edge_index) diff --git a/test/data/test_database.py b/test/data/test_database.py new file mode 100644 index 0000000..dd017fc --- /dev/null +++ b/test/data/test_database.py @@ -0,0 +1,301 @@ +import math +import os.path as osp + +import paddle +import pytest + +from paddle_geometric import EdgeIndex, Index +from paddle_geometric.data import Data, RocksDatabase, SQLiteDatabase +from paddle_geometric.data.database import TensorInfo +from paddle_geometric.testing import has_package, withPackage + +AVAILABLE_DATABASES = [] +if has_package('sqlite3'): + AVAILABLE_DATABASES.append(SQLiteDatabase) +if has_package('rocksdict'): + AVAILABLE_DATABASES.append(RocksDatabase) + + +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) +@pytest.mark.parametrize('batch_size', [None, 1]) +def test_database_single_tensor(tmp_path, Database, batch_size): + kwargs = dict(path=osp.join(tmp_path, 'storage.db')) + if Database == SQLiteDatabase: + kwargs['name'] = 'test_table' + + db = Database(**kwargs) + assert db.schema == {0: object} + + try: + assert len(db) == 0 + assert str(db) == f'{Database.__name__}(0)' + except NotImplementedError: + assert str(db) == f'{Database.__name__}()' + + data = paddle.randn(5) + db.insert(0, data) + try: + assert len(db) == 1 + except NotImplementedError: + pass + assert paddle.equal_all(db.get(0), data).item() + + indices = paddle.to_tensor([1, 2]) + data_list = paddle.randn(2, 5) + db.multi_insert(indices, data_list, batch_size=batch_size) + try: + assert len(db) == 3 + except NotImplementedError: + pass + out_list = db.multi_get(indices, batch_size=batch_size) + assert isinstance(out_list, list) + assert len(out_list) == 2 + assert paddle.equal_all(out_list[0], data_list[0]).item() + assert paddle.equal_all(out_list[1], data_list[1]).item() + + db.close() + + +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) +def test_database_schema(tmp_path, Database): + kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} + + path = osp.join(tmp_path, 'tuple_storage.db') + schema = (int, float, str, dict(dtype=paddle.float32, size=(2, -1)), object) + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 0: int, + 1: float, + 2: str, + 3: TensorInfo(dtype=paddle.float32, size=(2, -1)), + 4: object, + } + + data1 = (1, 0.1, 'a', paddle.randn(2, 8), Data(x=paddle.randn(8))) + data2 = (2, float('inf'), 'b', paddle.randn(2, 16), Data(x=paddle.randn(8))) + data3 = (3, float('nan'), 'c', paddle.randn(2, 32), Data(x=paddle.randn(8))) + db.insert(0, data1) + db.multi_insert([1, 2], [data2, data3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, data in zip([out1, out2, out3], [data1, data2, data3]): + assert out[0] == data[0] + if math.isnan(data[1]): + assert math.isnan(out[1]) + else: + assert out[1] == data[1] + assert out[2] == data[2] + assert paddle.equal_all(out[3], data[3]).item() + assert isinstance(out[4], Data) and len(out[4]) == 1 + assert paddle.equal_all(out[4].x, data[4].x).item() + + db.close() + + path = osp.join(tmp_path, 'dict_storage.db') + schema = { + 'int': int, + 'float': float, + 'str': str, + 'tensor': dict(dtype=paddle.float32, size=(2, -1)), + 'data': object + } + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 'int': int, + 'float': float, + 'str': str, + 'tensor': TensorInfo(dtype=paddle.float32, size=(2, -1)), + 'data': object, + } + + data1 = { + 'int': 1, + 'float': 0.1, + 'str': 'a', + 'tensor': paddle.randn(2, 8), + 'data': Data(x=paddle.randn(1, 8)), + } + data2 = { + 'int': 2, + 'float': 0.2, + 'str': 'b', + 'tensor': paddle.randn(2, 16), + 'data': Data(x=paddle.randn(2, 8)), + } + data3 = { + 'int': 3, + 'float': 0.3, + 'str': 'c', + 'tensor': paddle.randn(2, 32), + 'data': Data(x=paddle.randn(3, 8)), + } + db.insert(0, data1) + db.multi_insert([1, 2], [data2, data3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, data in zip([out1, out2, out3], [data1, data2, data3]): + assert out['int'] == data['int'] + assert out['float'] == data['float'] + assert out['str'] == data['str'] + assert paddle.equal_all(out['tensor'], data['tensor']).item() + assert isinstance(out['data'], Data) and len(out['data']) == 1 + assert paddle.equal_all(out['data'].x, data['data'].x).item() + + db.close() + + +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) +def test_index(tmp_path, Database): + kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} + + path = osp.join(tmp_path, 'tuple_storage.db') + schema = dict(dtype=paddle.int64, is_index=True) + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 0: TensorInfo(dtype=paddle.int64, is_index=True), + } + + index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) + index2 = Index([0, 1, 1, 2, 2, 3], dim_size=None, is_sorted=True) + index3 = Index([], dtype=paddle.int64) + + db.insert(0, index1) + db.multi_insert([1, 2], [index2, index3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, index in zip([out1, out2, out3], [index1, index2, index3]): + assert index.equal(out).item() + assert index.dtype == out.dtype + assert index.dim_size == out.dim_size + assert index.is_sorted == out.is_sorted + + db.close() + + +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) +def test_edge_index(tmp_path, Database): + kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} + + path = osp.join(tmp_path, 'tuple_storage.db') + schema = dict(dtype=paddle.int64, is_edge_index=True) + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 0: TensorInfo(dtype=paddle.int64, size=(2, -1), is_edge_index=True), + } + + adj1 = EdgeIndex( + [[0, 1, 1, 2], [1, 0, 2, 1]], + sparse_size=(3, 3), + sort_order='row', + is_undirected=True, + ) + adj2 = EdgeIndex( + [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]], + sparse_size=(4, 4), + sort_order='col', + ) + adj3 = EdgeIndex([[], []], dtype=paddle.int64) + + db.insert(0, adj1) + db.multi_insert([1, 2], [adj2, adj3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, adj in zip([out1, out2, out3], [adj1, adj2, adj3]): + assert adj.equal(out).item() + assert adj.dtype == out.dtype + assert adj.sparse_size() == out.sparse_size() + assert adj.sort_order == out.sort_order + assert adj.is_undirected == out.is_undirected + + db.close() + + +@withPackage('sqlite3') +def test_database_syntactic_sugar(tmp_path): + path = osp.join(tmp_path, 'storage.db') + db = SQLiteDatabase(path, name='test_table') + + data = paddle.randn(5, 16) + db[0] = data[0] + db[1:3] = data[1:3] + db[paddle.to_tensor([3, 4])] = data[paddle.to_tensor([3, 4])] + assert len(db) == 5 + + assert paddle.equal_all(db[0], data[0]).item() + assert paddle.equal_all(paddle.stack(db[:3], axis=0), data[:3]).item() + assert paddle.equal_all(paddle.stack(db[3:], axis=0), data[3:]).item() + assert paddle.equal_all(paddle.stack(db[1::2], axis=0), data[1::2]).item() + assert paddle.equal_all(paddle.stack(db[[4, 3]], axis=0), data[[4, 3]]).item() + assert paddle.equal_all( + paddle.stack(db[paddle.to_tensor([4, 3])], axis=0), + data[paddle.to_tensor([4, 3])], + ).item() + assert paddle.equal_all( + paddle.stack(db[paddle.to_tensor([4, 4])], axis=0), + data[paddle.to_tensor([4, 4])], + ).item() + + +if __name__ == '__main__': + import argparse + import tempfile + import time + + parser = argparse.ArgumentParser() + parser.add_argument('--numel', type=int, default=100_000) + parser.add_argument('--batch_size', type=int, default=256) + args = parser.parse_args() + + data = paddle.randn(args.numel, 128) + tmp_dir = tempfile.TemporaryDirectory() + + path = osp.join(tmp_dir.name, 'sqlite.db') + sqlite_db = SQLiteDatabase(path, name='test_table') + t = time.perf_counter() + sqlite_db.multi_insert(range(args.numel), data, batch_size=100, log=True) + print(f'Initialized SQLiteDB in {time.perf_counter() - t:.2f} seconds') + + path = osp.join(tmp_dir.name, 'rocks.db') + rocks_db = RocksDatabase(path) + t = time.perf_counter() + rocks_db.multi_insert(range(args.numel), data, batch_size=100, log=True) + print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds') + + def in_memory_get(data): + index = paddle.randint(0, args.numel, (args.batch_size, )) + return data[index] + + def db_get(db): + index = paddle.randint(0, args.numel, (args.batch_size, )) + return db[index] + + # Paddle Geometric doesn't have benchmark utility like torch_geometric.profile.benchmark + # Implementing simple benchmarking here + num_steps = 50 + num_warmups = 5 + + for _ in range(num_warmups): + in_memory_get(data) + db_get(sqlite_db) + db_get(rocks_db) + + for name, fn, db_arg in [('In-Memory', in_memory_get, data), + ('SQLite', db_get, sqlite_db), + ('RocksDB', db_get, rocks_db)]: + t = time.perf_counter() + for _ in range(num_steps): + if db_arg is data: + fn(db_arg) + else: + fn(db_arg) + print(f'{name}: {time.perf_counter() - t:.6f} seconds [{num_steps} steps]') + + tmp_dir.cleanup() diff --git a/test/data/test_datapipes.py b/test/data/test_datapipes.py new file mode 100644 index 0000000..85de783 --- /dev/null +++ b/test/data/test_datapipes.py @@ -0,0 +1,56 @@ +import pytest + +import paddle + +from paddle_geometric.data import Data +from paddle_geometric.data.datapipes import DatasetAdapter +from paddle_geometric.loader import DataLoader +from paddle_geometric.testing import withPackage +from paddle_geometric.utils import to_smiles + + +@pytest.fixture() +def dataset_adapter() -> DatasetAdapter: + x = paddle.randn(3, 8) + edge_index = paddle.to_tensor([[0, 1, 1], [1, 0, 2]]) + data = Data(x=x, edge_index=edge_index) + return DatasetAdapter([data, data, data, data]) + + +def test_dataset_adapter(dataset_adapter): + loader = DataLoader(dataset_adapter, batch_size=2) + batch = next(iter(loader)) + assert batch.x.shape == (6, 8) + assert len(loader) == 2 + + # Test sharding: + dataset_adapter.apply_sharding(2, 0) + assert len([data for data in dataset_adapter]) == 2 + + assert dataset_adapter.is_shardable() + + +def test_datapipe_batch_graphs(dataset_adapter): + dp = dataset_adapter.batch_graphs(batch_size=2) + assert len(dp) == 2 + batch = next(iter(dp)) + assert batch.x.shape == (6, 8) + + +def test_functional_transform(dataset_adapter): + assert next(iter(dataset_adapter)).is_directed() + dataset_adapter = dataset_adapter.to_undirected() + assert next(iter(dataset_adapter)).is_undirected() + + +@withPackage('rdkit') +def test_datapipe_parse_smiles(): + smiles = 'F/C=C/F' + + dp = DatasetAdapter([smiles]) + dp = dp.parse_smiles() + assert to_smiles(next(iter(dp))) == smiles + + dp = DatasetAdapter([{'abc': smiles, 'cba': '1.0'}]) + dp = dp.parse_smiles(smiles_key='abc', target_key='cba') + assert to_smiles(next(iter(dp))) == smiles diff --git a/test/data/test_hypergraph_data.py b/test/data/test_hypergraph_data.py new file mode 100644 index 0000000..fa33a0c --- /dev/null +++ b/test/data/test_hypergraph_data.py @@ -0,0 +1,172 @@ +import pytest + +import paddle + +import paddle_geometric +from paddle_geometric.data.hypergraph_data import HyperGraphData +from paddle_geometric.loader import DataLoader + + +def test_hypergraph_data(): + paddle_geometric.set_debug(True) + + x = paddle.to_tensor([[1, 3, 5, 7], [2, 4, 6, 8], [7, 8, 9, 10]], + dtype=paddle.float32).t() + edge_index = paddle.to_tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3], + [0, 0, 0, 1, 1, 1, 2, 2, 2]]) + data = HyperGraphData(x=x, edge_index=edge_index) + data.validate(raise_on_error=True) + + assert data.num_nodes == 4 + assert data.num_edges == 3 + + assert data.node_attrs() == ['x'] + assert data.edge_attrs() == ['edge_index'] + + assert data.x.tolist() == x.tolist() + assert data['x'].tolist() == x.tolist() + assert data.get('x').tolist() == x.tolist() + assert data.get('y', 2) == 2 + assert data.get('y', None) is None + + assert sorted(data.keys()) == ['edge_index', 'x'] + assert len(data) == 2 + assert 'x' in data and 'edge_index' in data and 'pos' not in data + + D = data.to_dict() + assert len(D) == 2 + assert 'x' in D and 'edge_index' in D + + D = data.to_namedtuple() + assert len(D) == 2 + assert D.x is not None and D.edge_index is not None + + assert data.__cat_dim__('x', data.x) == 0 + assert data.__cat_dim__('edge_index', data.edge_index) == -1 + assert data.__inc__('x', data.x) == 0 + assert paddle.equal_all(data.__inc__('edge_index', data.edge_index), + paddle.to_tensor([[data.num_nodes], [data.num_edges]])).item() + data_list = [data, data] + loader = DataLoader(data_list, batch_size=2) + batch = next(iter(loader)) + batched_edge_index = batch.edge_index + assert batched_edge_index.tolist() == [[ + 0, 1, 2, 1, 2, 3, 0, 2, 3, 4, 5, 6, 5, 6, 7, 4, 6, 7 + ], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]] + + assert not data.x.is_contiguous() + data.contiguous() + assert data.x.is_contiguous() + + assert not data.is_coalesced() + data = data.coalesce() + assert data.is_coalesced() + + clone = data.clone() + assert clone != data + assert len(clone) == len(data) + assert clone.x.data_ptr() != data.x.data_ptr() + assert clone.x.tolist() == data.x.tolist() + assert clone.edge_index.data_ptr() != data.edge_index.data_ptr() + assert clone.edge_index.tolist() == data.edge_index.tolist() + + data['x'] = x + 1 + assert data.x.tolist() == (x + 1).tolist() + + assert str(data) == 'HyperGraphData(x=[4, 3], edge_index=[2, 9])' + + dictionary = {'x': data.x, 'edge_index': data.edge_index} + data = HyperGraphData.from_dict(dictionary) + assert sorted(data.keys()) == ['edge_index', 'x'] + + assert not data.has_isolated_nodes() + # assert not data.has_self_loops() + # assert data.is_undirected() + # assert not data.is_directed() + + assert data.num_nodes == 4 + assert data.num_edges == 3 + with pytest.warns(UserWarning, match='deprecated'): + assert data.num_faces is None + assert data.num_node_features == 3 + assert data.num_features == 3 + + data.edge_attr = paddle.randn(data.num_edges, 2) + assert data.num_edge_features == 2 + assert data.is_edge_attr('edge_attr') + data.edge_attr = None + + data.x = None + with pytest.warns(UserWarning, match='Unable to accurately infer'): + assert data.num_nodes == 4 + + data.edge_index = None + with pytest.warns(UserWarning, match='Unable to accurately infer'): + assert data.num_nodes is None + assert data.num_edges == 0 + + data.num_nodes = 4 + assert data.num_nodes == 4 + + data = HyperGraphData(x=x, attribute=x) + assert len(data) == 2 + assert data.x.tolist() == x.tolist() + assert data.attribute.tolist() == x.tolist() + + face = paddle.to_tensor([[0, 1], [1, 2], [2, 3]]) + data = HyperGraphData(num_nodes=4, face=face) + with pytest.warns(UserWarning, match='deprecated'): + assert data.num_faces == 2 + assert data.num_nodes == 4 + + data = HyperGraphData(title='test') + assert str(data) == "HyperGraphData(title='test')" + assert data.num_node_features == 0 + # assert data.num_edge_features == 0 + + key = value = 'test_value' + data[key] = value + assert data[key] == value + del data[value] + del data[value] # Deleting unset attributes should work as well. + + assert data.get(key) is None + assert data.get('title') == 'test' + + paddle_geometric.set_debug(False) + + +def test_hypergraphdata_subgraph(): + x = paddle.arange(5) + y = paddle.to_tensor([0.]) + edge_index = paddle.to_tensor([[0, 1, 3, 2, 4, 0, 3, 4, 2, 1, 2, 3], + [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3]]) + edge_attr = paddle.rand(4, 2) + data = HyperGraphData(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr, + num_nodes=5) + + out = data.subgraph(paddle.to_tensor([1, 2, 4])) + assert len(out) == 5 + assert paddle.equal_all(out.x, paddle.to_tensor([1, 2, 4])).item() + assert paddle.equal_all(out.y, data.y).item() + assert out.edge_index.tolist() == [[1, 2, 2, 1, 0, 1], [0, 0, 1, 1, 2, 2]] + assert paddle.equal_all(out.edge_attr, edge_attr[[1, 2, 3]]).item() + assert out.num_nodes == 3 + + # Test unordered selection: + out = data.subgraph(paddle.to_tensor([3, 1, 2])) + assert len(out) == 5 + assert paddle.equal_all(out.x, paddle.to_tensor([3, 1, 2])).item() + assert paddle.equal_all(out.y, data.y).item() + assert out.edge_index.tolist() == [[0, 2, 0, 2, 1, 2, 0], + [0, 0, 1, 1, 2, 2, 2]] + assert paddle.equal_all(out.edge_attr, edge_attr[[1, 2, 3]]).item() + assert out.num_nodes == 3 + + out = data.subgraph(paddle.to_tensor([False, False, False, True, True])) + assert len(out) == 5 + assert paddle.equal_all(out.x, paddle.arange(3, 5)).item() + assert paddle.equal_all(out.y, data.y).item() + assert out.edge_index.tolist() == [[0, 1, 0, 1], [0, 0, 1, 1]] + assert paddle.equal_all(out.edge_attr, edge_attr[[1, 2]]).item() + assert out.num_nodes == 2 diff --git a/test/data/test_on_disk_dataset.py b/test/data/test_on_disk_dataset.py new file mode 100644 index 0000000..ddd6c8d --- /dev/null +++ b/test/data/test_on_disk_dataset.py @@ -0,0 +1,111 @@ +import os.path as osp +from typing import Any, Dict + +import paddle + +from paddle_geometric.data import Data, OnDiskDataset +from paddle_geometric.testing import withPackage + + +@withPackage('sqlite3') +def test_pickle(tmp_path): + dataset = OnDiskDataset(tmp_path) + assert len(dataset) == 0 + assert str(dataset) == 'OnDiskDataset(0)' + assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db')) + + data_list = [ + Data( + x=paddle.randn(5, 8), + edge_index=paddle.randint(0, 5, (2, 16)), + num_nodes=5, + ) for _ in range(4) + ] + + dataset.append(data_list[0]) + assert len(dataset) == 1 + + dataset.extend(data_list[1:]) + assert len(dataset) == 4 + + out = dataset.get(0) + assert paddle.equal_all(out.x, data_list[0].x).item() + assert paddle.equal_all(out.edge_index, data_list[0].edge_index).item() + assert out.num_nodes == data_list[0].num_nodes + + out_list = dataset.multi_get([1, 2, 3]) + for out, data in zip(out_list, data_list[1:]): + assert paddle.equal_all(out.x, data.x).item() + assert paddle.equal_all(out.edge_index, data.edge_index).item() + assert out.num_nodes == data.num_nodes + + dataset.close() + + # Test persistence of datasets: + dataset = OnDiskDataset(tmp_path) + assert len(dataset) == 4 + + out = dataset.get(0) + assert paddle.equal_all(out.x, data_list[0].x).item() + assert paddle.equal_all(out.edge_index, data_list[0].edge_index).item() + assert out.num_nodes == data_list[0].num_nodes + + dataset.close() + + +@withPackage('sqlite3') +def test_custom_schema(tmp_path): + class CustomSchemaOnDiskDataset(OnDiskDataset): + def __init__(self, root: str): + schema = { + 'x': dict(dtype=paddle.float32, size=(-1, 8)), + 'edge_index': dict(dtype=paddle.int64, size=(2, -1)), + 'num_nodes': int, + } + self.serialize_count = 0 + self.deserialize_count = 0 + super().__init__(root, schema=schema) + + def serialize(self, data: Data) -> Dict[str, Any]: + self.serialize_count += 1 + return data.to_dict() + + def deserialize(self, mapping: Dict[str, Any]) -> Any: + self.deserialize_count += 1 + return Data.from_dict(mapping) + + dataset = CustomSchemaOnDiskDataset(tmp_path) + assert len(dataset) == 0 + assert str(dataset) == 'CustomSchemaOnDiskDataset(0)' + assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db')) + + data_list = [ + Data( + x=paddle.randn(5, 8), + edge_index=paddle.randint(0, 5, (2, 16)), + num_nodes=5, + ) for _ in range(4) + ] + + dataset.append(data_list[0]) + assert dataset.serialize_count == 1 + assert len(dataset) == 1 + + dataset.extend(data_list[1:]) + assert dataset.serialize_count == 4 + assert len(dataset) == 4 + + out = dataset.get(0) + assert dataset.deserialize_count == 1 + assert paddle.equal_all(out.x, data_list[0].x).item() + assert paddle.equal_all(out.edge_index, data_list[0].edge_index).item() + assert out.num_nodes == data_list[0].num_nodes + + out_list = dataset.multi_get([1, 2, 3]) + assert dataset.deserialize_count == 4 + for out, data in zip(out_list, data_list[1:]): + assert paddle.equal_all(out.x, data.x).item() + assert paddle.equal_all(out.edge_index, data.edge_index).item() + assert out.num_nodes == data.num_nodes + + dataset.close() diff --git a/test/data/test_remote_backend_utils.py b/test/data/test_remote_backend_utils.py new file mode 100644 index 0000000..b8432e6 --- /dev/null +++ b/test/data/test_remote_backend_utils.py @@ -0,0 +1,34 @@ +import pytest + +import paddle + +from paddle_geometric.data import HeteroData +from paddle_geometric.data.remote_backend_utils import num_nodes, size +from paddle_geometric.testing import ( + MyFeatureStore, + MyGraphStore, + get_random_edge_index, +) + + +@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData]) +@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData]) +def test_num_nodes_size(FeatureStore, GraphStore): + feature_store = FeatureStore() + graph_store = GraphStore() + + # Infer num nodes from features: + x = paddle.arange(100) + feature_store.put_tensor(x, group_name='x', attr_name='x', index=None) + assert num_nodes(feature_store, graph_store, 'x') == 100 + + # Infer num nodes and size from edges: + xy = get_random_edge_index(100, 50, 20) + graph_store.put_edge_index(xy, edge_type=('x', 'to', 'y'), layout='coo', + size=(100, 50)) + assert num_nodes(feature_store, graph_store, 'y') == 50 + assert size(feature_store, graph_store, ('x', 'to', 'y')) == (100, 50) + + # Throw an error if we cannot infer for an unknown node type: + with pytest.raises(ValueError, match="Unable to accurately infer"): + _ = num_nodes(feature_store, graph_store, 'z') diff --git a/test/data/test_temporal.py b/test/data/test_temporal.py new file mode 100644 index 0000000..994bdcf --- /dev/null +++ b/test/data/test_temporal.py @@ -0,0 +1,137 @@ +import copy + +import paddle + +from paddle_geometric.data import TemporalData + + +def get_temporal_data(num_events, msg_channels): + return TemporalData( + src=paddle.arange(num_events), + dst=paddle.arange(num_events, num_events * 2), + t=paddle.arange(0, num_events * 1000, step=1000), + msg=paddle.randn(num_events, msg_channels), + y=paddle.randint(0, 2, (num_events, )), + ) + + +def test_temporal_data(): + data = get_temporal_data(num_events=3, msg_channels=16) + assert str(data) == ("TemporalData(src=[3], dst=[3], t=[3], " + "msg=[3, 16], y=[3])") + + assert data.num_nodes == 6 + assert data.num_events == data.num_edges == len(data) == 3 + + assert data.src.tolist() == [0, 1, 2] + assert data['src'].tolist() == [0, 1, 2] + + assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]] + data.edge_index = 'edge_index' + assert data.edge_index == 'edge_index' + del data.edge_index + assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]] + + assert sorted(data.keys()) == ['dst', 'msg', 'src', 't', 'y'] + assert sorted(data.to_dict().keys()) == sorted(data.keys()) + + data_tuple = data.to_namedtuple() + assert len(data_tuple) == 5 + assert data_tuple.src is not None + assert data_tuple.dst is not None + assert data_tuple.t is not None + assert data_tuple.msg is not None + assert data_tuple.y is not None + + assert data.__cat_dim__('src', data.src) == 0 + assert data.__inc__('src', data.src) == 6 + + clone = data.clone() + assert clone != data + assert len(clone) == len(data) + assert clone.src.data_ptr() != data.src.data_ptr() + assert clone.src.tolist() == data.src.tolist() + assert clone.dst.data_ptr() != data.dst.data_ptr() + assert clone.dst.tolist() == data.dst.tolist() + + deepcopy = copy.deepcopy(data) + assert deepcopy != data + assert len(deepcopy) == len(data) + assert deepcopy.src.data_ptr() != data.src.data_ptr() + assert deepcopy.src.tolist() == data.src.tolist() + assert deepcopy.dst.data_ptr() != data.dst.data_ptr() + assert deepcopy.dst.tolist() == data.dst.tolist() + + key = value = 'test_value' + data[key] = value + assert data[key] == value + assert data.test_value == value + del data[key] + del data[key] # Deleting unset attributes should work as well. + + assert data.get(key, 10) == 10 + + assert len([event for event in data]) == 3 + + assert len([attr for attr in data()]) == 5 + + assert data.size() == (2, 5) + + del data.src + assert 'src' not in data + + +def test_train_val_test_split(): + data = get_temporal_data(num_events=100, msg_channels=16) + + train_data, val_data, test_data = data.train_val_test_split( + val_ratio=0.2, test_ratio=0.15) + + assert len(train_data) == 65 + assert len(val_data) == 20 + assert len(test_data) == 15 + + assert train_data.t.max() < val_data.t.min() + assert val_data.t.max() < test_data.t.min() + + +def test_temporal_indexing(): + data = get_temporal_data(num_events=10, msg_channels=16) + + elem = data[0] + assert isinstance(elem, TemporalData) + assert len(elem) == 1 + assert elem.src.tolist() == data.src[0:1].tolist() + assert elem.dst.tolist() == data.dst[0:1].tolist() + assert elem.t.tolist() == data.t[0:1].tolist() + assert elem.msg.tolist() == data.msg[0:1].tolist() + assert elem.y.tolist() == data.y[0:1].tolist() + + subset = data[0:5] + assert isinstance(subset, TemporalData) + assert len(subset) == 5 + assert subset.src.tolist() == data.src[0:5].tolist() + assert subset.dst.tolist() == data.dst[0:5].tolist() + assert subset.t.tolist() == data.t[0:5].tolist() + assert subset.msg.tolist() == data.msg[0:5].tolist() + assert subset.y.tolist() == data.y[0:5].tolist() + + index = [0, 4, 8] + subset = data[paddle.to_tensor(index)] + assert isinstance(subset, TemporalData) + assert len(subset) == 3 + assert subset.src.tolist() == data.src[0::4].tolist() + assert subset.dst.tolist() == data.dst[0::4].tolist() + assert subset.t.tolist() == data.t[0::4].tolist() + assert subset.msg.tolist() == data.msg[0::4].tolist() + assert subset.y.tolist() == data.y[0::4].tolist() + + mask = [True, False, True, False, True, False, True, False, True, False] + subset = data[paddle.to_tensor(mask)] + assert isinstance(subset, TemporalData) + assert len(subset) == 5 + assert subset.src.tolist() == data.src[0::2].tolist() + assert subset.dst.tolist() == data.dst[0::2].tolist() + assert subset.t.tolist() == data.t[0::2].tolist() + assert subset.msg.tolist() == data.msg[0::2].tolist() + assert subset.y.tolist() == data.y[0::2].tolist()