Skip to content

Commit

Permalink
Move MultiprocessIterator to experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Feb 17, 2024
1 parent a05aef7 commit 4035273
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 130 deletions.
35 changes: 18 additions & 17 deletions jax_dataloader/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@
'jax_dataloader/datasets.py'),
'jax_dataloader.datasets.Dataset.__len__': ( 'dataset.html#dataset.__len__',
'jax_dataloader/datasets.py')},
'jax_dataloader.experimental.multi_processing': { 'jax_dataloader.experimental.multi_processing.EpochIterator': ( 'experimental/mp.html#epochiterator',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.EpochIterator.__del__': ( 'experimental/mp.html#epochiterator.__del__',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.EpochIterator.__init__': ( 'experimental/mp.html#epochiterator.__init__',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.EpochIterator.__iter__': ( 'experimental/mp.html#epochiterator.__iter__',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.EpochIterator.__next__': ( 'experimental/mp.html#epochiterator.__next__',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.EpochIterator.close': ( 'experimental/mp.html#epochiterator.close',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.EpochIterator.get_data': ( 'experimental/mp.html#epochiterator.get_data',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.EpochIterator.run': ( 'experimental/mp.html#epochiterator.run',
'jax_dataloader/experimental/multi_processing.py'),
'jax_dataloader.experimental.multi_processing.chunk': ( 'experimental/mp.html#chunk',
'jax_dataloader/experimental/multi_processing.py')},
'jax_dataloader.imports': {},
'jax_dataloader.loaders.base': { 'jax_dataloader.loaders.base.BaseDataLoader': ( 'loader.base.html#basedataloader',
'jax_dataloader/loaders/base.py'),
Expand All @@ -60,23 +78,6 @@
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.EpochIterator': ( 'loader.jax.html#epochiterator',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator': ( 'loader.jax.html#multiprocessiterator',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator.__del__': ( 'loader.jax.html#multiprocessiterator.__del__',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator.__init__': ( 'loader.jax.html#multiprocessiterator.__init__',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator.__iter__': ( 'loader.jax.html#multiprocessiterator.__iter__',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator.__next__': ( 'loader.jax.html#multiprocessiterator.__next__',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator.close': ( 'loader.jax.html#multiprocessiterator.close',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator.get_data': ( 'loader.jax.html#multiprocessiterator.get_data',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.MultiprocessIterator.run': ( 'loader.jax.html#multiprocessiterator.run',
'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.chunk': ('loader.jax.html#chunk', 'jax_dataloader/loaders/jax.py'),
'jax_dataloader.loaders.jax.to_jax_dataset': ( 'loader.jax.html#to_jax_dataset',
'jax_dataloader/loaders/jax.py')},
'jax_dataloader.loaders.tensorflow': { 'jax_dataloader.loaders.tensorflow.DataLoaderTensorflow': ( 'loader.tf.html#dataloadertensorflow',
Expand Down
Empty file.
74 changes: 74 additions & 0 deletions jax_dataloader/experimental/multi_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/experimental/mp.ipynb.

# %% ../../nbs/experimental/mp.ipynb 1
from __future__ import print_function, division, annotations
from ..imports import *
from ..datasets import ArrayDataset, JAXDataset
from ..loaders import BaseDataLoader
from ..utils import get_config, asnumpy
from ..tests import *
import jax_dataloader as jdl
from threading import Thread, Event
from queue import Queue, Full
import multiprocessing as mp
import weakref

# %% auto 0
__all__ = ['chunk', 'EpochIterator']

# %% ../../nbs/experimental/mp.ipynb 2
def chunk(seq: Sequence, size: int) -> List[Sequence]:
return [seq[pos:pos + size] for pos in range(0, len(seq), size)]


# %% ../../nbs/experimental/mp.ipynb 3
class EpochIterator(Thread):
"""[WIP] Multiprocessing Epoch Iterator"""

def __init__(self, data, batch_size: int, indices: Sequence[int]):
super().__init__()
self.data = data
batches = chunk(indices, batch_size)
self.iter_idx = iter(batches)
self.output_queue = Queue(5) # TODO: maxsize
self.terminate_event = Event()
self.start()

def run(self):
try:
while True:
# get data
result = self.get_data()
# put result in queue
while True:
try:
self.output_queue.put(result, block=True, timeout=0.5)
break
except Full: pass

if self.terminate_event.is_set(): return

except StopIteration:
self.output_queue.put(None)

def __next__(self):
result = self.output_queue.get()
if result is None:
self.close()
raise StopIteration()
return result

def __iter__(self):
return self

def __del__(self):
self.close()

def close(self):
self.terminate_event.set()

def get_data(self):
batch_idx = next(self.iter_idx)
batch = self.data[batch_idx]
return batch

55 changes: 3 additions & 52 deletions jax_dataloader/loaders/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@
from queue import Queue

# %% auto 0
__all__ = ['chunk', 'EpochIterator', 'MultiprocessIterator', 'to_jax_dataset', 'DataLoaderJAX']
__all__ = ['EpochIterator', 'to_jax_dataset', 'DataLoaderJAX']

# %% ../../nbs/loader.jax.ipynb 4
def chunk(seq: Sequence, size: int) -> List[Sequence]:
return [seq[pos:pos + size] for pos in range(0, len(seq), size)]


# %% ../../nbs/loader.jax.ipynb 5
def EpochIterator(
data,
batch_size: int,
Expand All @@ -29,51 +24,7 @@ def EpochIterator(
idx = indices[i:i+batch_size]
yield data[idx]

# %% ../../nbs/loader.jax.ipynb 6
class MultiprocessIterator(Thread):
"""[WIP] Multiprocessing Epoch Iterator"""

def __init__(self, data, batch_size: int, indices=None):
super().__init__()
self.data = data
indices = np.arange(len(data)) if indices is None else indices
batches = chunk(indices, batch_size)
self.iter_idx = iter(batches)
self.output_queue = Queue() # TODO: maxsize
self.terminate_event = Event()
self.start()

def run(self):
try:
while True:
result = self.get_data()
self.output_queue.put(result)
except StopIteration:
self.output_queue.put(None)

def __next__(self):
result = self.output_queue.get()
if result is None:
self.close()
raise StopIteration()
return result

def __iter__(self):
return self

def __del__(self):
self.close()

def close(self):
self.terminate_event.set()

def get_data(self):
batch_idx = next(self.iter_idx)
batch = self.data[batch_idx]
return batch


# %% ../../nbs/loader.jax.ipynb 7
# %% ../../nbs/loader.jax.ipynb 5
@dispatch
def to_jax_dataset(dataset: JAXDataset):
if isinstance(dataset, ArrayDataset):
Expand All @@ -84,7 +35,7 @@ def to_jax_dataset(dataset: JAXDataset):
def to_jax_dataset(dataset: HFDataset):
return dataset.with_format('numpy')

# %% ../../nbs/loader.jax.ipynb 8
# %% ../../nbs/loader.jax.ipynb 6
class DataLoaderJAX(BaseDataLoader):

@typecheck
Expand Down
105 changes: 105 additions & 0 deletions nbs/experimental/mp.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp experimental.multi_processing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from __future__ import print_function, division, annotations\n",
"from jax_dataloader.imports import *\n",
"from jax_dataloader.datasets import ArrayDataset, JAXDataset\n",
"from jax_dataloader.loaders import BaseDataLoader\n",
"from jax_dataloader.utils import get_config, asnumpy\n",
"from jax_dataloader.tests import *\n",
"import jax_dataloader as jdl\n",
"from threading import Thread, Event\n",
"from queue import Queue, Full\n",
"import multiprocessing as mp\n",
"import weakref"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def chunk(seq: Sequence, size: int) -> List[Sequence]:\n",
" return [seq[pos:pos + size] for pos in range(0, len(seq), size)] \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class EpochIterator(Thread):\n",
" \"\"\"[WIP] Multiprocessing Epoch Iterator\"\"\"\n",
" \n",
" def __init__(self, data, batch_size: int, indices: Sequence[int]):\n",
" super().__init__()\n",
" self.data = data\n",
" batches = chunk(indices, batch_size)\n",
" self.iter_idx = iter(batches)\n",
" self.output_queue = Queue(5) # TODO: maxsize\n",
" self.terminate_event = Event()\n",
" self.start()\n",
"\n",
" def run(self):\n",
" try:\n",
" while True:\n",
" # get data\n",
" result = self.get_data()\n",
" # put result in queue\n",
" while True:\n",
" try: \n",
" self.output_queue.put(result, block=True, timeout=0.5)\n",
" break\n",
" except Full: pass\n",
" \n",
" if self.terminate_event.is_set(): return \n",
"\n",
" except StopIteration:\n",
" self.output_queue.put(None)\n",
"\n",
" def __next__(self):\n",
" result = self.output_queue.get()\n",
" if result is None:\n",
" self.close()\n",
" raise StopIteration()\n",
" return result\n",
" \n",
" def __iter__(self):\n",
" return self\n",
" \n",
" def __del__(self):\n",
" self.close()\n",
"\n",
" def close(self):\n",
" self.terminate_event.set()\n",
"\n",
" def get_data(self):\n",
" batch_idx = next(self.iter_idx)\n",
" batch = self.data[batch_idx]\n",
" return batch\n"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}
61 changes: 0 additions & 61 deletions nbs/loader.jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,6 @@
"from queue import Queue"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def chunk(seq: Sequence, size: int) -> List[Sequence]:\n",
" return [seq[pos:pos + size] for pos in range(0, len(seq), size)] \n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -87,56 +76,6 @@
" yield data[idx]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class MultiprocessIterator(Thread):\n",
" \"\"\"[WIP] Multiprocessing Epoch Iterator\"\"\"\n",
" \n",
" def __init__(self, data, batch_size: int, indices=None):\n",
" super().__init__()\n",
" self.data = data\n",
" indices = np.arange(len(data)) if indices is None else indices\n",
" batches = chunk(indices, batch_size)\n",
" self.iter_idx = iter(batches)\n",
" self.output_queue = Queue() # TODO: maxsize\n",
" self.terminate_event = Event()\n",
" self.start()\n",
"\n",
" def run(self):\n",
" try:\n",
" while True:\n",
" result = self.get_data()\n",
" self.output_queue.put(result)\n",
" except StopIteration:\n",
" self.output_queue.put(None)\n",
"\n",
" def __next__(self):\n",
" result = self.output_queue.get()\n",
" if result is None:\n",
" self.close()\n",
" raise StopIteration()\n",
" return result\n",
" \n",
" def __iter__(self):\n",
" return self\n",
" \n",
" def __del__(self):\n",
" self.close()\n",
"\n",
" def close(self):\n",
" self.terminate_event.set()\n",
"\n",
" def get_data(self):\n",
" batch_idx = next(self.iter_idx)\n",
" batch = self.data[batch_idx]\n",
" return batch\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 4035273

Please sign in to comment.