Skip to content

Commit

Permalink
Read any number of arrays (#161)
Browse files Browse the repository at this point in the history
* Pass x_params/y_params positionally

* Generalize PyTorchTileDBDataLoader and TensorflowTileDBDataset to take one or more ArrayParams instead of exactly two
  • Loading branch information
gsakkis authored Jun 30, 2022
1 parent 233d1ee commit 1ddf09f
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 120 deletions.
7 changes: 3 additions & 4 deletions examples/readers/pytorch_data_api_tiledb_dense.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,9 @@
"\n",
"ctx = tiledb.Ctx({\"sm.memory_budget\": 1024**2})\n",
"with tiledb.open(training_images, ctx=ctx) as x, tiledb.open(training_labels, ctx=ctx) as y:\n",
" train_loader = PyTorchTileDBDataLoader(x_params=ArrayParams(x),\n",
" y_params=ArrayParams(y),\n",
" batch_size=64,\n",
" shuffle_buffer_size=128)\n",
" train_loader = PyTorchTileDBDataLoader(\n",
" ArrayParams(x), ArrayParams(y), batch_size=64, shuffle_buffer_size=128\n",
" )\n",
" net = Net(shape=(28, 28))\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)\n",
Expand Down
7 changes: 2 additions & 5 deletions examples/readers/pytorch_data_api_tiledb_sparse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,8 @@
"ctx = tiledb.Ctx({\"sm.memory_budget\": 1024**2, \"py.init_buffer_bytes\": 1024**2})\n",
"with tiledb.open(training_images, ctx=ctx) as x, tiledb.open(training_labels, ctx=ctx) as y:\n",
" train_loader = PyTorchTileDBDataLoader(\n",
" x_params=ArrayParams(x),\n",
" y_params=ArrayParams(y),\n",
" batch_size=32,\n",
" csr=False)\n",
"\n",
" ArrayParams(x), ArrayParams(y), batch_size=32, csr=False\n",
" )\n",
" #Number of ratings x (user + movies)\n",
" datashape_x = (100000, 2625)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/readers/tensorflow_data_api_tiledb_dense.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@
"ctx = tiledb.Ctx({\"sm.memory_budget\": 1024**2})\n",
"with tiledb.open(training_images, ctx=ctx) as x, tiledb.open(training_labels, ctx=ctx) as y:\n",
" tiledb_dataset = TensorflowTileDBDataset(\n",
" x_params=ArrayParams(array=x, fields=['features']),\n",
" y_params=ArrayParams(array=y, fields=['features']),\n",
" ArrayParams(array=x, fields=['features']),\n",
" ArrayParams(array=y, fields=['features']),\n",
" batch_size=64, shuffle_buffer_size=128\n",
" )\n",
" model.fit(tiledb_dataset, epochs=5)"
Expand Down
4 changes: 2 additions & 2 deletions examples/readers/tensorflow_data_api_tiledb_sparse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,8 @@
"ctx = tiledb.Ctx({\"sm.memory_budget\": 1024**2, \"py.init_buffer_bytes\": 1024**2})\n",
"with tiledb.open(training_images, ctx=ctx) as x, tiledb.open(training_labels, ctx=ctx) as y:\n",
" tiledb_dataset = TensorflowTileDBDataset(\n",
" x_params=ArrayParams(array=x, fields=['features']),\n",
" y_params=ArrayParams(array=y, fields=['features']),\n",
" ArrayParams(array=x, fields=['features']),\n",
" ArrayParams(array=y, fields=['features']),\n",
" batch_size=32)\n",
" model = design_model(input_shape=user_movie.shape[1])\n",
" model.fit(tiledb_dataset, epochs=2, batch_size=32)"
Expand Down
22 changes: 7 additions & 15 deletions tests/readers/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ def test_dataloader(
tmpdir, y_shape, y_sparse, key_dim_dtype, y_key_dim, num_fields
) as y_kwargs:
dataloader = PyTorchTileDBDataLoader(
x_params=ArrayParams(
x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]
),
y_params=ArrayParams(
y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]
),
ArrayParams(x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]),
ArrayParams(y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]),
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
num_workers=num_workers,
Expand Down Expand Up @@ -78,17 +74,17 @@ def test_unequal_num_keys(
) as y_kwargs:
with pytest.raises(ValueError) as ex:
PyTorchTileDBDataLoader(
x_params=ArrayParams(
ArrayParams(
x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]
),
y_params=ArrayParams(
ArrayParams(
y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]
),
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
num_workers=num_workers,
)
assert "X and Y arrays have different key range" in str(ex.value)
assert "All arrays must have the same key range" in str(ex.value)

