From 3a21c7020ce564d4a65ec735ac54a8b6ec9d7ff9 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Fri, 2 Aug 2024 12:42:38 +0000 Subject: [PATCH] fix: even more fixes --- training.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/training.py b/training.py index 82de6b6..a7ab86c 100644 --- a/training.py +++ b/training.py @@ -142,11 +142,9 @@ def data_source(): def labelizer_oxford_flowers102(path): with open(path, "r") as f: textlabels = [i.strip() for i in f.readlines()] - import tensorflow as tf - textlabels = tf.convert_to_tensor(textlabels) def load_labels(sample): - return textlabels[sample['label']] + return textlabels[int(sample['label'])] return load_labels def tfds_augmenters(image_scale, method): @@ -157,11 +155,11 @@ def __init__(self, *args, **kwargs): self.caption_processor = CaptionProcessor(tensor_type="np") def map(self, element) -> Dict[str, jnp.array]: - image = element['image'].numpy() + image = element['image'] image = cv2.resize(image, (image_scale, image_scale), interpolation=cv2.INTER_AREA) # image = (image - 127.5) / 127.5 - caption = labelizer(element).decode('utf-8') + caption = labelizer(element) results = self.caption_processor(caption) return { "image": image, @@ -913,7 +911,8 @@ def main(args): "x": (IMAGE_SIZE, IMAGE_SIZE, 3), "temb": (), "textcontext": (77, 768) - } + }, + "arguments": vars(args), } text_encoders = defaultTextEncodeModel() @@ -957,7 +956,6 @@ def main(args): wandb_config = { "project": "flaxdiff", "config": CONFIG, - "arguments": vars(args), "name": experiment_name, }