Skip to content

Commit

Permalink
Remove redundant model.py functions
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Sep 22, 2023
1 parent 3cea407 commit 978f3d1
Showing 1 changed file with 4 additions and 33 deletions.
37 changes: 4 additions & 33 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import keras.backend as K
from keras.callbacks import ReduceLROnPlateau
import tensorflow as tf
from tensorflow import keras, Tensor
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Add, Conv2D, ELU, BatchNormalization
from tensorflow.python.ops import math_ops, array_ops, ctc_ops
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.keras import backend_config
Expand All @@ -25,37 +24,6 @@
epsilon = backend_config.epsilon


def elu_bn(inputs: Tensor) -> Tensor:
elu = ELU()(inputs)
bn = BatchNormalization()(elu)
return bn


def residual_block(x, downsample, filters, kernel_size, initializer) -> Tensor:
y = Conv2D(kernel_size=kernel_size,
strides=((1, 1) if not downsample else (2, 2)),
filters=filters,
padding="same",
activation='elu',
kernel_initializer=initializer)(x)
y = Conv2D(kernel_size=kernel_size,
strides=(1, 1),
filters=filters,
padding="same",
activation='elu',
kernel_initializer=initializer)(y)
if downsample:
x = Conv2D(kernel_size=(1, 1),
strides=(2, 2),
filters=filters,
padding="same",
activation='elu',
kernel_initializer=initializer)(x)
out = Add()([x, y])
out = elu_bn(out)
return out


def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.
Arguments:
Expand Down Expand Up @@ -89,6 +57,7 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
ignore_longer_outputs_than_inputs=True),
1)


class CERMetric(tf.keras.metrics.Metric):
"""
A custom Keras metric to compute the Character Error Rate
Expand Down Expand Up @@ -277,6 +246,8 @@ def replace_final_layer(model, number_characters, model_name, use_mask=False):
return model

# # Train the model


def train_batch(model, train_dataset, validation_dataset, epochs, output, model_name, steps_per_epoch=None,
early_stopping_patience=20, num_workers=20, max_queue_size=256, output_checkpoints=False,
metadata=None, charlist=None):
Expand Down

0 comments on commit 978f3d1

Please sign in to comment.