@parametrize_for_dataset(num_fields=[0], shuffle_buffer_size=[0], num_workers=[0])
@pytest.mark.parametrize("csr", [True, False])
Expand Down Expand Up @@ -119,12 +115,8 @@ def test_dataloader_order(
tmpdir, y_shape, y_sparse, key_dim_dtype, y_key_dim, num_fields
) as y_kwargs:
dataloader = PyTorchTileDBDataLoader(
x_params=ArrayParams(
x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]
),
y_params=ArrayParams(
y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]
),
ArrayParams(x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]),
ArrayParams(y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]),
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
num_workers=num_workers,
Expand Down
22 changes: 7 additions & 15 deletions tests/readers/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ def test_dataset(
tmpdir, y_shape, y_sparse, key_dim_dtype, y_key_dim, num_fields
) as y_kwargs:
dataset = TensorflowTileDBDataset(
x_params=ArrayParams(
x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]
),
y_params=ArrayParams(
y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]
),
ArrayParams(x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]),
ArrayParams(y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]),
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
num_workers=num_workers,
Expand Down Expand Up @@ -75,17 +71,17 @@ def test_unequal_num_keys(
) as y_kwargs:
with pytest.raises(ValueError) as ex:
TensorflowTileDBDataset(
x_params=ArrayParams(
ArrayParams(
x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]
),
y_params=ArrayParams(
ArrayParams(
y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]
),
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
num_workers=num_workers,
)
assert "X and Y arrays have different key range" in str(ex.value)
assert "All arrays must have the same key range" in str(ex.value)

@parametrize_for_dataset(num_fields=[0], shuffle_buffer_size=[0], num_workers=[0])
def test_dataset_order(
Expand Down Expand Up @@ -114,12 +110,8 @@ def test_dataset_order(
tmpdir, y_shape, y_sparse, key_dim_dtype, y_key_dim, num_fields
) as y_kwargs:
dataset = TensorflowTileDBDataset(
x_params=ArrayParams(
x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]
),
y_params=ArrayParams(
y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]
),
ArrayParams(x_kwargs["array"], x_kwargs["key_dim"], x_kwargs["fields"]),
ArrayParams(y_kwargs["array"], y_kwargs["key_dim"], y_kwargs["fields"]),
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
num_workers=num_workers,
Expand Down
83 changes: 35 additions & 48 deletions tiledb/ml/readers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import scipy.sparse
import sparse
import torch
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

try:
# torch>=1.10
Expand All @@ -22,29 +23,25 @@
from ._tensor_schema import DenseTensorSchema, SparseTensorSchema, TensorSchema
from .types import ArrayParams

TensorLikeSequence = Union[
Tensor = Union[np.ndarray, sparse.COO, scipy.sparse.csr_matrix]
TensorSequence = Union[
Sequence[np.ndarray], Sequence[sparse.COO], Sequence[scipy.sparse.csr_matrix]
]
TensorLikeOrSequence = Union[
np.ndarray, sparse.COO, scipy.sparse.csr_matrix, TensorLikeSequence
]
XY = Tuple[TensorLikeOrSequence, TensorLikeOrSequence]
TensorOrSequence = Union[Tensor, TensorSequence]
OneOrMoreTensorsOrSequences = Union[TensorOrSequence, Tuple[TensorOrSequence, ...]]


def PyTorchTileDBDataLoader(
x_params: ArrayParams,
y_params: ArrayParams,
*,
*array_params: ArrayParams,
batch_size: int,
shuffle_buffer_size: int = 0,
prefetch: int = 2,
num_workers: int = 0,
csr: bool = True,
) -> torch.utils.data.DataLoader:
) -> DataLoader:
"""Return a DataLoader for loading data from TileDB arrays.
:param x_params: TileDB ArrayParams of the features.
:param y_params: TileDB ArrayParams of the labels.
:param array_params: One or more `ArrayParams` instances, one per TileDB array.
:param batch_size: Size of each batch.
:param shuffle_buffer_size: Number of elements from which this dataset will sample.
:param prefetch: Number of samples loaded in advance by each worker. Not applicable
Expand All @@ -54,65 +51,55 @@ def PyTorchTileDBDataLoader(
yielded batches may be shuffled even if `shuffle_buffer_size` is zero.
:param csr: For sparse 2D arrays, whether to return CSR tensors instead of COO.
"""
x_schema = _get_tensor_schema(x_params)
y_schema = _get_tensor_schema(y_params)
if not x_schema.key_range.equal_values(y_schema.key_range):
raise ValueError(
f"X and Y arrays have different key range: {x_schema.key_range} != {y_schema.key_range}"
)

