Skip to content

Commit

Permalink
Added tests for custom loss function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Pablo Rodríguez Flores committed Feb 12, 2024
1 parent f223ce4 commit a088ff6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
23 changes: 10 additions & 13 deletions resources/src/ai/outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,21 @@ def model_loss(self, y_true, y_pred, single_value=True):
Returns:
(tf.Tensor): Weighted loss value or a 3D loss array.
"""
y_true = tf.cast(y_true, tf.float16)
y_pred = tf.cast(y_pred, tf.float16)
y_true = tf.cast(y_true, tf.bfloat16)
y_pred = tf.cast(y_pred, tf.bfloat16)

Check warning on line 171 in resources/src/ai/outliers.py

View check run for this annotation

Codecov / codecov/patch

resources/src/ai/outliers.py#L170-L171

Added lines #L170 - L171 were not covered by tests
num_metrics = len(self.metrics)
num_features = len(self.columns)
is_metric = (tf.range(num_features) < num_metrics)
is_minute = (tf.range(num_features) == num_metrics)
mult_true = tf.where(
is_metric, self.loss_mult_metric * y_true,
tf.where(is_minute, self.loss_mult_minute * y_true, y_true)
)
mult_pred = tf.where(
is_metric, self.loss_mult_metric * y_pred,
tf.where(is_minute, self.loss_mult_minute * y_pred, y_pred)
)
standard_loss = tf.math.log(tf.cosh((mult_true - mult_pred)))
mult_true = tf.where(is_metric, self.loss_mult_metric * y_true, y_true)
mult_true = tf.where(is_minute, self.loss_mult_minute * mult_true, mult_true)
mult_pred = tf.where(is_metric, self.loss_mult_metric * y_pred, y_pred)
mult_pred = tf.where(is_minute, self.loss_mult_minute * mult_pred, mult_pred)
loss = tf.math.abs(mult_true-mult_pred)
loss = loss-tf.math.log(tf.cast(2.0, tf.bfloat16))+tf.math.log1p(tf.math.exp(-2.0*loss))

Check warning on line 181 in resources/src/ai/outliers.py

View check run for this annotation

Codecov / codecov/patch

resources/src/ai/outliers.py#L176-L181

Added lines #L176 - L181 were not covered by tests
if single_value:
standard_loss = tf.reduce_mean(standard_loss)
return standard_loss
loss = tf.reduce_mean(loss)
return loss

Check warning on line 184 in resources/src/ai/outliers.py

View check run for this annotation

Codecov / codecov/patch

resources/src/ai/outliers.py#L183-L184

Added lines #L183 - L184 were not covered by tests


def slice(self, data, index=None):
Expand Down
32 changes: 32 additions & 0 deletions resources/tests/test_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@

import unittest
import os
'''
Start of important OS Variables
'''
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
'''
End of important OS Variables
'''
import sys
import json
import tempfile
import numpy as np
import tensorflow as tf

from resources.src.ai.outliers import Autoencoder

Expand Down Expand Up @@ -120,5 +128,29 @@ def test_scale_descale_identity(self):
descaled_data = self.autoencoder.descale(rescaled_data)
self.assertTrue(np.allclose(descaled_data, rand_data))

def test_loss_execution_single_value(self):
np.random.seed(0)
y_true = tf.random.uniform((32, len(self.autoencoder.columns)), dtype=tf.float16)
y_pred = tf.random.uniform((32, len(self.autoencoder.columns)), dtype=tf.float16)
try:
loss = self.autoencoder.model_loss(y_true, y_pred, single_value=True)
execution_success = True
except Exception as e:
execution_success = False
print(e)
self.assertTrue(execution_success, "model_loss execution failed with an exception.")

def test_loss_execution_3d_array(self):
np.random.seed(0)
y_true = tf.random.uniform((32, len(self.autoencoder.columns)), dtype=tf.float16)
y_pred = tf.random.uniform((32, len(self.autoencoder.columns)), dtype=tf.float16)
try:
loss = self.autoencoder.model_loss(y_true, y_pred, single_value=False)
execution_success = True
except Exception as e:
execution_success = False
print(e)
self.assertTrue(execution_success, "model_loss execution failed with an exception.")

if __name__ == '__main__':
unittest.main()

0 comments on commit a088ff6

Please sign in to comment.