Skip to content

Commit

Permalink
feat: tpu training with cc12m dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 1, 2024
1 parent df2eca3 commit f61bb9d
Show file tree
Hide file tree
Showing 9 changed files with 2,857 additions and 3,008 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ good models
.env
.env
tensorboard
wandb
gcs_mount
datacache
3,349 changes: 453 additions & 2,896 deletions Diffusion flax linen on TPUs.ipynb

Large diffs are not rendered by default.

105 changes: 104 additions & 1 deletion Diffusion flax linen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,109 @@
" return embed_pooled, embed_labels_full"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_dataset_tf(data_name=\"oxford_flowers102\", batch_size=64, image_scale=256):\n",
" import tensorflow as tf\n",
"\n",
" if os.path.exists(f\"./datacache/{data_name}_labels.pkl\"):\n",
" print(\"Loading labels from cache\")\n",
" with open(f\"./datacache/{data_name}_labels.pkl\", \"rb\") as f:\n",
" import pickle\n",
" embed = pickle.load(f)\n",
" embed_labels = embed[\"embed_labels\"]\n",
" embed_labels_full = embed[\"embed_labels_full\"]\n",
" null_labels = embed[\"null_labels\"]\n",
" null_labels_full = embed[\"null_labels_full\"]\n",
" else:\n",
" print(\"No cache found, generating labels\")\n",
" textlabels = dataToLabelGenMap[data_name]()\n",
" \n",
" model, tokenizer = defaultTextEncodeModel()\n",
"\n",
" embed_labels, embed_labels_full = encodePrompts(textlabels, model, tokenizer)\n",
" embed_labels = embed_labels.tolist()\n",
" embed_labels_full = embed_labels_full.tolist()\n",
" \n",
" null_labels, null_labels_full = encodePrompts([\"\"], model, tokenizer)\n",
" null_labels = null_labels.tolist()[0]\n",
" null_labels_full = null_labels_full.tolist()[0]\n",
" \n",
" os.makedirs(\"./datacache\", exist_ok=True)\n",
" with open(f\"./datacache/{data_name}_labels.pkl\", \"wb\") as f:\n",
" import pickle\n",
" pickle.dump({\n",
" \"embed_labels\": embed_labels,\n",
" \"embed_labels_full\": embed_labels_full,\n",
" \"null_labels\": null_labels,\n",
" \"null_labels_full\": null_labels_full\n",
" }, f)\n",
" \n",
" embed_labels = tf.convert_to_tensor([np.array(i, dtype=np.float16) for i in embed_labels])\n",
" embed_labels_full = tf.convert_to_tensor(np.array([np.array(i, dtype=np.float16) for i in embed_labels_full]))\n",
" null_labels = np.array(null_labels, dtype=np.float16)\n",
" null_labels_full = np.array(null_labels_full, dtype=np.float16)\n",
" \n",
" def labelizer(labelidx:int) -> np.array:\n",
" label_pooled = embed_labels[labelidx]\n",
" label_seq = embed_labels_full[labelidx]\n",
" return label_pooled, label_seq\n",
" \n",
" def augmenter(image_scale=256, method=\"area\"):\n",
" @tf.function()\n",
" def augment(sample):\n",
" image = (\n",
" tf.cast(sample[\"image\"], tf.float32) - 127.5\n",
" ) / 127.5\n",
" image = tf.image.resize(\n",
" image, [image_scale, image_scale], method=method, antialias=True\n",
" )\n",
" image = tf.image.random_flip_left_right(image)\n",
" image = tf.image.random_contrast(image, 0.999, 1.05)\n",
" image = tf.image.random_brightness(image, 0.2)\n",
"\n",
" image = tf.clip_by_value(image, -1.0, 1.0)\n",
" labelidx = sample[\"label\"]\n",
" label, label_seq = labelizer(labelidx)\n",
" # image, label = move2gpu(image, label)\n",
" return {'image':image, 'label':label, 'label_seq':label_seq}\n",
" return augment\n",
"\n",
" # Load CelebA Dataset\n",
" data: tf.data.Dataset = tfds.load(data_name, split=\"all\", shuffle_files=True)\n",
" train_len = len(data)\n",
" final_data = (\n",
" data\n",
" .cache() # Cache after augmenting to avoid recomputation\n",
" .map(\n",
" augmenter(image_scale, method=\"area\"),\n",
" num_parallel_calls=tf.data.AUTOTUNE,\n",
" )\n",
" .repeat() # Repeats the dataset indefinitely\n",
" .shuffle(4096) # Ensure this is adequate for your dataset size\n",
" .batch(batch_size, drop_remainder=True)\n",
" .prefetch(tf.data.experimental.AUTOTUNE)\n",
" )\n",
" \n",
" def get_trainset():\n",
" return final_data.as_numpy_iterator()\n",
" \n",
" return {\n",
" \"train\": get_trainset,\n",
" \"train_len\": train_len,\n",
" \"batch_size\": batch_size,\n",
" \"null_labels\": null_labels,\n",
" \"null_labels_full\": null_labels_full,\n",
" \"embed_labels\": embed_labels,\n",
" \"embed_labels_full\": embed_labels_full,\n",
" \n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 4,
Expand Down Expand Up @@ -8709,7 +8812,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
8 changes: 8 additions & 0 deletions datasets/cc12m downloader.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

img2dataset --url_list ./datacache/cc12m.csv --input_format "csv"\
--url_col "image_url" --caption_col "caption" --output_format arrayrecord\
--output_folder gs://flaxdiff-datasets/arrayrecord/cc12m --processes_count 240 --thread_count 64 --image_size 256\
--enable_wandb True --disallowed_header_directives '[]' --compute_hash None --max_shard_retry 3 --timeout 60

./gcsfuse.sh DATASET_GCS_BUCKET=flaxdiff-datasets MOUNT_PATH=gcs_mount
Loading

0 comments on commit f61bb9d

Please sign in to comment.