Skip to content

Commit

Permalink
fixed few more bugs in dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 12, 2024
1 parent 95cfc6b commit dbdff91
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 31 deletions.
84 changes: 61 additions & 23 deletions datasets/dataset preparations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,40 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38ceb8843f1246f1901cd6f2bfdd9957",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/128 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "83f7ba8a5d96497d92f5719366d2237d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/352 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"coyo700 = load_dataset(\"kakaobrain/coyo-700m\", num_proc=32)"
"coyo700 = load_dataset(\"kakaobrain/coyo-700m\", num_proc=64)"
]
},
{
Expand Down Expand Up @@ -479,7 +508,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -581,7 +610,7 @@
" self.num_workers = num_workers\n",
" self.batch_size = batch_size\n",
" loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)\n",
" self.thread = threading.Thread(target=loader, args=(dataset, num_workers))\n",
" self.thread = threading.Thread(target=loader, args=(dataset))\n",
" self.thread.start()\n",
" \n",
" def __iter__(self):\n",
Expand Down Expand Up @@ -677,7 +706,16 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from flaxdiff.data.online_loader import OnlineStreamingDataLoader"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -693,21 +731,6 @@
"text": [
"Dataset length: 591753\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Exception in thread Thread-7:\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n",
" self.run()\n",
" File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 766, in run_closure\n",
" _threading_Thread_run(self)\n",
" File \"/usr/lib/python3.10/threading.py\", line 953, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"TypeError: parallel_image_loader() got multiple values for argument 'num_workers'\n"
]
}
],
"source": [
Expand Down Expand Up @@ -777,14 +800,29 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2000/2000 [00:37<00:00, 53.41it/s]\n"
" 0%| | 0/2000 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Exception in thread Thread-11:\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n",
" self.run()\n",
" File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 766, in run_closure\n",
" _threading_Thread_run(self)\n",
" File \"/usr/lib/python3.10/threading.py\", line 953, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"TypeError: parallel_image_loader() got multiple values for argument 'num_threads'\n"
]
}
],
Expand Down
14 changes: 7 additions & 7 deletions flaxdiff/data/online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256
self.num_workers = num_workers
self.batch_size = batch_size
loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
self.thread = threading.Thread(target=loader, args=(dataset))
self.thread = threading.Thread(target=loader, args=(dataset,))
self.thread.start()

def __iter__(self):
Expand All @@ -131,7 +131,7 @@ def __del__(self):

def __len__(self):
return len(self.dataset) // self.batch_size

def default_collate(batch):
urls = [sample["url"] for sample in batch]
captions = [sample["caption"] for sample in batch]
Expand Down Expand Up @@ -177,14 +177,14 @@ def __init__(
if isinstance(dataset[0], str):
print("Loading multiple datasets from paths")
dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
else:
print("Concatenating multiple datasets")
dataset = concatenate_datasets(dataset)
dataset = dataset.map(pre_map_maker(pre_map_def))
print("Concatenating multiple datasets")
dataset = concatenate_datasets(dataset)
dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
print(f"Dataset length: {len(dataset)}")
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
self.collate_fn = collate_fn
self.batch_size = batch_size

# Launch a thread to load batches in the background
self.batch_queue = queue.Queue(prefetch)
Expand All @@ -204,5 +204,5 @@ def __next__(self):
# return self.collate_fn(next(self.iterator))

def __len__(self):
return len(self.dataset) // self.batch_size
return len(self.dataset)

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.14',
version='0.1.15',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit dbdff91

Please sign in to comment.