Skip to content

Commit

Permalink
feat: latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 7, 2024
1 parent 4798f53 commit 38bec68
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 85 deletions.
100 changes: 48 additions & 52 deletions evaluate.ipynb

Large diffs are not rendered by default.

11 changes: 5 additions & 6 deletions flaxdiff/trainer/diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
local_rng_state = RandomMarkovState(subkey)

images = batch['image']
images = jnp.array(images, dtype=jnp.bfloat16)
# normalize image
images = (images - 127.5) / 127.5

if autoencoder is not None:
# Convert the images to latent space
local_rng_state, rngs = local_rng_state.get_random_key()
images = autoencoder.encode(images, rngs)
else:
# normalize image
images = (images - 127.5) / 127.5
# local_rng_state, rngs = local_rng_state.get_random_key()
images = autoencoder.encode(images)#, rngs)

output = text_embedder(
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
Expand Down Expand Up @@ -199,4 +199,3 @@ def boolean_string(s):
if type(s) == bool:
return s
return s == 'True'

12 changes: 8 additions & 4 deletions flaxdiff/trainer/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self,
wandb_config: Dict[str, Any] = None,
distributed_training: bool = None,
checkpoint_base_path: str = "./checkpoints",
checkpoint_epoch: int = None,
):
if distributed_training is None or distributed_training is True:
# Auto-detect if we are running on multiple devices
Expand Down Expand Up @@ -141,7 +142,7 @@ def __init__(self,
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)

if load_from_checkpoint:
latest_epoch, old_state, old_best_state, rngstate = self.load()
latest_epoch, old_state, old_best_state, rngstate = self.load(checkpoint_epoch)
else:
latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None

Expand Down Expand Up @@ -234,8 +235,11 @@ def tensorboard_path(self):
os.makedirs(path)
return path

