Skip to content

Commit

Permalink
fix: even more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 2, 2024
1 parent c77a653 commit 3a21c70
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,9 @@ def data_source():
def labelizer_oxford_flowers102(path):
with open(path, "r") as f:
textlabels = [i.strip() for i in f.readlines()]
import tensorflow as tf
textlabels = tf.convert_to_tensor(textlabels)

def load_labels(sample):
return textlabels[sample['label']]
return textlabels[int(sample['label'])]
return load_labels

def tfds_augmenters(image_scale, method):
Expand All @@ -157,11 +155,11 @@ def __init__(self, *args, **kwargs):
self.caption_processor = CaptionProcessor(tensor_type="np")

def map(self, element) -> Dict[str, jnp.array]:
image = element['image'].numpy()
image = element['image']
image = cv2.resize(image, (image_scale, image_scale),
interpolation=cv2.INTER_AREA)
# image = (image - 127.5) / 127.5
caption = labelizer(element).decode('utf-8')
caption = labelizer(element)
results = self.caption_processor(caption)
return {
"image": image,
Expand Down Expand Up @@ -913,7 +911,8 @@ def main(args):
"x": (IMAGE_SIZE, IMAGE_SIZE, 3),
"temb": (),
"textcontext": (77, 768)
}
},
"arguments": vars(args),
}

text_encoders = defaultTextEncodeModel()
Expand Down Expand Up @@ -957,7 +956,6 @@ def main(args):
wandb_config = {
"project": "flaxdiff",
"config": CONFIG,
"arguments": vars(args),
"name": experiment_name,
}

Expand Down

0 comments on commit 3a21c70

Please sign in to comment.