diff --git a/ats/tests/test_anomaly_detectors.py b/ats/tests/test_anomaly_detectors.py index f1862f9..171deaa 100644 --- a/ats/tests/test_anomaly_detectors.py +++ b/ats/tests/test_anomaly_detectors.py @@ -1,12 +1,11 @@ import os import unittest -import numpy as np import pandas as pd from ..anomaly_detectors.naive import MinMaxAnomalyDetector, ZScoreAnomalyDetector from ..anomaly_detectors.ml.ifsom import IFSOMAnomalyDetector from ..anomaly_detectors.stat.robust import _COMNHARAnomalyDetector -from ..utils import generate_timeseries_df, load_isp_format_wide_df, wide_df_to_list_of_timeseries_df, timeseries_df_to_list_of_timeseries_df +from ..utils import generate_timeseries_df, load_isp_format_wide_df, wide_df_to_list_of_timeseries_df, timeseries_df_to_list_of_timeseries_df, ensure_full_reproducibility from ..anomaly_detectors.stat.support_functions import generate_contaminated_dataframe # Setup logging @@ -68,7 +67,7 @@ def test_zscore(self): class TestStatAnomalyDetectors(unittest.TestCase): def setUp(self): - np.random.seed(0) + ensure_full_reproducibility() def test_robust_on_multivariate(self): @@ -118,7 +117,7 @@ def test_robust_on_list_of_univariate(self): class TestMLAnomalyDetectors(unittest.TestCase): def setUp(self): - np.random.seed(0) + ensure_full_reproducibility() def test_ifsom(self): diff --git a/ats/utils.py b/ats/utils.py index 6ac1c5f..f667ead 100644 --- a/ats/utils.py +++ b/ats/utils.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """Utilities""" +import os +import random import pandas as pd import numpy as np import plotly.express as px @@ -430,3 +432,14 @@ def timeseries_df_to_list_of_timeseries_df(timeseries_df, anomaly_labels=False): def list_of_timeseries_df_to_timeseries_df(list_of_timeseries_df): return pd.concat(list_of_timeseries_df, axis=1) +def ensure_full_reproducibility(seed=0): + random.seed(seed) + np.random.seed(seed) + try: + import tensorflow as tf + except ImportError: + pass + else: + tf.random.set_seed(seed) + tf.config.experimental.enable_op_determinism() + os.environ["TF_DETERMINISTIC_OPS"] = "1"