-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_to_tflite.py
37 lines (30 loc) · 1.4 KB
/
convert_to_tflite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import tensorflow as tf
import os
# Enable GPU Memory Growth
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
input_model_path = os.path.join("model/model-unet.h5")
output_model_path = os.path.join("model/model-unet.tflite")
# Metric Function
class MaxMeanIoU(tf.keras.metrics.MeanIoU):
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(tf.argmax(y_true, axis=-1), tf.argmax(y_pred, axis=-1), sample_weight)
# Loss Function
def dice_loss(y_true, y_pred, num_classes=2):
smooth=tf.keras.backend.epsilon()
dice=0
for index in range(num_classes):
y_true_f = tf.keras.backend.flatten(y_true[:,:,:,index])
y_pred_f = tf.keras.backend.flatten(y_pred[:,:,:,index])
intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
union = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f)
dice += (intersection + smooth) / (union + smooth)
return -2./num_classes * dice
# Load model
model = tf.keras.models.load_model(input_model_path, custom_objects={'dice_loss': dice_loss, 'MaxMeanIoU': MaxMeanIoU})
# Converting a tf.Keras model to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model.
with open(output_model_path, 'wb') as f:
f.write(tflite_model)