Skip to content

Commit

Permalink
feat: latest training script
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 17, 2024
1 parent f0bfd6c commit ae85f41
Showing 1 changed file with 56 additions and 23 deletions.
79 changes: 56 additions & 23 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import wandb.util
import wandb.wandb_run
from flaxdiff.models.simple_unet import Unet
from flaxdiff.models.simple_vit import UViT
import jax.experimental.pallas.ops.tpu.flash_attention
from flaxdiff.predictors import VPredictionTransform, EpsilonPredictionTransform, DiffusionPredictionTransform, DirectPredictionTransform, KarrasPredictionTransform
from flaxdiff.schedulers import CosineNoiseSchedule, NoiseScheduler, GeneralizedNoiseScheduler, KarrasVENoiseScheduler, EDMNoiseScheduler
Expand Down Expand Up @@ -1316,6 +1317,7 @@ def boolean_string(s):
parser.add_argument('--noise_schedule', type=str, default='edm',
choices=['cosine', 'karras', 'edm'], help='Noise schedule')

parser.add_argument('--architecture', type=str, choices=["unet", "uvit"], default="unet", help='Architecture to use')
parser.add_argument('--emb_features', type=int, default=256, help='Embedding features')
parser.add_argument('--feature_depths', type=int, nargs='+', default=[64, 128, 256, 512], help='Feature depths')
parser.add_argument('--attention_heads', type=int, default=8, help='Number of attention heads')
Expand All @@ -1331,6 +1333,11 @@ def boolean_string(s):
parser.add_argument('--num_middle_res_blocks', type=int, default=1, help='Number of middle residual blocks')
parser.add_argument('--activation', type=str, default='swish', help='activation to use')

parser.add_argument('--patch_size', type=int, default=16, help='Patch size for the transformer if using UViT')
parser.add_argument('--num_layers', type=int, default=12, help='Number of layers in the transformer if using UViT')
parser.add_argument('--num_heads', type=int, default=12, help='Number of heads in the transformer if using UViT')
parser.add_argument('--mlp_ratio', type=int, default=4, help='MLP ratio in the transformer if using UViT')

parser.add_argument('--dtype', type=str, default=None, help='dtype to use')
parser.add_argument('--precision', type=str, default=None, help='precision to use', choices=['high', 'default', 'highest', 'None', None])

Expand Down Expand Up @@ -1486,21 +1493,48 @@ def main(args):
autoencoder = StableDiffusionVAE(**autoencoder_opts)
INPUT_CHANNELS = 4
DIFFUSION_INPUT_SIZE = DIFFUSION_INPUT_SIZE // 8

CONFIG = {
"model": {
"emb_features": args.emb_features,
"feature_depths": args.feature_depths,
"attention_configs": attention_configs,
"num_res_blocks": args.num_res_blocks,
"num_middle_res_blocks": args.num_middle_res_blocks,
"dtype": DTYPE,
"precision": PRECISION,
"activation": args.activation,
"output_channels": INPUT_CHANNELS,
"norm_groups": args.norm_groups,
"named_norms": args.named_norms,

model_config = {
"emb_features": args.emb_features,
"dtype": DTYPE,
"precision": PRECISION,
"activation": args.activation,
"output_channels": INPUT_CHANNELS,
"norm_groups": args.norm_groups,
}

MODEL_ARCHITECUTRES = {
"unet": {
"class": Unet,
"kwargs": {
"feature_depths": args.feature_depths,
"attention_configs": attention_configs,
"num_res_blocks": args.num_res_blocks,
"num_middle_res_blocks": args.num_middle_res_blocks,
"named_norms": args.named_norms,
},
},
"uvit": {
"class": UViT,
"kwargs": {
"patch_size": args.patch_size,
"num_layers": args.num_layers,
"num_heads": args.num_heads,
"dropout_rate": 0.1,
"use_projection": False,
},
}
}

model_architecture = MODEL_ARCHITECUTRES[args.architecture]['class']
model_config.update(MODEL_ARCHITECUTRES[args.architecture]['kwargs'])

if args.architecture == 'uvit':
model_config['emb_features'] = 768

CONFIG = {
"model": model_config,
"architecture": args.architecture,
"dataset": {
"name": dataset_name,
"length": datalen,
Expand Down Expand Up @@ -1535,9 +1569,8 @@ def main(args):

print("Experiment_Name:", experiment_name)

model_config = CONFIG['model']
model_config['activation'] = ACTIVATION_MAP[model_config['activation']]
unet = Unet(**model_config)
unet = model_architecture(**model_config)

learning_rate = CONFIG['learning_rate']
optimizer = OPTIMIZER_MAP[args.optimizer]
Expand Down Expand Up @@ -1630,14 +1663,14 @@ def main(args):
for tpu-v4-64
python3 training.py --dataset=combined_online --dataset_path='/home/mrwhite0racle/gcs_mount/'\
python3 training.py --dataset=combined_30m --dataset_path='/home/mrwhite0racle/gcs_mount/'\
--checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\
--epochs=40 --batch_size=512 --image_size=512 --learning_rate=4e-5 \
--num_res_blocks=4 --emb_features 512 --feature_depths 128 256 512 512 --norm_groups 0 --only_pure_attention=True --use_self_and_cross=False \
--dtype=bfloat16 --precision=default --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_ldm_combined_online-bigger'\
--epochs=40 --batch_size=512 --image_size=128 --learning_rate=9e-5 \
--architecture=uvit --num_layers=12 --emb_features=768 --norm_groups 0 --num_heads=12 \
--dtype=bfloat16 --precision=default \
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_uvit_combined_30m'\
--optimizer=adamw --clip_grads 0.5 \
--learning_rate_schedule=cosine --learning_rate_peak=2.7e-4 --learning_rate_end=9e-5 --learning_rate_warmup_steps=10000 --learning_rate_decay_epochs=2\
--optimizer=adamw --autoencoder=stable_diffusion --clip_grads 0.5
--load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_30m/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_NEW_ARCH_combined_30'
Expand All @@ -1649,7 +1682,7 @@ def main(args):
python3 training.py --dataset=combined_30m --dataset_path='/home/mrwhite0racle/gcs_mount/'\
--checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\
--epochs=40 --batch_size=256 --image_size=128 \
--learning_rate=4e-5 --num_res_blocks=3 \
--learning_rate=9e-5 --architecture=uvit --num_layers=12 \
--use_self_and_cross=False --precision=default --dtype=bfloat16 --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_flaxdiff-0-1-10__new-combined_30m'\
--optimizer=adamw --feature_depths 128 256 512 512 --use_dynamic_scale=True\
Expand Down

0 comments on commit ae85f41

Please sign in to comment.