From 9991cca193dd551ae2b0b25fedb956b8ce86acb0 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Sun, 8 Sep 2024 02:01:50 -0400 Subject: [PATCH] feat: modified uvit arch --- datasets/dataset preparations copy.ipynb | 1491 ---------------------- datasets/dataset preparations.ipynb | 789 ++++++++---- flaxdiff/models/simple_vit.py | 26 +- training.py | 2 +- 4 files changed, 539 insertions(+), 1769 deletions(-) delete mode 100644 datasets/dataset preparations copy.ipynb diff --git a/datasets/dataset preparations copy.ipynb b/datasets/dataset preparations copy.ipynb deleted file mode 100644 index 483a601..0000000 --- a/datasets/dataset preparations copy.ipynb +++ /dev/null @@ -1,1491 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import webdataset as wds\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import augmax\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import grain.python as pygrain\n", - "from typing import Any, Dict, List, Tuple\n", - "import numpy as np\n", - "from functools import partial\n", - "import tqdm \n", - "\n", - "import fsspec\n", - "import json\n", - "\n", - "import os\n", - "from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel\n", - "\n", - "from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk\n", - "from datasets.utils.file_utils import get_datasets_user_agent\n", - "from concurrent.futures import ThreadPoolExecutor\n", - "from functools import partial\n", - "import io\n", - "import urllib\n", - "\n", - "import PIL.Image\n", - "import cv2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "USER_AGENT = get_datasets_user_agent()\n", - "\n", - "\n", - "def fetch_single_image(image_url, timeout=None, retries=0):\n", - " for _ in range(retries + 1):\n", - " try:\n", - " request = urllib.request.Request(\n", - " image_url,\n", - " data=None,\n", - " headers={\"user-agent\": USER_AGENT},\n", - " )\n", - " with urllib.request.urlopen(request, timeout=timeout) as req:\n", - " image = PIL.Image.open(io.BytesIO(req.read()))\n", - " break\n", - " except Exception:\n", - " image = None\n", - " return image\n", - "\n", - "denormalizeImage = lambda x: (x + 1.0) * 127.5\n", - "\n", - "def plotImages(imgs, fig_size=(8, 8), dpi=100):\n", - " fig = plt.figure(figsize=fig_size, dpi=dpi)\n", - " imglen = imgs.shape[0]\n", - " for i in range(imglen):\n", - " plt.subplot(fig_size[0], fig_size[1], i + 1)\n", - " plt.imshow(jnp.astype(denormalizeImage(imgs[i, :, :, :]), jnp.uint8))\n", - " plt.axis(\"off\")\n", - " plt.show()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Filtering pipeline for various datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def dataMapper(map: Dict[str, Any]):\n", - " def _map(sample) -> Dict[str, Any]:\n", - " return {\n", - " \"url\": sample[map[\"url\"]],\n", - " \"caption\": sample[map[\"caption\"]],\n", - " }\n", - " return _map\n", - "\n", - "def imageFetcher():\n", - " def fetch_images(batch, num_threads, timeout=None, retries=0):\n", - " fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)\n", - " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", - " batch[\"image\"] = list(executor.map(fetch_single_image_with_args, batch[\"url\"]))\n", - " return batch\n", - " return fetch_images\n", - "\n", - "def mapDataset(dataset, args, mapper=dataMapper, workers=16, batch_size=10000, should_remove_columns=True, fn_kwargs={}):\n", - " if should_remove_columns:\n", - " remove_columns = dataset.column_names\n", - " else:\n", - " remove_columns = None\n", - " return dataset.map(mapper(*args), batched=True, batch_size=batch_size, remove_columns=remove_columns, num_proc=workers, fn_kwargs=fn_kwargs) " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e357e8fa8418439e8d2d0a8e23f3d1c5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Map (num_proc=16): 0%| | 0/12096809 [00:00 value[\"max\"]:\n", - " return False\n", - " return True\n", - " return _filter\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c5d3eddced904acca1ddd5625e84d5ed", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Filter (num_proc=64): 0%| | 0/746972269 [00:00 1\u001b[0m leonardo \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbigdata-pw/leonardo\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mall\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m leonardoMap \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124murl\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage_url\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcaption\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcaption\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 5\u001b[0m }\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/load.py:2616\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2613\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\u001b[38;5;241m.\u001b[39mas_streaming_dataset(split\u001b[38;5;241m=\u001b[39msplit)\n\u001b[1;32m 2615\u001b[0m \u001b[38;5;66;03m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 2616\u001b[0m \u001b[43mbuilder_instance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2617\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2618\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2619\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2620\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2621\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2622\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2624\u001b[0m \u001b[38;5;66;03m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 2625\u001b[0m keep_in_memory \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2626\u001b[0m keep_in_memory \u001b[38;5;28;01mif\u001b[39;00m keep_in_memory \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_small_dataset(builder_instance\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size)\n\u001b[1;32m 2627\u001b[0m )\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/builder.py:1029\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, ignore_verifications, try_from_hf_gcs, dl_manager, base_path, use_auth_token, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_proc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1028\u001b[0m prepare_split_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_proc\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m num_proc\n\u001b[0;32m-> 1029\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1030\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1031\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1032\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1033\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdownload_and_prepare_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1034\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1035\u001b[0m \u001b[38;5;66;03m# Sync info\u001b[39;00m\n\u001b[1;32m 1036\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(split\u001b[38;5;241m.\u001b[39mnum_bytes \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39msplits\u001b[38;5;241m.\u001b[39mvalues())\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/builder.py:1124\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 1120\u001b[0m split_dict\u001b[38;5;241m.\u001b[39madd(split_generator\u001b[38;5;241m.\u001b[39msplit_info)\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1123\u001b[0m \u001b[38;5;66;03m# Prepare split will record examples associated to the split\u001b[39;00m\n\u001b[0;32m-> 1124\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_generator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1125\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m(\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot find data file. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_download_instructions \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mOriginal error:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(e)\n\u001b[1;32m 1131\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/builder.py:1913\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split\u001b[0;34m(self, split_generator, file_format, num_proc, max_shard_size)\u001b[0m\n\u001b[1;32m 1911\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m Pool(num_proc) \u001b[38;5;28;01mas\u001b[39;00m pool:\n\u001b[1;32m 1912\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m pbar:\n\u001b[0;32m-> 1913\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m job_id, done, content \u001b[38;5;129;01min\u001b[39;00m iflatmap_unordered(\n\u001b[1;32m 1914\u001b[0m pool, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_split_single, kwargs_iterable\u001b[38;5;241m=\u001b[39mkwargs_per_job\n\u001b[1;32m 1915\u001b[0m ):\n\u001b[1;32m 1916\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m done:\n\u001b[1;32m 1917\u001b[0m \u001b[38;5;66;03m# the content is the result of the job\u001b[39;00m\n\u001b[1;32m 1918\u001b[0m (\n\u001b[1;32m 1919\u001b[0m examples_per_job[job_id],\n\u001b[1;32m 1920\u001b[0m bytes_per_job[job_id],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1923\u001b[0m shard_lengths_per_job[job_id],\n\u001b[1;32m 1924\u001b[0m ) \u001b[38;5;241m=\u001b[39m content\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/utils/py_utils.py:718\u001b[0m, in \u001b[0;36miflatmap_unordered\u001b[0;34m(pool, func, kwargs_iterable)\u001b[0m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m pool_changed:\n\u001b[1;32m 717\u001b[0m \u001b[38;5;66;03m# we get the result in case there's an error to raise\u001b[39;00m\n\u001b[0;32m--> 718\u001b[0m [async_result\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.05\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m async_result \u001b[38;5;129;01min\u001b[39;00m async_results]\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/utils/py_utils.py:718\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m pool_changed:\n\u001b[1;32m 717\u001b[0m \u001b[38;5;66;03m# we get the result in case there's an error to raise\u001b[39;00m\n\u001b[0;32m--> 718\u001b[0m [\u001b[43masync_result\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.05\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m async_result \u001b[38;5;129;01min\u001b[39;00m async_results]\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/multiprocess/pool.py:774\u001b[0m, in \u001b[0;36mApplyResult.get\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 772\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value\n\u001b[1;32m 773\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 774\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value\n", - "\u001b[0;31mDatasetGenerationError\u001b[0m: An error occurred while generating the dataset" - ] - } - ], - "source": [ - "leonardo = load_dataset(\"bigdata-pw/leonardo\", split=\"all\", num_proc=64)\n", - "leonardoMap = {\n", - " \"url\": \"image_url\",\n", - " \"caption\": \"caption\",\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "heavyFilterMap = {\n", - " \"like_count\": {\"min\": 1, \"max\": 100000000},\n", - "}\n", - "\n", - "def leonardoFilter(filterMap):\n", - " def _filter(sample):\n", - " # if len(sample['negative_prompt']) != 0:\n", - " # return False\n", - " for key, value in filterMap.items():\n", - " if sample[key] < value[\"min\"] or sample[key] > value[\"max\"]:\n", - " return False\n", - " return True\n", - " return _filter\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "leonardoLiked = leonardo.filter(leonardoFilter(heavyFilterMap), num_proc=120)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "final_data = mapDataset(leonardoLiked, ({\n", - " \"url\":\"url\",\n", - " \"caption\":\"text\"\n", - " },), batch_size=1000000, workers=None)\n", - "\n", - "final_data.save_to_disk(\"gs://flaxdiff-datasets-regional/datasets/leonardo-liked\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d522bd677e184c57855719a8d6013f31", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Resolving data files: 0%| | 0/958 [00:00 2:\n", - " return\n", - " # check if the variance is too low\n", - " if np.std(image) < 1e-4:\n", - " return\n", - " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", - " downscale = max(original_width, original_height) > max(image_shape)\n", - " interpolation = downscale_interpolation if downscale else upscale_interpolation\n", - " image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)\n", - " image = A.pad(\n", - " image,\n", - " min_height=image_shape[0],\n", - " min_width=image_shape[1],\n", - " border_mode=cv2.BORDER_CONSTANT,\n", - " value=[255, 255, 255],\n", - " )\n", - " data_queue.put({\n", - " \"url\": url,\n", - " \"caption\": caption,\n", - " \"image\": image\n", - " })\n", - " except Exception as e:\n", - " error_queue.put({\n", - " \"url\": url,\n", - " \"caption\": caption,\n", - " \"error\": str(e)\n", - " })\n", - " \n", - "def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):\n", - " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", - " executor.map(map_sample, batch[\"url\"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)\n", - " \n", - "def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):\n", - " map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)\n", - " shard_len = len(dataset) // num_workers\n", - " print(f\"Local Shard lengths: {shard_len}\")\n", - " with multiprocessing.Pool(num_workers) as pool:\n", - " iteration = 0\n", - " while True:\n", - " # Repeat forever\n", - " dataset = dataset.shuffle(seed=iteration)\n", - " shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]\n", - " pool.map(map_batch_fn, shards)\n", - " iteration += 1\n", - " \n", - "class ImageBatchIterator:\n", - " def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):\n", - " self.dataset = dataset\n", - " self.num_workers = num_workers\n", - " self.batch_size = batch_size\n", - " loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)\n", - " self.thread = threading.Thread(target=loader, args=(dataset))\n", - " self.thread.start()\n", - " \n", - " def __iter__(self):\n", - " return self\n", - " \n", - " def __next__(self):\n", - " def fetcher(_):\n", - " return data_queue.get()\n", - " with ThreadPoolExecutor(max_workers=self.batch_size) as executor:\n", - " batch = list(executor.map(fetcher, range(self.batch_size)))\n", - " return batch\n", - " \n", - " def __del__(self):\n", - " self.thread.join()\n", - " \n", - " def __len__(self):\n", - " return len(self.dataset) // self.batch_size\n", - " \n", - "def default_collate(batch):\n", - " urls = [sample[\"url\"] for sample in batch]\n", - " captions = [sample[\"caption\"] for sample in batch]\n", - " images = np.stack([sample[\"image\"] for sample in batch], axis=0)\n", - " return {\n", - " \"url\": urls,\n", - " \"caption\": captions,\n", - " \"image\": images,\n", - " }\n", - " \n", - "def dataMapper(map: Dict[str, Any]):\n", - " def _map(sample) -> Dict[str, Any]:\n", - " return {\n", - " \"url\": sample[map[\"url\"]],\n", - " \"caption\": sample[map[\"caption\"]],\n", - " }\n", - " return _map\n", - "\n", - "class OnlineStreamingDataLoader():\n", - " def __init__(\n", - " self, \n", - " dataset, \n", - " batch_size=64, \n", - " num_workers=16, \n", - " num_threads=512,\n", - " default_split=\"all\",\n", - " pre_map_maker=dataMapper, \n", - " pre_map_def={\n", - " \"url\": \"URL\",\n", - " \"caption\": \"TEXT\",\n", - " },\n", - " global_process_count=1,\n", - " global_process_index=0,\n", - " prefetch=1000,\n", - " collate_fn=default_collate,\n", - " ):\n", - " if isinstance(dataset, str):\n", - " dataset_path = dataset\n", - " print(\"Loading dataset from path\")\n", - " dataset = load_dataset(dataset_path, split=default_split)\n", - " elif isinstance(dataset, list):\n", - " if isinstance(dataset[0], str):\n", - " print(\"Loading multiple datasets from paths\")\n", - " dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]\n", - " else:\n", - " print(\"Concatenating multiple datasets\")\n", - " dataset = concatenate_datasets(dataset)\n", - " dataset = dataset.map(pre_map_maker(pre_map_def))\n", - " self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)\n", - " print(f\"Dataset length: {len(dataset)}\")\n", - " self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)\n", - " self.collate_fn = collate_fn\n", - " \n", - " # Launch a thread to load batches in the background\n", - " self.batch_queue = queue.Queue(prefetch)\n", - " \n", - " def batch_loader():\n", - " for batch in self.iterator:\n", - " self.batch_queue.put(batch)\n", - " \n", - " self.loader_thread = threading.Thread(target=batch_loader)\n", - " self.loader_thread.start()\n", - " \n", - " def __iter__(self):\n", - " return self\n", - " \n", - " def __next__(self):\n", - " return self.collate_fn(self.batch_queue.get())\n", - " # return self.collate_fn(next(self.iterator))\n", - " \n", - " def __len__(self):\n", - " return len(self.dataset) // self.batch_size\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from flaxdiff.data.online_loader import OnlineStreamingDataLoader" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=16, default_split=\"train\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataloader.batch_queue.qsize()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_queue.qsize()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "error_queue.qsize()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in tqdm.tqdm(range(0, 2000)):\n", - " batch = next(dataloader)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def parallel_loading(dataset):\n", - " dataset.map(map_batch_fn, num_proc=64, batched=True, batch_size=64, fn_kwargs={\"num_threads\": 64})\n", - " \n", - "thread = threading.Thread(target=parallel_loading, args=(mscoco_fused,))\n", - "thread.start()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch.utils.data import Dataset, DataLoader\n", - "from concurrent.futures import ThreadPoolExecutor\n", - "import aiohttp\n", - "from io import BytesIO\n", - "import asyncio\n", - "from PIL import Image\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class URLDataset(Dataset):\n", - " def __init__(self, data):\n", - " self.data = data\n", - " \n", - " async def fetch_image(self, url):\n", - " async with aiohttp.ClientSession() as session:\n", - " async with session.get(url) as response:\n", - " image_data = await response.read()\n", - " image = Image.open(BytesIO(image_data))\n", - " return image\n", - " \n", - " def __getitem__(self, index):\n", - " data = self.data[index]\n", - " url, caption = data['url'], data['caption']\n", - " loop = asyncio.get_event_loop()\n", - " image = loop.run_until_complete(self.fetch_image(url))\n", - " # Preprocess image and return along with the caption\n", - " image = image.resize((256, 256)) # Example resize\n", - " return image, caption\n", - " \n", - " def __len__(self):\n", - " return len(self.data)\n", - "\n", - "# Example usage\n", - "dataset = URLDataset(mscoco_fused)\n", - "data_loader = DataLoader(dataset, batch_size=256, num_workers=8, prefetch_factor=2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in tqdm.tqdm(data_loader):\n", - " pass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class CustomDataset(Dataset):\n", - " def __init__(self, dataset):\n", - " self.dataset = dataset\n", - " \n", - " def __len__(self):\n", - " return len(self.dataset)\n", - " \n", - " def __getitem__(self, idx):\n", - " url = self.dataset[idx]['url']\n", - " caption = self.dataset[idx]['caption']\n", - " image = fetch_single_image(url) # Assuming fetch_single_image is defined elsewhere\n", - " return {\n", - " \"url\": url,\n", - " \"caption\": caption,\n", - " \"image\": image\n", - " }\n", - "\n", - "def collate_fn(batch):\n", - " # Custom collation logic if needed\n", - " print(batch)\n", - " # urls = [item[\"url\"] for item in batch]\n", - " # fetch_single_image_with_args = partial(fetch_single_image, timeout=10, retries=3)\n", - " # with ThreadPoolExecutor(max_workers=len(batch)) as executor:\n", - " # images = list(executor.map(fetch_single_image_with_args, urls))\n", - " \n", - " # return {\n", - " # \"url\": urls,\n", - " # \"caption\": [item[\"caption\"] for item in batch],\n", - " # \"image\": images\n", - " # }\n", - " \n", - "# Assuming mscoco_fused is your dataset\n", - "dataset = CustomDataset(mscoco_fused)\n", - "data_loader = DataLoader(dataset, batch_size=512, num_workers=8, collate_fn=collate_fn, prefetch_factor=100)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in tqdm.tqdm(data_loader):\n", - " # print(i)\n", - " # break\n", - " pass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "queue.qsize()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install arrayqueues" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with multiprocessing.Manager() as manager:\n", - "img_queue = manager.Queue()\n", - "process = multiprocessing.Process(target=parallel_image_loader, args=(mscoco_fused, img_queue, 8))\n", - "process.start()\n", - "process.join()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import multiprocessing\n", - "from multiprocessing import shared_memory\n", - "import numpy as np\n", - "from concurrent.futures import ThreadPoolExecutor\n", - "from datasets import Dataset\n", - "import threading\n", - "\n", - "def create_shared_array(shape, dtype):\n", - " \"\"\"Create a shared numpy array.\"\"\"\n", - " nbytes = np.prod(shape) * np.dtype(dtype).itemsize\n", - " shm = shared_memory.SharedMemory(create=True, size=nbytes)\n", - " array = np.ndarray(shape, dtype=dtype, buffer=shm.buf)\n", - " return shm, array\n", - "\n", - "def map_fn(url, caption, shared_array, shared_index, lock, shape, dtype):\n", - " image = fetch_single_image(url) # Assuming fetch_single_image is defined elsewhere\n", - " with lock:\n", - " index = shared_index.value\n", - " shared_array[index] = np.frombuffer(image, dtype=dtype).reshape(shape) # Store image in shared memory\n", - " shared_index.value += 1 # Move to the next index\n", - " # Save additional info (url, caption) if necessary\n", - "\n", - "def map_batch_fn(batch, shared_array, shared_index, lock, shape, dtype, num_threads=64):\n", - " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", - " executor.map(\n", - " map_fn, \n", - " batch[\"url\"], \n", - " batch['caption'], \n", - " [shared_array] * len(batch[\"url\"]), \n", - " [shared_index] * len(batch[\"url\"]), \n", - " [lock] * len(batch[\"url\"]), \n", - " [shape] * len(batch[\"url\"]), \n", - " [dtype] * len(batch[\"url\"])\n", - " )\n", - "\n", - "def parallel_image_loader(dataset: Dataset, shared_array, shared_index, lock, shape, dtype, num_workers: int = 8):\n", - " batch_len = len(dataset) // num_workers\n", - " batches = [dataset[i * batch_len:(i + 1) * batch_len] for i in range(num_workers)]\n", - " with multiprocessing.Pool(num_workers) as pool:\n", - " pool.starmap(\n", - " map_batch_fn, \n", - " [(batch, shared_array, shared_index, lock, shape, dtype) for batch in batches]\n", - " )\n", - "\n", - "class ImageBatchIterator:\n", - " def __init__(self, dataset: Dataset, num_workers: int = 8, batch_size: int = 64, image_shape=(224, 224, 3), dtype=np.uint8):\n", - " self.dataset = dataset\n", - " self.num_workers = num_workers\n", - " self.batch_size = batch_size\n", - " self.image_shape = image_shape\n", - " self.dtype = dtype\n", - " \n", - " # Create shared memory array\n", - " self.shm, self.shared_array = create_shared_array((len(dataset),) + image_shape, dtype)\n", - " self.shared_index = multiprocessing.Value('i', 0) # Shared index counter\n", - " self.lock = multiprocessing.Lock() # Lock for safe indexing\n", - " \n", - " self.thread = threading.Thread(target=parallel_image_loader, args=(\n", - " dataset, self.shared_array, self.shared_index, self.lock, image_shape, dtype, num_workers))\n", - " self.thread.start()\n", - " \n", - " def __iter__(self):\n", - " return self\n", - " \n", - " def __next__(self):\n", - " if self.shared_index.value < self.batch_size:\n", - " raise StopIteration\n", - " \n", - " batch_start = max(0, self.shared_index.value - self.batch_size)\n", - " batch_end = self.shared_index.value\n", - " batch = self.shared_array[batch_start:batch_end]\n", - " return batch\n", - " \n", - " def __del__(self):\n", - " self.thread.join()\n", - " self.shm.close()\n", - " self.shm.unlink() # Free shared memory when done\n", - " \n", - " def __len__(self):\n", - " return len(self.dataset) // self.batch_size\n", - "\n", - "# Example usage:\n", - "dataset = ImageBatchIterator(mscoco_fused, num_workers=16, batch_size=64, image_shape=(224, 224, 3))\n", - "for i in tqdm.tqdm(range(0, 100)):\n", - " batch = next(dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in tqdm.tqdm(range(0, 100)):\n", - " batch = next(dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/datasets/dataset preparations.ipynb b/datasets/dataset preparations.ipynb index 2334d64..483a601 100644 --- a/datasets/dataset preparations.ipynb +++ b/datasets/dataset preparations.ipynb @@ -110,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -140,41 +140,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0571b694e010404390eee7a2ec5d2c65", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading data: 0%| | 0.00/18.3M [00:00 1\u001b[0m leonardo \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbigdata-pw/leonardo\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mall\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m leonardoMap \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124murl\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage_url\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcaption\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcaption\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 5\u001b[0m }\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/load.py:2616\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2613\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\u001b[38;5;241m.\u001b[39mas_streaming_dataset(split\u001b[38;5;241m=\u001b[39msplit)\n\u001b[1;32m 2615\u001b[0m \u001b[38;5;66;03m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 2616\u001b[0m \u001b[43mbuilder_instance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2617\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2618\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2619\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2620\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2621\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2622\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2624\u001b[0m \u001b[38;5;66;03m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 2625\u001b[0m keep_in_memory \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2626\u001b[0m keep_in_memory \u001b[38;5;28;01mif\u001b[39;00m keep_in_memory \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_small_dataset(builder_instance\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size)\n\u001b[1;32m 2627\u001b[0m )\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/builder.py:1029\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, ignore_verifications, try_from_hf_gcs, dl_manager, base_path, use_auth_token, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_proc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1028\u001b[0m prepare_split_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_proc\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m num_proc\n\u001b[0;32m-> 1029\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1030\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1031\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1032\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1033\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdownload_and_prepare_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1034\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1035\u001b[0m \u001b[38;5;66;03m# Sync info\u001b[39;00m\n\u001b[1;32m 1036\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(split\u001b[38;5;241m.\u001b[39mnum_bytes \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39msplits\u001b[38;5;241m.\u001b[39mvalues())\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/builder.py:1124\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 1120\u001b[0m split_dict\u001b[38;5;241m.\u001b[39madd(split_generator\u001b[38;5;241m.\u001b[39msplit_info)\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1123\u001b[0m \u001b[38;5;66;03m# Prepare split will record examples associated to the split\u001b[39;00m\n\u001b[0;32m-> 1124\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_generator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1125\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m(\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot find data file. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_download_instructions \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mOriginal error:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(e)\n\u001b[1;32m 1131\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/builder.py:1913\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split\u001b[0;34m(self, split_generator, file_format, num_proc, max_shard_size)\u001b[0m\n\u001b[1;32m 1911\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m Pool(num_proc) \u001b[38;5;28;01mas\u001b[39;00m pool:\n\u001b[1;32m 1912\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m pbar:\n\u001b[0;32m-> 1913\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m job_id, done, content \u001b[38;5;129;01min\u001b[39;00m iflatmap_unordered(\n\u001b[1;32m 1914\u001b[0m pool, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_split_single, kwargs_iterable\u001b[38;5;241m=\u001b[39mkwargs_per_job\n\u001b[1;32m 1915\u001b[0m ):\n\u001b[1;32m 1916\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m done:\n\u001b[1;32m 1917\u001b[0m \u001b[38;5;66;03m# the content is the result of the job\u001b[39;00m\n\u001b[1;32m 1918\u001b[0m (\n\u001b[1;32m 1919\u001b[0m examples_per_job[job_id],\n\u001b[1;32m 1920\u001b[0m bytes_per_job[job_id],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1923\u001b[0m shard_lengths_per_job[job_id],\n\u001b[1;32m 1924\u001b[0m ) \u001b[38;5;241m=\u001b[39m content\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/utils/py_utils.py:718\u001b[0m, in \u001b[0;36miflatmap_unordered\u001b[0;34m(pool, func, kwargs_iterable)\u001b[0m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m pool_changed:\n\u001b[1;32m 717\u001b[0m \u001b[38;5;66;03m# we get the result in case there's an error to raise\u001b[39;00m\n\u001b[0;32m--> 718\u001b[0m [async_result\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.05\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m async_result \u001b[38;5;129;01min\u001b[39;00m async_results]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/utils/py_utils.py:718\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m pool_changed:\n\u001b[1;32m 717\u001b[0m \u001b[38;5;66;03m# we get the result in case there's an error to raise\u001b[39;00m\n\u001b[0;32m--> 718\u001b[0m [\u001b[43masync_result\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.05\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m async_result \u001b[38;5;129;01min\u001b[39;00m async_results]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/multiprocess/pool.py:774\u001b[0m, in \u001b[0;36mApplyResult.get\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 772\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value\n\u001b[1;32m 773\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 774\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value\n", + "\u001b[0;31mDatasetGenerationError\u001b[0m: An error occurred while generating the dataset" + ] + } + ], + "source": [ + "leonardo = load_dataset(\"bigdata-pw/leonardo\", split=\"all\", num_proc=64)\n", + "leonardoMap = {\n", + " \"url\": \"image_url\",\n", + " \"caption\": \"caption\",\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "heavyFilterMap = {\n", + " \"like_count\": {\"min\": 1, \"max\": 100000000},\n", + "}\n", + "\n", + "def leonardoFilter(filterMap):\n", + " def _filter(sample):\n", + " # if len(sample['negative_prompt']) != 0:\n", + " # return False\n", + " for key, value in filterMap.items():\n", + " if sample[key] < value[\"min\"] or sample[key] > value[\"max\"]:\n", + " return False\n", + " return True\n", + " return _filter\n", + " " ] }, { @@ -520,34 +820,128 @@ "metadata": {}, "outputs": [], "source": [ - "final_data[0]" + "leonardoLiked = leonardo.filter(leonardoFilter(heavyFilterMap), num_proc=120)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "# Data Loading Experiments" + "final_data = mapDataset(leonardoLiked, ({\n", + " \"url\":\"url\",\n", + " \"caption\":\"text\"\n", + " },), batch_size=1000000, workers=None)\n", + "\n", + "final_data.save_to_disk(\"gs://flaxdiff-datasets-regional/datasets/leonardo-liked\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d522bd677e184c57855719a8d6013f31", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/958 [00:00 2.4:\n", - " print(f\"Wrong aspect ratio {url}\")\n", + " if max(original_height, original_width) / min(original_height, original_width) > 2:\n", " return\n", " # check if the variance is too low\n", " if np.std(image) < 1e-4:\n", - " print(f\"Low variance {url}\")\n", " return\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", " downscale = max(original_width, original_height) > max(image_shape)\n", " interpolation = downscale_interpolation if downscale else upscale_interpolation\n", - "\n", - " image = image_processor(\n", - " image, image_shape, interpolation=interpolation)\n", - " \n", - " print(f\"Processed {url}\")\n", - "\n", + " image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)\n", + " image = A.pad(\n", + " image,\n", + " min_height=image_shape[0],\n", + " min_width=image_shape[1],\n", + " border_mode=cv2.BORDER_CONSTANT,\n", + " value=[255, 255, 255],\n", + " )\n", " data_queue.put({\n", " \"url\": url,\n", " \"caption\": caption,\n", - " \"image\": image,\n", - " \"original_height\": original_height,\n", - " \"original_width\": original_width,\n", + " \"image\": image\n", " })\n", " except Exception as e:\n", - " print(f\"Error processing {url}\", e)\n", - " # error_queue.put_nowait({\n", - " # \"url\": url,\n", - " # \"caption\": caption,\n", - " # \"error\": str(e)\n", - " # })\n", - " pass\n", - "\n", - "\n", - "def map_batch(\n", - " batch, num_threads=256, image_shape=(256, 256), \n", - " min_image_shape=(128, 128),\n", - " timeout=15, retries=3, image_processor=default_image_processor,\n", - " upscale_interpolation=cv2.INTER_CUBIC,\n", - " downscale_interpolation=cv2.INTER_AREA,\n", - "):\n", - " try:\n", - " map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,\n", - " timeout=timeout, retries=retries, image_processor=image_processor,\n", - " upscale_interpolation=upscale_interpolation,\n", - " downscale_interpolation=downscale_interpolation)\n", - " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", - " executor.map(map_sample_fn, batch[\"url\"], batch['caption'])\n", - " except Exception as e:\n", - " print(f\"Error processing batch\", e)\n", - " # error_queue.put_nowait({\n", - " # \"batch\": batch,\n", - " # \"error\": str(e)\n", - " # })\n", - " pass\n", - "\n", - "\n", - "def parallel_image_loader(\n", - " dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), \n", - " min_image_shape=(128, 128),\n", - " num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,\n", - " upscale_interpolation=cv2.INTER_CUBIC,\n", - " downscale_interpolation=cv2.INTER_AREA,\n", - "):\n", - " map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape, \n", - " min_image_shape=min_image_shape,\n", - " timeout=timeout, retries=retries, image_processor=image_processor,\n", - " upscale_interpolation=upscale_interpolation,\n", - " downscale_interpolation=downscale_interpolation)\n", + " error_queue.put({\n", + " \"url\": url,\n", + " \"caption\": caption,\n", + " \"error\": str(e)\n", + " })\n", + " \n", + "def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):\n", + " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", + " executor.map(map_sample, batch[\"url\"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)\n", + " \n", + "def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):\n", + " map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)\n", " shard_len = len(dataset) // num_workers\n", " print(f\"Local Shard lengths: {shard_len}\")\n", " with multiprocessing.Pool(num_workers) as pool:\n", " iteration = 0\n", " while True:\n", " # Repeat forever\n", - " shards = [dataset[i*shard_len:(i+1)*shard_len]\n", - " for i in range(num_workers)]\n", - " print(f\"mapping {len(shards)} shards\")\n", + " dataset = dataset.shuffle(seed=iteration)\n", + " shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]\n", " pool.map(map_batch_fn, shards)\n", " iteration += 1\n", - " print(f\"Shuffling dataset with seed {iteration}\")\n", - " dataset = dataset.shuffle(seed=iteration)\n", - " # Clear the error queue\n", - " # while not error_queue.empty():\n", - " # error_queue.get_nowait()\n", - "\n", - "\n", + " \n", "class ImageBatchIterator:\n", - " def __init__(\n", - " self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), \n", - " min_image_shape=(128, 128),\n", - " num_workers: int = 8, num_threads=256, timeout=15, retries=3, \n", - " image_processor=default_image_processor,\n", - " upscale_interpolation=cv2.INTER_CUBIC,\n", - " downscale_interpolation=cv2.INTER_AREA,\n", - " ):\n", + " def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):\n", " self.dataset = dataset\n", " self.num_workers = num_workers\n", " self.batch_size = batch_size\n", - " loader = partial(parallel_image_loader, num_threads=num_threads,\n", - " image_shape=image_shape,\n", - " min_image_shape=min_image_shape, \n", - " num_workers=num_workers, \n", - " timeout=timeout, retries=retries, image_processor=image_processor,\n", - " upscale_interpolation=upscale_interpolation,\n", - " downscale_interpolation=downscale_interpolation)\n", - " self.thread = threading.Thread(target=loader, args=(dataset,))\n", + " loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)\n", + " self.thread = threading.Thread(target=loader, args=(dataset))\n", " self.thread.start()\n", - "\n", + " \n", " def __iter__(self):\n", " return self\n", - "\n", + " \n", " def __next__(self):\n", " def fetcher(_):\n", " return data_queue.get()\n", " with ThreadPoolExecutor(max_workers=self.batch_size) as executor:\n", " batch = list(executor.map(fetcher, range(self.batch_size)))\n", " return batch\n", - "\n", + " \n", " def __del__(self):\n", " self.thread.join()\n", - "\n", + " \n", " def __len__(self):\n", " return len(self.dataset) // self.batch_size\n", - "\n", - "\n", + " \n", "def default_collate(batch):\n", " urls = [sample[\"url\"] for sample in batch]\n", " captions = [sample[\"caption\"] for sample in batch]\n", @@ -750,8 +1071,7 @@ " \"caption\": captions,\n", " \"image\": images,\n", " }\n", - "\n", - "\n", + " \n", "def dataMapper(map: Dict[str, Any]):\n", " def _map(sample) -> Dict[str, Any]:\n", " return {\n", @@ -760,18 +1080,15 @@ " }\n", " return _map\n", "\n", - "\n", "class OnlineStreamingDataLoader():\n", " def __init__(\n", - " self,\n", - " dataset,\n", - " batch_size=64,\n", - " image_shape=(256, 256),\n", - " min_image_shape=(128, 128),\n", - " num_workers=16,\n", + " self, \n", + " dataset, \n", + " batch_size=64, \n", + " num_workers=16, \n", " num_threads=512,\n", " default_split=\"all\",\n", - " pre_map_maker=dataMapper,\n", + " pre_map_maker=dataMapper, \n", " pre_map_def={\n", " \"url\": \"URL\",\n", " \"caption\": \"TEXT\",\n", @@ -780,61 +1097,44 @@ " global_process_index=0,\n", " prefetch=1000,\n", " collate_fn=default_collate,\n", - " timeout=15,\n", - " retries=3,\n", - " image_processor=default_image_processor,\n", - " upscale_interpolation=cv2.INTER_CUBIC,\n", - " downscale_interpolation=cv2.INTER_AREA,\n", " ):\n", " if isinstance(dataset, str):\n", " dataset_path = dataset\n", " print(\"Loading dataset from path\")\n", - " if \"gs://\" in dataset:\n", - " dataset = load_from_disk(dataset_path)\n", - " else:\n", - " dataset = load_dataset(dataset_path, split=default_split)\n", + " dataset = load_dataset(dataset_path, split=default_split)\n", " elif isinstance(dataset, list):\n", " if isinstance(dataset[0], str):\n", " print(\"Loading multiple datasets from paths\")\n", - " dataset = [load_from_disk(dataset_path) if \"gs://\" in dataset_path else load_dataset(\n", - " dataset_path, split=default_split) for dataset_path in dataset]\n", - " print(\"Concatenating multiple datasets\")\n", - " dataset = concatenate_datasets(dataset)\n", - " dataset = dataset.shuffle(seed=0)\n", - " dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)\n", - " self.dataset = dataset.shard(\n", - " num_shards=global_process_count, index=global_process_index)\n", + " dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]\n", + " else:\n", + " print(\"Concatenating multiple datasets\")\n", + " dataset = concatenate_datasets(dataset)\n", + " dataset = dataset.map(pre_map_maker(pre_map_def))\n", + " self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)\n", " print(f\"Dataset length: {len(dataset)}\")\n", - " self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,\n", - " min_image_shape=min_image_shape,\n", - " num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,\n", - " timeout=timeout, retries=retries, image_processor=image_processor,\n", - " upscale_interpolation=upscale_interpolation,\n", - " downscale_interpolation=downscale_interpolation)\n", - " self.batch_size = batch_size\n", - "\n", + " self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)\n", + " self.collate_fn = collate_fn\n", + " \n", " # Launch a thread to load batches in the background\n", " self.batch_queue = queue.Queue(prefetch)\n", - "\n", + " \n", " def batch_loader():\n", " for batch in self.iterator:\n", - " try:\n", - " self.batch_queue.put(collate_fn(batch))\n", - " except Exception as e:\n", - " print(\"Error processing batch\", e)\n", - "\n", + " self.batch_queue.put(batch)\n", + " \n", " self.loader_thread = threading.Thread(target=batch_loader)\n", " self.loader_thread.start()\n", - "\n", + " \n", " def __iter__(self):\n", " return self\n", - "\n", + " \n", " def __next__(self):\n", - " return self.batch_queue.get()\n", + " return self.collate_fn(self.batch_queue.get())\n", " # return self.collate_fn(next(self.iterator))\n", - "\n", + " \n", " def __len__(self):\n", - " return len(self.dataset)" + " return len(self.dataset) // self.batch_size\n", + " " ] }, { @@ -843,101 +1143,52 @@ "metadata": {}, "outputs": [], "source": [ - "dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=4, num_threads=128)" + "from flaxdiff.data.online_loader import OnlineStreamingDataLoader" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from flaxdiff.data.online_loader import OnlineStreamingDataLoader" + "dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=16, default_split=\"train\")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading multiple datasets from paths\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Concatenating multiple datasets\n", - "Dataset length: 15055574\n", - "Local Shard lengths: 940973\n" - ] - } - ], + "outputs": [], "source": [ - "dataloader = OnlineStreamingDataLoader([\n", - " \"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017\"\n", - " ], batch_size=16, num_workers=16, num_threads=512, default_split=\"train\")" + "dataloader.batch_queue.qsize()" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "dataloader.batch_queue.qsize()" + "data_queue.qsize()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "data_queue.qsize()" + "error_queue.qsize()" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/100 [00:00