Skip to content

Commit

Permalink
feat: synced changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 8, 2024
1 parent 38bec68 commit d019a7e
Show file tree
Hide file tree
Showing 5 changed files with 611 additions and 107 deletions.
4 changes: 2 additions & 2 deletions datasets/cc12m downloader.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

img2dataset --url_list ./cc12m.tsv --input_format "tsv"\
--url_col "image_url" --caption_col "caption" --output_format arrayrecord\
--output_folder gs://flaxdiff-datasets-regional/arrayrecord/cc12m --processes_count 64\
--thread_count 64 --image_size 256\
--output_folder gs://flaxdiff-datasets-regional/arrayrecord2/cc12m --processes_count 64\
--thread_count 64 --image_size 256 --number_sample_per_shard 50000 --min_image_size 100 \
--enable_wandb True --disallowed_header_directives '[]' --compute_hash None --max_shard_retry 3 --timeout 60
2 changes: 1 addition & 1 deletion datasets/custom datasets downloader.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
img2dataset --url_list $1 --input_format "parquet"\
--url_col "url" --caption_col "caption" --output_format arrayrecord\
--output_folder $2 --processes_count 64\
--thread_count 64 --image_size 256 --min_image_size 100 \
--thread_count 64 --image_size 256 --min_image_size 100 --number_sample_per_shard 40000 \
--enable_wandb True --disallowed_header_directives '[]' --compute_hash None --max_shard_retry 3 --timeout 60

# gs://flaxdiff-datasets-regional/arrayrecord/laion-aesthetics-12m+mscoco-2017
50 changes: 27 additions & 23 deletions datasets/dataset preparations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mrwhite0racle/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import webdataset as wds\n",
"import jax\n",
Expand Down Expand Up @@ -119,14 +110,17 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading readme: 100%|██████████████████████████████████████████████████████████████████████████████████| 4.01k/4.01k [00:00<00:00, 17.8MB/s]\n"
"Downloading readme: 100%|██████████| 4.01k/4.01k [00:00<00:00, 12.7MB/s]\n",
"Downloading data: 100%|██████████| 2.44G/2.44G [01:20<00:00, 30.2MB/s]\n",
"Generating train split: 100%|██████████| 12096809/12096809 [00:11<00:00, 1093969.94 examples/s]\n",
"Map (num_proc=16): 100%|██████████| 12096809/12096809 [00:49<00:00, 246633.91 examples/s] \n"
]
}
],
Expand All @@ -142,9 +136,19 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading data: 100%|██████████| 18.3M/18.3M [00:01<00:00, 12.3MB/s]\n",
"Generating train split: 100%|██████████| 591753/591753 [00:00<00:00, 2517731.17 examples/s]\n",
"Map (num_proc=16): 100%|██████████| 591753/591753 [00:02<00:00, 262603.89 examples/s] \n"
]
}
],
"source": [
"mscoco = load_dataset(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\")\n",
"mscoco_fused = mscoco['train']\n",
Expand All @@ -154,16 +158,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"fused_data = concatenate_datasets([laion12m6_fused, mscoco_fused])"
"fused_data = concatenate_datasets([laion12m6_fused, mscoco_fused, mscoco_fused])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {
"notebookRunGroups": {
"groupValue": "1"
Expand All @@ -176,7 +180,7 @@
"12688562"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -206,14 +210,14 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Creating parquet from Arrow format: 100%|████████████████████████████████████████████████████████████████| 12689/12689 [00:15<00:00, 829.59ba/s]\n"
"Creating parquet from Arrow format: 100%|██████████| 12689/12689 [00:20<00:00, 622.45ba/s] \n"
]
},
{
Expand All @@ -222,7 +226,7 @@
"2454709042"
]
},
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -479,7 +483,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
660 changes: 580 additions & 80 deletions evaluate.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion flaxdiff/models/autoencoder/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def encode(self, images, rngkey: jax.random.PRNGKey = None):
log_std = jnp.clip(log_std, -30, 20)
std = jnp.exp(0.5 * log_std)
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
print("Sampled")
# print("Sampled")
else:
# return the mean
latents, _ = jnp.split(latents, 2, axis=-1)
Expand Down

0 comments on commit d019a7e

Please sign in to comment.