From e83333ff167d6e502f3d6b93c0bb7ad3b04d133e Mon Sep 17 00:00:00 2001 From: Luke Date: Mon, 2 Oct 2023 10:41:26 +0200 Subject: [PATCH] File cleanups, removed unused lines, ordering packages based on were it came from, e.g. local/third-party/standard libraries --- src/analyze.py | 15 ++++++++++----- src/arg_parser.py | 6 ++++++ src/data_generator.py | 10 +++++++--- src/inference.py | 12 ------------ src/loghi_custom_callback.py | 7 ++++++- src/main.py | 11 ----------- src/model.py | 1 - src/sample_processor.py | 8 +++++++- src/utils.py | 7 +++++++ src/visualize_filter_result.py | 22 ++++++++++++---------- src/visualize_network.py | 18 ++++++++++-------- src/visualize_network_saliency.py | 15 ++++++++++----- src/visualize_siamese_network_dense.py | 15 +++++++++++---- 13 files changed, 86 insertions(+), 61 deletions(-) diff --git a/src/analyze.py b/src/analyze.py index a3f292c2..aa38f592 100644 --- a/src/analyze.py +++ b/src/analyze.py @@ -1,17 +1,22 @@ -from __future__ import division -from __future__ import print_function +# Imports +# > Standard Library import os import sys import math import pickle import copy -import numpy as np -import cv2 -import matplotlib.pyplot as plt +from __future__ import division +from __future__ import print_function + +# > Local dependencies from model import Model from sample_processor import preprocess +# > Third party libraries +import numpy as np +import cv2 +import matplotlib.pyplot as plt # constants like filepaths class Constants: diff --git a/src/arg_parser.py b/src/arg_parser.py index 95b44294..0d483294 100644 --- a/src/arg_parser.py +++ b/src/arg_parser.py @@ -1,5 +1,11 @@ +# Imports + +# > Standard Library import argparse +# > Local dependencies + +# > Third party libraries def get_arg_parser(): parser = argparse.ArgumentParser( diff --git a/src/data_generator.py b/src/data_generator.py index 7a729d8d..56d954ad 100644 --- a/src/data_generator.py +++ b/src/data_generator.py @@ -1,8 +1,12 @@ -# import the necessary packages +# Imports -# from helpers import benchmark -import tensorflow as tf +# > Standard Library import random + +# > Local dependencies + +# > Third party libraries +import tensorflow as tf import elasticdeform.tf as etf import tensorflow_addons as tfa diff --git a/src/inference.py b/src/inference.py index b4391757..d5954832 100644 --- a/src/inference.py +++ b/src/inference.py @@ -38,8 +38,6 @@ def main(): batchSize = 1 imgSize = (1024, 32, 1) maxTextLen = 128 - epochs = 10 - learning_rate = 0.0001 # load training data, create TF model charlist = open(FilePaths.fnCharList).read() print(charlist) @@ -90,18 +88,8 @@ def decode_batch_predictions(pred): _, ax = plt.subplots(4, 4, figsize=(15, 5)) for i in range(len(pred_texts)): - # for i in range(16): print(orig_texts[i].strip()) print(pred_texts[i].strip()) - -# img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8) -# img = img.T -# title = f"Prediction: {pred_texts[i].strip()}" -# ax[i // 4, i % 4].imshow(img, cmap="gray") -# ax[i // 4, i % 4].set_title(title) -# ax[i // 4, i % 4].axis("off") -# plt.show() - if __name__ == '__main__': main() diff --git a/src/loghi_custom_callback.py b/src/loghi_custom_callback.py index 3db6d067..4217a7c0 100644 --- a/src/loghi_custom_callback.py +++ b/src/loghi_custom_callback.py @@ -1,7 +1,12 @@ +# Imports + +# > Standard Library import os -from tensorflow import keras import json +# > Local dependencies +# > Third party libraries +from tensorflow import keras class LoghiCustomCallback(keras.callbacks.Callback): diff --git a/src/main.py b/src/main.py index 294f251d..92765f5f 100644 --- a/src/main.py +++ b/src/main.py @@ -96,7 +96,6 @@ def main(): exit(1) with open(charlist_location) as file: char_list = list(char for char in file.read()) - # char_list = sorted(list(char_list)) print("using charlist") print("length charlist: " + str(len(char_list))) print(char_list) @@ -194,7 +193,6 @@ def main(): if args.existing_model: print('using existing model as base: ' + args.existing_model) - # if not args.replace_final_layer: if args.replace_recurrent_layer: model = replace_recurrent_layer(model, @@ -306,7 +304,6 @@ def main(): output=args.output, model_name='encoder12', steps_per_epoch=args.steps_per_epoch, - # num_workers=args.num_workers, max_queue_size=args.max_queue_size, early_stopping_patience=args.early_stopping_patience, output_checkpoints=args.output_checkpoints, @@ -397,18 +394,13 @@ def main(): predictions = prediction_model.predict(batch[0]) predicted_texts = decode_batch_predictions(predictions, utilsObject, args.greedy, args.beam_width, args.num_oov_indices) - - # preds = utils.softmax(preds) predsbeam = tf.transpose(predictions, perm=[1, 0, 2]) if wbs: print('computing wbs...') label_str = wbs.compute(predsbeam) char_str = [] # decoded texts for batch - # print(len(label_str)) - # print(label_str) for curr_label_str in label_str: - # print(len(curr_label_str)) s = ''.join([chars[label] for label in curr_label_str]) char_str.append(s) print(s) @@ -424,7 +416,6 @@ def main(): for i in range(len(prediction)): confidence = prediction[i][0] predicted_text = prediction[i][1] - # for i in range(16): original_text = orig_texts[i].strip().replace('', '') predicted_text = predicted_text.strip().replace('', '') original_text = remove_tags(original_text) @@ -455,7 +446,6 @@ def main(): filename = loader.get_item( 'validation', (batch_no * args.batch_size) + i) print('\n' + filename) - # print(predicted_simple) print(original_text) print(predicted_text) if wbs: @@ -476,7 +466,6 @@ def main(): + ' total_orig: ' + str(len(original_text)) + ' total_pred: ' + str(len(predicted_text)) + ' errors: ' + str(current_editdistance)) - # print(cer) print("avg editdistance: " + str(totaleditdistance / float(totallength))) print("avg editdistance lower: " + diff --git a/src/model.py b/src/model.py index b58cf604..676223b3 100644 --- a/src/model.py +++ b/src/model.py @@ -143,7 +143,6 @@ def update_state(self, y_true, y_pred, sample_weight=None): self.wer_accumulator.assign_add(correct_words_amount) self.counter.assign_add(K.cast(len(y_true), 'float32')) - # self.counter.assign_add(10) def result(self): return tf.math.divide_no_nan(self.wer_accumulator, self.counter) diff --git a/src/sample_processor.py b/src/sample_processor.py index 348e0bca..02b73f2d 100644 --- a/src/sample_processor.py +++ b/src/sample_processor.py @@ -1,7 +1,13 @@ +# Imports + +# > Standard Library from __future__ import division from __future__ import print_function - import random + +# > Local dependencies + +# > Third party libraries import numpy as np import cv2 diff --git a/src/utils.py b/src/utils.py index 35217d41..ca027485 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,3 +1,10 @@ +# Imports + +# > Standard Library + +# > Local dependencies + +# > Third party libraries import tensorflow as tf import numpy as np from keras.models import Model diff --git a/src/visualize_filter_result.py b/src/visualize_filter_result.py index 4cc43647..f9cc0f57 100644 --- a/src/visualize_filter_result.py +++ b/src/visualize_filter_result.py @@ -1,21 +1,23 @@ -from tensorflow.keras.utils import get_custom_objects -import os +# Imports +# > Standard Library +import random +import argparse +import os -from config import * -import utils - -from data_loader import DataLoaderNew +# > Local dependencies +from data_loader import DataLoader from model import CERMetric, WERMetric, CTCLoss from utils import * -import tensorflow.keras as keras +from config import * +# > Third party libraries +import tensorflow.keras as keras import numpy as np import tensorflow as tf -import random -import argparse from matplotlib import pyplot as plt -import tensorflow_addons as tfa +from tensorflow.keras.utils import get_custom_objects + # disable GPU for now, because it is already running on my dev machine os.environ["CUDA_VISIBLE_DEVICES"] = "0" diff --git a/src/visualize_network.py b/src/visualize_network.py index a4aebc08..fa191976 100644 --- a/src/visualize_network.py +++ b/src/visualize_network.py @@ -1,23 +1,25 @@ +# Imports + +# > Standard Library import math import os +import random +import argparse -import utils +# > Local dependencies from model import CTCLoss, CERMetric, WERMetric from utils import * - - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" -os.environ['TF_DETERMINISTIC_OPS'] = '0' - from config import * + +# > Third party libraries import tensorflow.keras as keras import numpy as np import tensorflow as tf -import random -import argparse import tensorflow_addons as tfa from tensorflow.keras.utils import get_custom_objects +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ['TF_DETERMINISTIC_OPS'] = '0' def compute_loss(input_image, filter_index): activation = feature_extractor(input_image) diff --git a/src/visualize_network_saliency.py b/src/visualize_network_saliency.py index b9a1eee4..3c0ac613 100644 --- a/src/visualize_network_saliency.py +++ b/src/visualize_network_saliency.py @@ -1,17 +1,22 @@ -import metrics +# Imports + +# > Standard Library +import random +import argparse + +# > Local dependencies from config import * + + +# > Third party libraries import tensorflow.keras as keras import tensorflow.keras.backend as K import numpy as np import tensorflow as tf -import random -import argparse from matplotlib import pyplot as plt from tf_keras_vis.saliency import Saliency from tf_keras_vis.utils import normalize from tensorflow.keras.preprocessing.image import load_img -# from keras.utils.generic_utils import get_custom_objects -from tensorflow.keras.utils import get_custom_objects # from dataset_ecodices import DatasetEcodices # from dataset_iisg import DatasetIISG diff --git a/src/visualize_siamese_network_dense.py b/src/visualize_siamese_network_dense.py index 42350313..bb933728 100644 --- a/src/visualize_siamese_network_dense.py +++ b/src/visualize_siamese_network_dense.py @@ -1,10 +1,17 @@ -from config import * +# Imports + +# > Standard Library import metrics -from matplotlib import pyplot as plt -from tf_keras_vis.utils import num_of_gpus +import argparse + +# > Local dependencies + + +# > Third party libraries import tensorflow.keras as keras import tensorflow as tf -import argparse +from matplotlib import pyplot as plt +from tf_keras_vis.utils import num_of_gpus from keras.utils.generic_utils import get_custom_objects os.environ["CUDA_VISIBLE_DEVICES"] = ""