From 40352739710d636ae2c3ef97392cf4a21ad8d3bd Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Sat, 17 Feb 2024 11:20:35 -0500 Subject: [PATCH] Move MultiprocessIterator to experimental --- jax_dataloader/_modidx.py | 35 +++--- jax_dataloader/experimental/__init__.py | 0 .../experimental/multi_processing.py | 74 ++++++++++++ jax_dataloader/loaders/jax.py | 55 +-------- nbs/experimental/mp.ipynb | 105 ++++++++++++++++++ nbs/loader.jax.ipynb | 61 ---------- 6 files changed, 200 insertions(+), 130 deletions(-) create mode 100644 jax_dataloader/experimental/__init__.py create mode 100644 jax_dataloader/experimental/multi_processing.py create mode 100644 nbs/experimental/mp.ipynb diff --git a/jax_dataloader/_modidx.py b/jax_dataloader/_modidx.py index b314a42..06daffc 100644 --- a/jax_dataloader/_modidx.py +++ b/jax_dataloader/_modidx.py @@ -37,6 +37,24 @@ 'jax_dataloader/datasets.py'), 'jax_dataloader.datasets.Dataset.__len__': ( 'dataset.html#dataset.__len__', 'jax_dataloader/datasets.py')}, + 'jax_dataloader.experimental.multi_processing': { 'jax_dataloader.experimental.multi_processing.EpochIterator': ( 'experimental/mp.html#epochiterator', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.EpochIterator.__del__': ( 'experimental/mp.html#epochiterator.__del__', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.EpochIterator.__init__': ( 'experimental/mp.html#epochiterator.__init__', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.EpochIterator.__iter__': ( 'experimental/mp.html#epochiterator.__iter__', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.EpochIterator.__next__': ( 'experimental/mp.html#epochiterator.__next__', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.EpochIterator.close': ( 'experimental/mp.html#epochiterator.close', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.EpochIterator.get_data': ( 'experimental/mp.html#epochiterator.get_data', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.EpochIterator.run': ( 'experimental/mp.html#epochiterator.run', + 'jax_dataloader/experimental/multi_processing.py'), + 'jax_dataloader.experimental.multi_processing.chunk': ( 'experimental/mp.html#chunk', + 'jax_dataloader/experimental/multi_processing.py')}, 'jax_dataloader.imports': {}, 'jax_dataloader.loaders.base': { 'jax_dataloader.loaders.base.BaseDataLoader': ( 'loader.base.html#basedataloader', 'jax_dataloader/loaders/base.py'), @@ -60,23 +78,6 @@ 'jax_dataloader/loaders/jax.py'), 'jax_dataloader.loaders.jax.EpochIterator': ( 'loader.jax.html#epochiterator', 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator': ( 'loader.jax.html#multiprocessiterator', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator.__del__': ( 'loader.jax.html#multiprocessiterator.__del__', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator.__init__': ( 'loader.jax.html#multiprocessiterator.__init__', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator.__iter__': ( 'loader.jax.html#multiprocessiterator.__iter__', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator.__next__': ( 'loader.jax.html#multiprocessiterator.__next__', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator.close': ( 'loader.jax.html#multiprocessiterator.close', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator.get_data': ( 'loader.jax.html#multiprocessiterator.get_data', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.MultiprocessIterator.run': ( 'loader.jax.html#multiprocessiterator.run', - 'jax_dataloader/loaders/jax.py'), - 'jax_dataloader.loaders.jax.chunk': ('loader.jax.html#chunk', 'jax_dataloader/loaders/jax.py'), 'jax_dataloader.loaders.jax.to_jax_dataset': ( 'loader.jax.html#to_jax_dataset', 'jax_dataloader/loaders/jax.py')}, 'jax_dataloader.loaders.tensorflow': { 'jax_dataloader.loaders.tensorflow.DataLoaderTensorflow': ( 'loader.tf.html#dataloadertensorflow', diff --git a/jax_dataloader/experimental/__init__.py b/jax_dataloader/experimental/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jax_dataloader/experimental/multi_processing.py b/jax_dataloader/experimental/multi_processing.py new file mode 100644 index 0000000..ddaca5d --- /dev/null +++ b/jax_dataloader/experimental/multi_processing.py @@ -0,0 +1,74 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/experimental/mp.ipynb. + +# %% ../../nbs/experimental/mp.ipynb 1 +from __future__ import print_function, division, annotations +from ..imports import * +from ..datasets import ArrayDataset, JAXDataset +from ..loaders import BaseDataLoader +from ..utils import get_config, asnumpy +from ..tests import * +import jax_dataloader as jdl +from threading import Thread, Event +from queue import Queue, Full +import multiprocessing as mp +import weakref + +# %% auto 0 +__all__ = ['chunk', 'EpochIterator'] + +# %% ../../nbs/experimental/mp.ipynb 2 +def chunk(seq: Sequence, size: int) -> List[Sequence]: + return [seq[pos:pos + size] for pos in range(0, len(seq), size)] + + +# %% ../../nbs/experimental/mp.ipynb 3 +class EpochIterator(Thread): + """[WIP] Multiprocessing Epoch Iterator""" + + def __init__(self, data, batch_size: int, indices: Sequence[int]): + super().__init__() + self.data = data + batches = chunk(indices, batch_size) + self.iter_idx = iter(batches) + self.output_queue = Queue(5) # TODO: maxsize + self.terminate_event = Event() + self.start() + + def run(self): + try: + while True: + # get data + result = self.get_data() + # put result in queue + while True: + try: + self.output_queue.put(result, block=True, timeout=0.5) + break + except Full: pass + + if self.terminate_event.is_set(): return + + except StopIteration: + self.output_queue.put(None) + + def __next__(self): + result = self.output_queue.get() + if result is None: + self.close() + raise StopIteration() + return result + + def __iter__(self): + return self + + def __del__(self): + self.close() + + def close(self): + self.terminate_event.set() + + def get_data(self): + batch_idx = next(self.iter_idx) + batch = self.data[batch_idx] + return batch + diff --git a/jax_dataloader/loaders/jax.py b/jax_dataloader/loaders/jax.py index 76b3ac2..7f15942 100644 --- a/jax_dataloader/loaders/jax.py +++ b/jax_dataloader/loaders/jax.py @@ -12,14 +12,9 @@ from queue import Queue # %% auto 0 -__all__ = ['chunk', 'EpochIterator', 'MultiprocessIterator', 'to_jax_dataset', 'DataLoaderJAX'] +__all__ = ['EpochIterator', 'to_jax_dataset', 'DataLoaderJAX'] # %% ../../nbs/loader.jax.ipynb 4 -def chunk(seq: Sequence, size: int) -> List[Sequence]: - return [seq[pos:pos + size] for pos in range(0, len(seq), size)] - - -# %% ../../nbs/loader.jax.ipynb 5 def EpochIterator( data, batch_size: int, @@ -29,51 +24,7 @@ def EpochIterator( idx = indices[i:i+batch_size] yield data[idx] -# %% ../../nbs/loader.jax.ipynb 6 -class MultiprocessIterator(Thread): - """[WIP] Multiprocessing Epoch Iterator""" - - def __init__(self, data, batch_size: int, indices=None): - super().__init__() - self.data = data - indices = np.arange(len(data)) if indices is None else indices - batches = chunk(indices, batch_size) - self.iter_idx = iter(batches) - self.output_queue = Queue() # TODO: maxsize - self.terminate_event = Event() - self.start() - - def run(self): - try: - while True: - result = self.get_data() - self.output_queue.put(result) - except StopIteration: - self.output_queue.put(None) - - def __next__(self): - result = self.output_queue.get() - if result is None: - self.close() - raise StopIteration() - return result - - def __iter__(self): - return self - - def __del__(self): - self.close() - - def close(self): - self.terminate_event.set() - - def get_data(self): - batch_idx = next(self.iter_idx) - batch = self.data[batch_idx] - return batch - - -# %% ../../nbs/loader.jax.ipynb 7 +# %% ../../nbs/loader.jax.ipynb 5 @dispatch def to_jax_dataset(dataset: JAXDataset): if isinstance(dataset, ArrayDataset): @@ -84,7 +35,7 @@ def to_jax_dataset(dataset: JAXDataset): def to_jax_dataset(dataset: HFDataset): return dataset.with_format('numpy') -# %% ../../nbs/loader.jax.ipynb 8 +# %% ../../nbs/loader.jax.ipynb 6 class DataLoaderJAX(BaseDataLoader): @typecheck diff --git a/nbs/experimental/mp.ipynb b/nbs/experimental/mp.ipynb new file mode 100644 index 0000000..457bcf1 --- /dev/null +++ b/nbs/experimental/mp.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp experimental.multi_processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from __future__ import print_function, division, annotations\n", + "from jax_dataloader.imports import *\n", + "from jax_dataloader.datasets import ArrayDataset, JAXDataset\n", + "from jax_dataloader.loaders import BaseDataLoader\n", + "from jax_dataloader.utils import get_config, asnumpy\n", + "from jax_dataloader.tests import *\n", + "import jax_dataloader as jdl\n", + "from threading import Thread, Event\n", + "from queue import Queue, Full\n", + "import multiprocessing as mp\n", + "import weakref" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def chunk(seq: Sequence, size: int) -> List[Sequence]:\n", + " return [seq[pos:pos + size] for pos in range(0, len(seq), size)] \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class EpochIterator(Thread):\n", + " \"\"\"[WIP] Multiprocessing Epoch Iterator\"\"\"\n", + " \n", + " def __init__(self, data, batch_size: int, indices: Sequence[int]):\n", + " super().__init__()\n", + " self.data = data\n", + " batches = chunk(indices, batch_size)\n", + " self.iter_idx = iter(batches)\n", + " self.output_queue = Queue(5) # TODO: maxsize\n", + " self.terminate_event = Event()\n", + " self.start()\n", + "\n", + " def run(self):\n", + " try:\n", + " while True:\n", + " # get data\n", + " result = self.get_data()\n", + " # put result in queue\n", + " while True:\n", + " try: \n", + " self.output_queue.put(result, block=True, timeout=0.5)\n", + " break\n", + " except Full: pass\n", + " \n", + " if self.terminate_event.is_set(): return \n", + "\n", + " except StopIteration:\n", + " self.output_queue.put(None)\n", + "\n", + " def __next__(self):\n", + " result = self.output_queue.get()\n", + " if result is None:\n", + " self.close()\n", + " raise StopIteration()\n", + " return result\n", + " \n", + " def __iter__(self):\n", + " return self\n", + " \n", + " def __del__(self):\n", + " self.close()\n", + "\n", + " def close(self):\n", + " self.terminate_event.set()\n", + "\n", + " def get_data(self):\n", + " batch_idx = next(self.iter_idx)\n", + " batch = self.data[batch_idx]\n", + " return batch\n" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/loader.jax.ipynb b/nbs/loader.jax.ipynb index 5d9996c..3bc2673 100644 --- a/nbs/loader.jax.ipynb +++ b/nbs/loader.jax.ipynb @@ -59,17 +59,6 @@ "from queue import Queue" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "def chunk(seq: Sequence, size: int) -> List[Sequence]:\n", - " return [seq[pos:pos + size] for pos in range(0, len(seq), size)] \n" - ] - }, { "cell_type": "code", "execution_count": null, @@ -87,56 +76,6 @@ " yield data[idx]" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "class MultiprocessIterator(Thread):\n", - " \"\"\"[WIP] Multiprocessing Epoch Iterator\"\"\"\n", - " \n", - " def __init__(self, data, batch_size: int, indices=None):\n", - " super().__init__()\n", - " self.data = data\n", - " indices = np.arange(len(data)) if indices is None else indices\n", - " batches = chunk(indices, batch_size)\n", - " self.iter_idx = iter(batches)\n", - " self.output_queue = Queue() # TODO: maxsize\n", - " self.terminate_event = Event()\n", - " self.start()\n", - "\n", - " def run(self):\n", - " try:\n", - " while True:\n", - " result = self.get_data()\n", - " self.output_queue.put(result)\n", - " except StopIteration:\n", - " self.output_queue.put(None)\n", - "\n", - " def __next__(self):\n", - " result = self.output_queue.get()\n", - " if result is None:\n", - " self.close()\n", - " raise StopIteration()\n", - " return result\n", - " \n", - " def __iter__(self):\n", - " return self\n", - " \n", - " def __del__(self):\n", - " self.close()\n", - "\n", - " def close(self):\n", - " self.terminate_event.set()\n", - "\n", - " def get_data(self):\n", - " batch_idx = next(self.iter_idx)\n", - " batch = self.data[batch_idx]\n", - " return batch\n" - ] - }, { "cell_type": "code", "execution_count": null,