Skip to content

Commit

Permalink
feat: optimizer, lr schedule, checkpoint base args in training.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 5, 2024
1 parent 5b2b71a commit a91a3ec
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def __init__(self,
param_transforms: Callable = None,
wandb_config: Dict[str, Any] = None,
distributed_training: bool = None,
checkpoint_base_path: str = "./checkpoints",
):
if distributed_training is None or distributed_training is True:
# Auto-detect if we are running on multiple devices
Expand All @@ -403,6 +404,7 @@ def __init__(self,
self.name = name
self.loss_fn = loss_fn
self.input_shapes = input_shapes
self.checkpoint_base_path = checkpoint_base_path

if wandb_config is not None and jax.process_index() == 0:
run = wandb.init(**wandb_config)
Expand Down Expand Up @@ -508,7 +510,7 @@ def get_rngstate(self):

def checkpoint_path(self):
experiment_name = self.name
path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
path = os.path.join(self.checkpoint_base_path, experiment_name)
if not os.path.exists(path):
os.makedirs(path)
return path
Expand Down Expand Up @@ -666,8 +668,8 @@ def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
return epoch_loss, current_step, train_state, rng_state

while self.latest_epoch < epochs:
self.latest_epoch += 1
current_epoch = self.latest_epoch
self.latest_epoch += 1
print(f"\nEpoch {current_epoch}/{epochs}")
start_time = time.time()
epoch_loss = 0
Expand Down Expand Up @@ -901,8 +903,6 @@ def boolean_string(s):
parser.add_argument('--dataset_path', type=str,
default='/home/mrwhite0racle/gcs_mount/arrayrecord/cc12m', help="Dataset location path")

parser.add_argument('--learning_rate', type=float,
default=2e-4, help='Learning rate')
parser.add_argument('--noise_schedule', type=str, default='edm',
choices=['cosine', 'karras', 'edm'], help='Noise schedule')

Expand All @@ -928,6 +928,18 @@ def boolean_string(s):
parser.add_argument('--dataset_test', type=boolean_string,
default=False, help='Run the dataset iterator for 3000 steps for testintg/benchmarking')

parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Checkpoint directory')

parser.add_argument('--optimizer', type=str, default='adamw',
choices=['adam', 'adamw', 'lamb'], help='Optimizer to use')
parser.add_argument('--optimizer_opts', type=str, default='{}', help='Optimizer options as a dictionary')
parser.add_argument('--learning_rate_schedule', type=str, default=None, choices=[None, 'cosine'], help='Learning rate schedule')
parser.add_argument('--learning_rate', type=float,
default=2.7e-4, help='Initial Learning rate')
parser.add_argument('--learning_rate_peak', type=float, default=3e-4, help='Learning rate peak')
parser.add_argument('--learning_rate_end', type=float, default=2e-4, help='Learning rate end')
parser.add_argument('--learning_rate_warmup_steps', type=int, default=10000, help='Learning rate warmup steps')

def main(args):
resource.setrlimit(
resource.RLIMIT_CORE,
Expand Down Expand Up @@ -960,6 +972,12 @@ def main(args):
'swish': jax.nn.swish,
'mish': jax.nn.mish,
}

OPTIMIZER_MAP = {
'adam' : optax.adam,
'adamw' : optax.adamw,
'lamb' : optax.lamb,
}

DTYPE = DTYPE_MAP[args.dtype]
PRECISION = PRECISION_MAP[args.precision]
Expand Down Expand Up @@ -1059,7 +1077,14 @@ def main(args):
unet = Unet(**model_config)

learning_rate = CONFIG['learning_rate']
solver = optax.adam(learning_rate)
optimizer = OPTIMIZER_MAP[args.optimizer]
optimizer_opts = json.loads(args.optimizer_opts)
if args.learning_rate_schedule == 'cosine':
learning_rate = optax.warmup_cosine_decay_schedule(
init_value=learning_rate, peak_value=args.learning_rate_peak, warmup_steps=args.learning_rate_warmup_steps,
decay_steps=batches, end_value=args.learning_rate_end,
)
solver = optimizer(learning_rate, **optimizer_opts)

wandb_config = {
"project": "flaxdiff",
Expand All @@ -1077,7 +1102,8 @@ def main(args):
sigma_data=edm_schedule.sigma_data),
load_from_checkpoint=args.load_from_checkpoint,
wandb_config=wandb_config,
distributed_training=args.distributed_training,
distributed_training=args.distributed_training,
checkpoint_base_path=args.checkpoint_dir,
)

if trainer.distributed_training:
Expand All @@ -1094,8 +1120,10 @@ def main(args):

"""
python3 training.py --dataset=laiona_coco --dataset_path='/home/mrwhite0racle/gcs_mount/arrayrecord/laion-aesthetics-12m+mscoco-2017'\
--epochs=40 --batch_size=64 \
--epochs=40 --batch_size=256 \
--learning_rate=2.7e-4 --num_res_blocks=3 \
--use_self_and_cross=False --dtype=bfloat16 --precision=high --attention_heads=16\
--experiment_name='batch 64 v4-16 host laiona_coco'"
--experiment_name='batch 256 v4-16 host laiona_coco with lr schedule'\
--learning_rate_schedule=cosine --learning_rate_peak=5e-4 --learning_rate_end=1e-4 --learning_rate_warmup_steps=10000\
--optimizer=lamb
"""

0 comments on commit a91a3ec

Please sign in to comment.