return torch.utils.data.DataLoader(
dataset=_PyTorchTileDBDataset(
x_schema=x_schema,
y_schema=y_schema,
shuffle_buffer_size=shuffle_buffer_size,
),
schemas = tuple(map(_get_tensor_schema, array_params))
collators = tuple(
_get_tensor_collator(params.array, csr, len(schema.fields))
for params, schema in zip(array_params, schemas)
)
collate_fn = _CompositeCollator(*collators) if len(collators) > 1 else collators[0]

return DataLoader(
dataset=_PyTorchTileDBDataset(schemas, shuffle_buffer_size=shuffle_buffer_size),
batch_size=batch_size,
prefetch_factor=prefetch,
num_workers=num_workers,
worker_init_fn=_worker_init,
collate_fn=_CompositeCollator(
_get_tensor_collator(x_params.array, csr, len(x_schema.fields)),
_get_tensor_collator(y_params.array, csr, len(y_schema.fields)),
),
collate_fn=collate_fn,
)


class _PyTorchTileDBDataset(torch.utils.data.IterableDataset[XY]):
def __init__(
self,
x_schema: TensorSchema,
y_schema: TensorSchema,
shuffle_buffer_size: int = 0,
):
class _PyTorchTileDBDataset(IterableDataset[OneOrMoreTensorsOrSequences]):
def __init__(self, schemas: Sequence[TensorSchema], shuffle_buffer_size: int = 0):
super().__init__()
self.x_schema = x_schema
self.y_schema = y_schema
self.key_range = x_schema.key_range
key_range = schemas[0].key_range
if not all(key_range.equal_values(schema.key_range) for schema in schemas[1:]):
raise ValueError(f"All arrays must have the same key range: {key_range}")
self.schemas = schemas
self.key_range = key_range
self._shuffle_buffer_size = shuffle_buffer_size

def __iter__(self) -> Iterator[XY]:
rows: Iterator[XY] = zip(
self._iter_rows(self.x_schema), self._iter_rows(self.y_schema)
)
def __iter__(self) -> Iterator[OneOrMoreTensorsOrSequences]:
rows: Iterator[OneOrMoreTensorsOrSequences]
it_rows = tuple(map(self._iter_rows, self.schemas))
rows = zip(*it_rows) if len(it_rows) > 1 else it_rows[0]
if self._shuffle_buffer_size > 0:
rows = _iter_shuffled(rows, self._shuffle_buffer_size)
return rows

def _iter_rows(self, schema: TensorSchema) -> Iterator[TensorLikeOrSequence]:
def _iter_rows(self, schema: TensorSchema) -> Iterator[TensorOrSequence]:
max_weight = schema.max_partition_weight
key_subranges = self.key_range.partition_by_weight(max_weight)
batches: Iterable[TensorLikeOrSequence] = schema.iter_tensors(key_subranges)
batches: Iterable[TensorOrSequence] = schema.iter_tensors(key_subranges)
if len(schema.fields) == 1:
return (tensor for batch in batches for tensor in batch)
else:
return (tensors for batch in batches for tensors in zip(*batch))


def _worker_init(worker_id: int) -> None:
worker_info = torch.utils.data.get_worker_info()
worker_info = get_worker_info()
dataset = worker_info.dataset
if dataset.x_schema.sparse or dataset.y_schema.sparse:
if any(schema.sparse for schema in dataset.schemas):
raise NotImplementedError("https://github.com/pytorch/pytorch/issues/20248")
key_ranges = list(dataset.key_range.partition_by_count(worker_info.num_workers))
dataset.key_range = key_ranges[worker_id]
Expand All @@ -127,7 +114,7 @@ def _get_tensor_schema(array_params: ArrayParams) -> TensorSchema:
return SparseTensorSchema.from_array_params(array_params)


_SingleCollator = Callable[[TensorLikeSequence], torch.Tensor]
_SingleCollator = Callable[[TensorSequence], torch.Tensor]


class _CompositeCollator:
Expand All @@ -139,7 +126,7 @@ class _CompositeCollator:
def __init__(self, *collators: _SingleCollator):
self._collators = collators

def __call__(self, rows: Sequence[TensorLikeSequence]) -> Sequence[torch.Tensor]:
def __call__(self, rows: Sequence[TensorSequence]) -> Sequence[torch.Tensor]:
columns = tuple(zip(*rows))
collators = self._collators
assert len(columns) == len(collators)
Expand Down
49 changes: 20 additions & 29 deletions tiledb/ml/readers/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,15 @@


def TensorflowTileDBDataset(
x_params: ArrayParams,
y_params: ArrayParams,
*,
*array_params: ArrayParams,
batch_size: int,
shuffle_buffer_size: int = 0,
prefetch: int = tf.data.AUTOTUNE,
num_workers: int = 0,
) -> tf.data.Dataset:
"""Return a tf.data.Dataset for loading data from TileDB arrays.
:param x_params: TileDB ArrayParams of the features.
:param y_params: TileDB ArrayParams of the labels.
:param array_params: One or more `ArrayParams` instances, one per TileDB array.
:param batch_size: Size of each batch.
:param shuffle_buffer_size: Number of elements from which this dataset will sample.
:param prefetch: Maximum number of batches that will be buffered when prefetching.
Expand All @@ -30,36 +27,30 @@ def TensorflowTileDBDataset(
used to fetch inputs asynchronously and in parallel. Note: when `num_workers` > 1
yielded batches may be shuffled even if `shuffle_buffer_size` is zero.
"""
x_schema = _get_tensor_schema(x_params)
y_schema = _get_tensor_schema(y_params)
if not x_schema.key_range.equal_values(y_schema.key_range):
raise ValueError(
f"X and Y arrays have different key range: {x_schema.key_range} != {y_schema.key_range}"
)
schemas = tuple(map(_get_tensor_schema, array_params))
key_range = schemas[0].key_range
if not all(key_range.equal_values(schema.key_range) for schema in schemas[1:]):
raise ValueError(f"All arrays must have the same key range: {key_range}")

x_max_weight = x_schema.max_partition_weight
y_max_weight = y_schema.max_partition_weight
key_ranges = list(x_schema.key_range.partition_by_count(num_workers or 1))
max_weights = tuple(schema.max_partition_weight for schema in schemas)
key_subranges = tuple(key_range.partition_by_count(num_workers or 1))

def key_range_dataset(key_range_idx: int) -> tf.data.Dataset:
x_dataset = tf.data.Dataset.from_generator(
lambda i: x_schema.iter_tensors(
key_ranges[i].partition_by_weight(x_max_weight)
),
args=(key_range_idx,),
output_signature=_get_tensor_specs(x_schema),
)
y_dataset = tf.data.Dataset.from_generator(
lambda i: y_schema.iter_tensors(
key_ranges[i].partition_by_weight(y_max_weight)
),
args=(key_range_idx,),
output_signature=_get_tensor_specs(y_schema),
datasets = tuple(
tf.data.Dataset.from_generator(
lambda i, schema=schema, max_weight=max_weight: schema.iter_tensors(
key_subranges[i].partition_by_weight(max_weight)
),
args=(key_range_idx,),
output_signature=_get_tensor_specs(schema),
).unbatch()
for schema, max_weight in zip(schemas, max_weights)
)
return tf.data.Dataset.zip((x_dataset.unbatch(), y_dataset.unbatch()))
return tf.data.Dataset.zip(datasets) if len(datasets) > 1 else datasets[0]

if num_workers:
dataset = tf.data.Dataset.from_tensor_slices(range(len(key_ranges))).interleave(
dataset = tf.data.Dataset.from_tensor_slices(range(len(key_subranges)))
dataset = dataset.interleave(
key_range_dataset, num_parallel_calls=num_workers, deterministic=False
)
else:
Expand Down

0 comments on commit 1ddf09f

Please sign in to comment.