def load(self):
epoch = self.checkpointer.latest_step()
def load(self, checkpoint_epoch):
if checkpoint_epoch is not None:
epoch = checkpoint_epoch
else:
epoch = self.checkpointer.latest_step()
print("Loading model from checkpoint", epoch)
ckpt = self.checkpointer.restore(epoch)
state = ckpt['state']
Expand All @@ -245,7 +249,7 @@ def load(self):
self.best_loss = ckpt['best_loss']
print(
f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
return epoch, state, best_state, rngstate
return epoch + 1, state, best_state, rngstate

def save(self, epoch=0):
print(f"Saving model at epoch {epoch}")
Expand Down
52 changes: 29 additions & 23 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def __init__(self,
wandb_config: Dict[str, Any] = None,
distributed_training: bool = None,
checkpoint_base_path: str = "./checkpoints",
checkpoint_epoch: int = None,
):
if distributed_training is None or distributed_training is True:
# Auto-detect if we are running on multiple devices
Expand All @@ -430,6 +431,7 @@ def __init__(self,


if wandb_config is not None and jax.process_index() == 0:
import wandb
run = wandb.init(**wandb_config)
self.wandb = run

Expand All @@ -445,7 +447,7 @@ def __init__(self,
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")

if checkpoint_id is None:
self.checkpoint_id = name.replace(' ', '_').lower()
self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
else:
self.checkpoint_id = checkpoint_id

Expand All @@ -458,7 +460,7 @@ def __init__(self,
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)

if load_from_checkpoint:
latest_epoch, old_state, old_best_state, rngstate = self.load()
latest_epoch, old_state, old_best_state, rngstate = self.load(checkpoint_epoch)
else:
latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None

Expand Down Expand Up @@ -551,8 +553,11 @@ def tensorboard_path(self):
os.makedirs(path)
return path

def load(self):
epoch = self.checkpointer.latest_step()
def load(self, checkpoint_epoch):
if checkpoint_epoch is not None:
epoch = checkpoint_epoch
else:
epoch = self.checkpointer.latest_step()
print("Loading model from checkpoint", epoch)
ckpt = self.checkpointer.restore(epoch)
state = ckpt['state']
Expand All @@ -562,7 +567,7 @@ def load(self):
self.best_loss = ckpt['best_loss']
print(
f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
return epoch, state, best_state, rngstate
return epoch + 1, state, best_state, rngstate

def save(self, epoch=0):
print(f"Saving model at epoch {epoch}")
Expand Down Expand Up @@ -846,9 +851,10 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
local_rng_state = RandomMarkovState(subkey)

images = batch['image']

images = jnp.array(images, dtype=jnp.float32)
# normalize image
images = (images - 127.5) / 127.5

if autoencoder is not None:
# Convert the images to latent space
# local_rng_state, rngs = local_rng_state.get_random_key()
Expand Down Expand Up @@ -1002,7 +1008,7 @@ def main(args):

jax.distributed.initialize()

jax.config.update('jax_threefry_partitionable', True)
# jax.config.update('jax_threefry_partitionable', True)
print(f"Number of devices: {jax.device_count()}")
print(f"Local devices: {jax.local_devices()}")

Expand Down Expand Up @@ -1200,29 +1206,29 @@ def main(args):
python3 training.py --dataset=laiona_coco --dataset_path='/home/mrwhite0racle/gcs_mount/'\
--checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\
--epochs=40 --batch_size=256 --image_size=256 \
--learning_rate=1e-4 --num_res_blocks=3 \
--use_self_and_cross=False --precision=default --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-32_flaxdiff-0-1-7_LDM'\
--optimizer=adamw --autoencoder=stable_diffusion --feature_depths 128 256 512 512
--epochs=40 --batch_size=256 --image_size=128 \
--learning_rate=2.7e-4 --num_res_blocks=3 \
--use_self_and_cross=False --dtype=bfloat16 --precision=high --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-32_flaxdiff-0-1-8_lr-{learning_rate}_prec-{precision}_dtype-{dtype}_res-{num_res_blocks}'\
--optimizer=adamw --learning_rate_peak=4e-4 --learning_rate_end=1e-4 --learning_rate_warmup_steps=5000
for tpu-v4-64
python3 training.py --dataset=combined_aesthetic --dataset_path='/home/mrwhite0racle/gcs_mount/'\
python3 training.py --dataset=laiona_coco --dataset_path='/home/mrwhite0racle/gcs_mount/'\
--checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\
--epochs=40 --batch_size=512 --image_size=512 \
--learning_rate=9e-5 --num_res_blocks=3 \
--use_self_and_cross=False --dtype=bfloat16 --precision=default --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_flaxdiff-0-1-7'\
--optimizer=adamw
--epochs=40 --batch_size=512 --image_size=128 \
--learning_rate=2e-4 --num_res_blocks=3 \
--use_self_and_cross=False --dtype=float32 --precision=high --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_flaxdiff-0-1-8_lr-{learning_rate}_prec-{precision}_dtype-{dtype}_res-{num_res_blocks}'\
--optimizer=adamw --learning_rate_peak=4e-4 --learning_rate_end=1e-4 --learning_rate_warmup_steps=5000
for tpu-v4-16
python3 training.py --dataset=aesthetic_coyo --dataset_path='/home/mrwhite0racle/gcs_mount/'\
--checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\
--epochs=40 --batch_size=64 --image_size=128 \
--learning_rate=1e-4 --num_res_blocks=3 \
--use_self_and_cross=False --precision=default --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-16_flaxdiff-0-1-7'\
--optimizer=adamw
--epochs=40 --batch_size=128 --image_size=128 \
--learning_rate=2e-4 --num_res_blocks=3 \
--use_self_and_cross=False --dtype=float32 --precision=high --attention_heads=16\
--experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-16_flaxdiff-0-1-8_lr-{learning_rate}_prec-{precision}_dtype-{dtype}_res-{num_res_blocks}'\
--optimizer=adamw --learning_rate_peak=4e-4 --learning_rate_end=1e-4 --learning_rate_warmup_steps=5000
"""

0 comments on commit 38bec68

Please sign in to comment.