-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
1752 lines (1495 loc) · 72.2 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from typing import Any, Tuple, Mapping, Callable, List, Dict
from functools import partial
import flax.experimental
import flax.jax_utils
import flax.training
import flax.training.dynamic_scale
import jax.experimental.multihost_utils
import orbax
import orbax.checkpoint
import flax.jax_utils
import wandb.util
import wandb.wandb_run
from flaxdiff.models.common import kernel_init
from flaxdiff.models.simple_unet import Unet
from flaxdiff.models.simple_vit import UViT
import jax.experimental.pallas.ops.tpu.flash_attention
from flaxdiff.predictors import VPredictionTransform, EpsilonPredictionTransform, DiffusionPredictionTransform, DirectPredictionTransform, KarrasPredictionTransform
from flaxdiff.schedulers import CosineNoiseSchedule, NoiseScheduler, GeneralizedNoiseScheduler, KarrasVENoiseScheduler, EDMNoiseScheduler
import struct as st
import flax
import tqdm
from flax import linen as nn
import jax
from typing import Dict, Callable, Sequence, Any, Union
from dataclasses import field
import jax.numpy as jnp
import grain.python as pygrain
import numpy as np
import augmax
import matplotlib.pyplot as plt
from clu import metrics
from flax.training import train_state # Useful dataclass to keep train state
import optax
from flax import struct # Flax dataclasses
import time
import os
from datetime import datetime
from flax.training import orbax_utils
import functools
import json
# For CLIP
from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel
import wandb
import cv2
import argparse
import resource
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
from termcolor import colored
import warnings
import traceback
warnings.filterwarnings("ignore")
#####################################################################################################################
################################################# Initialization ####################################################
#####################################################################################################################
os.environ['TOKENIZERS_PARALLELISM'] = "false"
class RandomClass():
def __init__(self, rng: jax.random.PRNGKey):
self.rng = rng
def get_random_key(self):
self.rng, subkey = jax.random.split(self.rng)
return subkey
def get_sigmas(self, steps):
return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa
def reset_random_key(self):
self.rng = jax.random.PRNGKey(42)
class MarkovState(struct.PyTreeNode):
pass
class RandomMarkovState(MarkovState):
rng: jax.random.PRNGKey
def get_random_key(self):
rng, subkey = jax.random.split(self.rng)
return RandomMarkovState(rng), subkey
PROCESS_COLOR_MAP = {
0: "green",
1: "yellow",
2: "magenta",
3: "cyan",
4: "white",
5: "light_blue",
6: "light_red",
7: "light_cyan"
}
def _build_global_shape_and_sharding(
local_shape: tuple[int, ...], global_mesh: Mesh
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
return global_shape, sharding
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
"""Put local sharded array into local devices"""
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
try:
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
except ValueError as array_split_error:
raise ValueError(
f"Unable to put to devices shape {array.shape} with "
f"local device count {len(global_mesh.local_devices)} "
) from array_split_error
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
def convert_to_global_tree(global_mesh, pytree):
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
#####################################################################################################################
################################################## Data Pipeline ####################################################
#####################################################################################################################
def defaultTextEncodeModel(backend="jax"):
modelname = "openai/clip-vit-large-patch14"
if backend == "jax":
model = FlaxCLIPTextModel.from_pretrained(
modelname, dtype=jnp.bfloat16)
else:
model = CLIPTextModel.from_pretrained(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
return model, tokenizer
def encodePrompts(prompts, model, tokenizer=None):
if model == None:
model, tokenizer = defaultTextEncodeModel()
if tokenizer == None:
tokenizer = AutoTokenizer.from_pretrained(
"openai/clip-vit-large-patch14")
# inputs = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="np")
inputs = tokenizer(prompts, padding="max_length",
max_length=tokenizer.model_max_length, truncation=True, return_tensors="np")
outputs = model(input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'])
# outputs = infer(inputs['input_ids'], inputs['attention_mask'])
last_hidden_state = outputs.last_hidden_state
pooler_output = outputs.pooler_output # pooled (EOS token) states
embed_pooled = pooler_output # .astype(jnp.float16)
embed_labels_full = last_hidden_state # .astype(jnp.float16)
return embed_pooled, embed_labels_full
class CaptionProcessor:
def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
self.tokenizer = AutoTokenizer.from_pretrained(modelname)
self.tensor_type = tensor_type
def __call__(self, caption):
# print(caption)
tokens = self.tokenizer(caption, padding="max_length", max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors=self.tensor_type)
# print(tokens.keys())
return {
"input_ids": tokens["input_ids"],
"attention_mask": tokens["attention_mask"],
"caption": caption,
}
def __repr__(self):
return self.__class__.__name__ + '()'
# -----------------------------------------------------------------------------------------------#
# Oxford flowers and other TFDS datasources ----------------------------------------------------#
# -----------------------------------------------------------------------------------------------#
def data_source_tfds(name, use_tf=True, split="all"):
import tensorflow_datasets as tfds
if use_tf:
def data_source(path_override):
return tfds.load(name, split=split, shuffle_files=True)
else:
def data_source(path_override):
return tfds.data_source(name, split=split, try_gcs=False)
return data_source
def labelizer_oxford_flowers102(path):
with open(path, "r") as f:
textlabels = [i.strip() for i in f.readlines()]
def load_labels(sample):
return textlabels[int(sample['label'])]
return load_labels
def tfds_augmenters(image_scale, method):
labelizer = labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
if image_scale > 256:
interpolation = cv2.INTER_CUBIC
else:
interpolation = cv2.INTER_AREA
class augmenters(pygrain.MapTransform):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.caption_processor = CaptionProcessor(tensor_type="np")
def map(self, element) -> Dict[str, jnp.array]:
image = element['image']
image = cv2.resize(image, (image_scale, image_scale),
interpolation=interpolation)
# image = (image - 127.5) / 127.5
caption = labelizer(element)
results = self.caption_processor(caption)
return {
"image": image,
"input_ids": results['input_ids'][0],
"attention_mask": results['attention_mask'][0],
}
return augmenters
# -----------------------------------------------------------------------------------------------#
# CC12m and other GCS data sources --------------------------------------------------------------#
# -----------------------------------------------------------------------------------------------#
def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
def data_source(base="/home/mrwhite0racle/gcs_mount"):
records_path = os.path.join(base, source)
records = [os.path.join(records_path, i) for i in os.listdir(
records_path) if 'array_record' in i]
ds = pygrain.ArrayRecordDataSource(records)
return ds
return data_source
def data_source_combined_gcs(
sources=[]):
def data_source(base="/home/mrwhite0racle/gcs_mount"):
records_paths = [os.path.join(base, source) for source in sources]
records = []
for records_path in records_paths:
records += [os.path.join(records_path, i) for i in os.listdir(
records_path) if 'array_record' in i]
ds = pygrain.ArrayRecordDataSource(records)
return ds
return data_source
def unpack_dict_of_byte_arrays(packed_data):
unpacked_dict = {}
offset = 0
while offset < len(packed_data):
# Unpack the key length
key_length = st.unpack_from('I', packed_data, offset)[0]
offset += st.calcsize('I')
# Unpack the key bytes and convert to string
key = packed_data[offset:offset+key_length].decode('utf-8')
offset += key_length
# Unpack the byte array length
byte_array_length = st.unpack_from('I', packed_data, offset)[0]
offset += st.calcsize('I')
# Unpack the byte array
byte_array = packed_data[offset:offset+byte_array_length]
offset += byte_array_length
unpacked_dict[key] = byte_array
return unpacked_dict
def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (image_scale, image_scale),
interpolation=cv2.INTER_AREA)
return image
def gcs_augmenters(image_scale, method):
labelizer = lambda sample : sample['txt']
class augmenters(pygrain.MapTransform):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.caption_processor = CaptionProcessor(tensor_type="np")
self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
def map(self, element) -> Dict[str, jnp.array]:
element = unpack_dict_of_byte_arrays(element)
image = np.asarray(bytearray(element['jpg']), dtype="uint8")
image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
image = self.image_augmenter(image)
caption = labelizer(element).decode('utf-8')
results = self.caption_processor(caption)
return {
"image": image,
"input_ids": results['input_ids'][0],
"attention_mask": results['attention_mask'][0],
}
return augmenters
# Configure the following for your datasets
datasetMap = {
"oxford_flowers102": {
"source": data_source_tfds("oxford_flowers102", use_tf=False),
"augmenter": tfds_augmenters,
},
"cc12m": {
"source": data_source_gcs('arrayrecord2/cc12m'),
"augmenter": gcs_augmenters,
},
"laiona_coco": {
"source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
"augmenter": gcs_augmenters,
},
"aesthetic_coyo": {
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
"augmenter": gcs_augmenters,
},
"combined_aesthetic": {
"source": data_source_combined_gcs([
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
'arrayrecord2/cc12m',
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
]),
"augmenter": gcs_augmenters,
},
"laiona_coco_coyo": {
"source": data_source_combined_gcs([
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
]),
"augmenter": gcs_augmenters,
},
"combined_30m": {
"source": data_source_combined_gcs([
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
'arrayrecord2/cc12m',
'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
"arrayrecord2/playground+leonardo_x4+cc3m.parquet",
]),
"augmenter": gcs_augmenters,
}
}
def batch_mesh_map(mesh):
class augmenters(pygrain.MapTransform):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def map(self, batch) -> Dict[str, jnp.array]:
return convert_to_global_tree(mesh, batch)
return augmenters
def get_dataset_grain(
data_name="cc12m",
batch_size=64,
image_scale=256,
count=None,
num_epochs=None,
method=jax.image.ResizeMethod.LANCZOS3,
worker_count=32,
read_thread_count=64,
read_buffer_size=50,
worker_buffer_size=20,
seed=0,
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
):
dataset = datasetMap[data_name]
data_source = dataset["source"](dataset_source)
augmenter = dataset["augmenter"](image_scale, method)
local_batch_size = batch_size // jax.process_count()
model, tokenizer = defaultTextEncodeModel()
null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
null_labels = np.array(null_labels[0], dtype=np.float16)
null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
sampler = pygrain.IndexSampler(
num_records=len(data_source) if count is None else count,
shuffle=True,
seed=seed,
num_epochs=num_epochs,
shard_options=pygrain.ShardByJaxProcess(),
)
def get_trainset():
transformations = [
augmenter(),
pygrain.Batch(local_batch_size, drop_remainder=True),
]
# if mesh != None:
# transformations += [batch_mesh_map(mesh)]
loader = pygrain.DataLoader(
data_source=data_source,
sampler=sampler,
operations=transformations,
worker_count=worker_count,
read_options=pygrain.ReadOptions(
read_thread_count, read_buffer_size
),
worker_buffer_size=worker_buffer_size,
)
return loader
return {
"train": get_trainset,
"train_len": len(data_source),
"local_batch_size": local_batch_size,
"global_batch_size": batch_size,
"null_labels": null_labels,
"null_labels_full": null_labels_full,
"model": model,
"tokenizer": tokenizer,
}
# -----------------------------------------------------------------------------------------------#
# Dataloader for directly streaming images from urls --------------------------------------------#
# -----------------------------------------------------------------------------------------------#
import albumentations as A
from flaxdiff.data.online_loader import OnlineStreamingDataLoader, dataMapper, \
default_collate, load_dataset, concatenate_datasets, \
ImageBatchIterator, default_image_processor, load_from_disk
import threading
import queue
def default_image_processor(
image, image_shape,
min_image_shape=(128, 128),
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
):
try:
image = np.array(image)
if len(image.shape) != 3 or image.shape[2] != 3:
return None, 0, 0
original_height, original_width = image.shape[:2]
# check if the image is too small
if min(original_height, original_width) < min(min_image_shape):
return None, original_height, original_width
# check if wrong aspect ratio
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
return None, original_height, original_width
# check if the variance is too low
if np.std(image) < 1e-5:
return None, original_height, original_width
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
downscale = max(original_width, original_height) > max(image_shape)
interpolation = downscale_interpolation if downscale else upscale_interpolation
image = A.longest_max_size(image, max(
image_shape), interpolation=interpolation)
image = A.pad(
image,
min_height=image_shape[0],
min_width=image_shape[1],
border_mode=cv2.BORDER_CONSTANT,
value=[255, 255, 255],
)
return image, original_height, original_width
except Exception as e:
# print("Error processing image", e, image_shape, interpolation)
# traceback.print_exc()
return None, 0, 0
def default_feature_extractor(sample):
url = None
if "url" in sample:
url = sample["url"]
elif "URL" in sample:
url = sample["URL"]
elif "image_url" in sample:
url = sample["image_url"]
else:
print("No url found in sample, skipping", sample.keys())
caption = None
if "caption" in sample:
caption = sample["caption"]
elif "CAPTION" in sample:
caption = sample["CAPTION"]
elif "txt" in sample:
caption = sample["txt"]
elif "TEXT" in sample:
caption = sample["TEXT"]
elif "text" in sample:
caption = sample["text"]
else:
print("No caption found in sample, skipping", sample.keys())
return {
"url": url,
"caption": caption,
}
class OnlineStreamingDataLoader():
def __init__(
self,
dataset,
batch_size=64,
image_shape=(256, 256),
min_image_shape=(128, 128),
num_workers=16,
num_threads=512,
default_split="all",
pre_map_maker=dataMapper,
pre_map_def={
"url": "URL",
"caption": "TEXT",
},
global_process_count=1,
global_process_index=0,
prefetch=1000,
collate_fn=default_collate,
timeout=15,
retries=3,
image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
):
if isinstance(dataset, str):
dataset_path = dataset
print("Loading dataset from path")
if "gs://" in dataset:
dataset = load_from_disk(dataset_path)
else:
dataset = load_dataset(dataset_path, split=default_split)
elif isinstance(dataset, list):
if isinstance(dataset[0], str):
print("Loading multiple datasets from paths")
dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
dataset_path, split=default_split) for dataset_path in dataset]
print("Concatenating multiple datasets")
dataset = concatenate_datasets(dataset)
dataset = dataset.shuffle(seed=0)
# dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
self.dataset = dataset.shard(
num_shards=global_process_count, index=global_process_index)
print(f"Dataset length: {len(dataset)}")
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
min_image_shape=min_image_shape,
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor)
self.batch_size = batch_size
# Launch a thread to load batches in the background
self.batch_queue = queue.Queue(prefetch)
def batch_loader():
for batch in self.iterator:
try:
self.batch_queue.put(collate_fn(batch))
except Exception as e:
print("Error collating batch", e)
self.loader_thread = threading.Thread(target=batch_loader)
self.loader_thread.start()
def __iter__(self):
return self
def __next__(self):
return self.batch_queue.get()
# return self.collate_fn(next(self.iterator))
def __len__(self):
return len(self.dataset)
onlineDatasetMap = {
"combined_online": {
"source": [
# "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
# "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
# "dclure/laion-aesthetics-12m-umap",
"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
"gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
"gs://flaxdiff-datasets-regional/datasets/cc12m",
"gs://flaxdiff-datasets-regional/datasets/playground-liked",
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
"gs://flaxdiff-datasets-regional/datasets/cc3m",
"gs://flaxdiff-datasets-regional/datasets/cc3m",
"gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
# "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
]
}
}
def generate_collate_fn(tokenizer):
caption_processor = CaptionProcessor(tensor_type="np")
def default_collate(batch):
try:
# urls = [sample["url"] for sample in batch]
captions = [sample["caption"] for sample in batch]
results = caption_processor(captions)
images = np.stack([sample["image"] for sample in batch], axis=0)
return {
"image": images,
"input_ids": results['input_ids'],
"attention_mask": results['attention_mask'],
}
except Exception as e:
print("Error in collate function", e, [sample["image"].shape for sample in batch])
traceback.print_exc()
return default_collate
def get_dataset_online(
data_name="combined_online",
batch_size=64,
image_scale=256,
count=None,
num_epochs=None,
method=jax.image.ResizeMethod.LANCZOS3,
worker_count=32,
read_thread_count=64,
read_buffer_size=50,
worker_buffer_size=20,
seed=0,
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
):
local_batch_size = batch_size // jax.process_count()
model, tokenizer = defaultTextEncodeModel()
null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
null_labels = np.array(null_labels[0], dtype=np.float16)
null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
sources = onlineDatasetMap[data_name]["source"]
dataloader = OnlineStreamingDataLoader(
sources,
batch_size=local_batch_size,
num_workers=worker_count,
num_threads=read_thread_count,
image_shape=(image_scale, image_scale),
global_process_count=jax.process_count(),
global_process_index=jax.process_index(),
prefetch=worker_buffer_size,
collate_fn=generate_collate_fn(tokenizer),
default_split="train",
)
def get_trainset(mesh: Mesh = None):
if mesh != None:
class dataLoaderWithMesh:
def __init__(self, dataloader, mesh):
self.dataloader = dataloader
self.mesh = mesh
self.tmp_queue = queue.Queue(worker_buffer_size)
def batch_loader():
for batch in self.dataloader:
try:
self.tmp_queue.put(convert_to_global_tree(mesh, batch))
except Exception as e:
print("Error processing batch", e)
self.loader_thread = threading.Thread(target=batch_loader)
self.loader_thread.start()
def __iter__(self):
return self
def __next__(self):
return self.tmp_queue.get()
dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
return dataloader_with_mesh
return dataloader
return {
"train": get_trainset,
"train_len": len(dataloader) * jax.process_count(),
"local_batch_size": local_batch_size,
"global_batch_size": batch_size,
"null_labels": null_labels,
"null_labels_full": null_labels_full,
"model": model,
"tokenizer": tokenizer,
}
#####################################################################################################################
############################################### Training Pipeline ###################################################
#####################################################################################################################
@struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output('loss')
# Define the TrainState
class SimpleTrainState(train_state.TrainState):
metrics: Metrics
dynamic_scale: flax.training.dynamic_scale.DynamicScale
class SimpleTrainer:
state: SimpleTrainState
best_state: SimpleTrainState
best_loss: float
model: nn.Module
ema_decay: float = 0.999
def __init__(self,
model: nn.Module,
input_shapes: Dict[str, Tuple[int]],
optimizer: optax.GradientTransformation,
rngs: jax.random.PRNGKey,
train_state: SimpleTrainState = None,
name: str = "Simple",
load_from_checkpoint: str = None,
checkpoint_suffix: str = "",
loss_fn=optax.l2_loss,
param_transforms: Callable = None,
wandb_config: Dict[str, Any] = None,
distributed_training: bool = None,
checkpoint_base_path: str = "./checkpoints",
checkpoint_step: int = None,
use_dynamic_scale: bool = False,
):
if distributed_training is None or distributed_training is True:
# Auto-detect if we are running on multiple devices
distributed_training = jax.device_count() > 1
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
else:
self.mesh = None
self.distributed_training = distributed_training
self.model = model
self.name = name
self.loss_fn = loss_fn
self.input_shapes = input_shapes
self.checkpoint_base_path = checkpoint_base_path
if wandb_config is not None and jax.process_index() == 0:
run = wandb.init(**wandb_config)
self.wandb = run
# define our custom x axis metric
self.wandb.define_metric("train/step")
self.wandb.define_metric("train/epoch")
self.wandb.define_metric("train/loss", step_metric="train/step")
self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
options = orbax.checkpoint.CheckpointManagerOptions(
max_to_keep=4, create=True)
self.checkpointer = orbax.checkpoint.CheckpointManager(
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
if load_from_checkpoint is not None:
latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
else:
latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
self.latest_step = latest_step
if rngstate:
self.rngstate = RandomMarkovState(**rngstate)
else:
self.rngstate = RandomMarkovState(rngs)
self.rngstate, subkey = self.rngstate.get_random_key()
if train_state == None:
state, best_state = self.generate_states(
optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
)
self.init_state(state, best_state)
else:
self.state = train_state
self.best_state = train_state
self.best_loss = 1e9
def get_input_ones(self):
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
def generate_states(
self,
optimizer: optax.GradientTransformation,
rngs: jax.random.PRNGKey,
existing_state: dict = None,
existing_best_state: dict = None,
model: nn.Module = None,
param_transforms: Callable = None,
use_dynamic_scale: bool = False
) -> Tuple[SimpleTrainState, SimpleTrainState]:
print("Generating states for SimpleTrainer")
rngs, subkey = jax.random.split(rngs)
if existing_state == None:
input_vars = self.get_input_ones()
params = model.init(subkey, **input_vars)
else:
params = existing_state['params']
state = SimpleTrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
metrics=Metrics.empty(),
dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
)
if existing_best_state is not None:
best_state = state.replace(
params=existing_best_state['params'])
else:
best_state = state
return state, best_state
def init_state(
self,
state: SimpleTrainState,
best_state: SimpleTrainState,
):
self.best_loss = 1e9
self.state = state
self.best_state = best_state
def get_state(self):
return self.get_np_tree(self.state)
def get_best_state(self):
return self.get_np_tree(self.best_state)
def get_rngstate(self):
return self.get_np_tree(self.rngstate)
def get_np_tree(self, pytree):
return jax.tree_util.tree_map(lambda x : np.array(x), pytree)
def checkpoint_path(self):
path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
if not os.path.exists(path):
os.makedirs(path)
return path
def tensorboard_path(self):
experiment_name = self.name
path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
if not os.path.exists(path):
os.makedirs(path)
return path
def load(self, checkpoint_path=None, checkpoint_step=None):
if checkpoint_path is None:
checkpointer = self.checkpointer
else:
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
options = orbax.checkpoint.CheckpointManagerOptions(
max_to_keep=4, create=False)
checkpointer = orbax.checkpoint.CheckpointManager(
checkpoint_path, checkpointer, options)
if checkpoint_step is None:
step = checkpointer.latest_step()
else:
step = checkpoint_step
print("Loading model from checkpoint at step ", step)
ckpt = checkpointer.restore(step)
state = ckpt['state']
best_state = ckpt['best_state']
rngstate = ckpt['rngs']
# Convert the state to a TrainState
self.best_loss = ckpt['best_loss']
if self.best_loss == 0:
# It cant be zero as that must have been some problem
self.best_loss = 1e9
current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
print(
f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
return current_epoch, step, state, best_state, rngstate
def save(self, epoch=0, step=0, state=None, rngstate=None):
print(f"Saving model at epoch {epoch} step {step}")
try:
ckpt = {
# 'model': self.model,
'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate),
'state': self.get_state() if state is None else self.get_np_tree(state),
'best_state': self.get_best_state(),
'best_loss': np.array(self.best_loss),
'epoch': epoch,
}
try:
save_args = orbax_utils.save_args_from_target(ckpt)
self.checkpointer.save(step, ckpt, save_kwargs={
'save_args': save_args}, force=True)
self.checkpointer.wait_until_finished()
pass
except Exception as e:
print("Error saving checkpoint", e)
except Exception as e:
print("Error saving checkpoint outer", e)
def _define_train_step(self, **kwargs):
model = self.model
loss_fn = self.loss_fn
distributed_training = self.distributed_training
def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
"""Train for a single step."""
images = batch['image']
labels = batch['label']
def model_loss(params):
preds = model.apply(params, images)
expected_output = labels
nloss = loss_fn(preds, expected_output)
loss = jnp.mean(nloss)
return loss
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
if distributed_training:
grads = jax.lax.pmean(grads, "data")
train_state = train_state.apply_gradients(grads=grads)
return train_state, loss, rng_state
if distributed_training:
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
train_step = jax.pmap(train_step)
return train_step
def _define_compute_metrics(self):
model = self.model
loss_fn = self.loss_fn
@jax.jit
def compute_metrics(state: SimpleTrainState, batch):
preds = model.apply(state.params, batch['image'])
expected_output = batch['label']
loss = jnp.mean(loss_fn(preds, expected_output))
metric_updates = state.metrics.single_from_model_output(
loss=loss, logits=preds, labels=expected_output)
metrics = state.metrics.merge(metric_updates)
state = state.replace(metrics=metrics)
return state
return compute_metrics
def summary(self):
input_vars = self.get_input_ones()
print(self.model.tabulate(jax.random.key(0), **input_vars,
console_kwargs={"width": 200, "force_jupyter": True, }))
def config(self):
return {
"model": self.model,
"state": self.state,
"name": self.name,
"input_shapes": self.input_shapes
}
def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
from flax.metrics import tensorboard
summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
summary_writer.hparams({
**self.config(),
"steps_per_epoch": steps_per_epoch,
"epochs": epochs,
"batch_size": batch_size
})
return summary_writer
def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
train_ds = iter(data['train']())
if 'test' in data:
test_ds = data['test']
else:
test_ds = None
train_step = self._define_train_step(**train_step_args)
compute_metrics = self._define_compute_metrics()