From ac36037b8be516cb028521ed438a12be113968e4 Mon Sep 17 00:00:00 2001 From: xuhongzuo Date: Fri, 9 Jun 2023 14:01:54 +0800 Subject: [PATCH 1/2] Deep Isolation Forest method implemented --- README.rst | 3 + docs/zreferences.bib | 11 + examples/dif_example.py | 54 +++++ pyod/models/dif.py | 451 ++++++++++++++++++++++++++++++++++++++++ pyod/test/test_dif.py | 133 ++++++++++++ 5 files changed, 652 insertions(+) create mode 100644 examples/dif_example.py create mode 100644 pyod/models/dif.py create mode 100644 pyod/test/test_dif.py diff --git a/README.rst b/README.rst index a993bb3a8..5b1e7d56c 100644 --- a/README.rst +++ b/README.rst @@ -373,6 +373,7 @@ Proximity-Based SOD Subspace Outlier Detection Proximity-Based ROD Rotation-based Outlier Detection 2020 [#Almardeny2020A]_ Outlier Ensembles IForest Isolation Forest 2008 [#Liu2008Isolation]_ Outlier Ensembles INNE Isolation-based Anomaly Detection Using Nearest-Neighbor Ensembles 2018 [#Bandaragoda2018Isolation]_ +Outlier Ensembles DIF Deep Isolation Forest for Anomaly Detection 2023 [#Xu2023Deep]_ Outlier Ensembles FB Feature Bagging 2005 [#Lazarevic2005Feature]_ Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 [#Zhao2019LSCP]_ Outlier Ensembles XGBOD Extreme Boosting Based Outlier Detection **(Supervised)** 2018 [#Zhao2018XGBOD]_ @@ -630,6 +631,8 @@ Reference .. [#Wang2020adVAE] Wang, X., Du, Y., Lin, S., Cui, P., Shen, Y. and Yang, Y., 2019. adVAE: A self-adversarial variational autoencoder with Gaussian anomaly prior knowledge for anomaly detection. *Knowledge-Based Systems*. +.. [#Xu2023Deep] Xu, H., Pang, G., Wang, Y., Wang, Y., 2023. Deep isolation forest for anomaly detection. *IEEE Transactions on Knowledge and Data Engineering*. + .. [#You2017Provable] You, C., Robinson, D.P. and Vidal, R., 2017. Provable self-representation based outlier detection in a union of subspaces. In Proceedings of the IEEE conference on computer vision and pattern recognition. .. [#Zenati2018Adversarially] Zenati, H., Romain, M., Foo, C.S., Lecouat, B. and Chandrasekhar, V., 2018, November. Adversarially learned anomaly detection. In 2018 IEEE International conference on data mining (ICDM) (pp. 727-736). IEEE. diff --git a/docs/zreferences.bib b/docs/zreferences.bib index eee409ea2..e22b37f5c 100644 --- a/docs/zreferences.bib +++ b/docs/zreferences.bib @@ -489,4 +489,15 @@ @article{fang2001wrap pages={608--624}, year={2001}, publisher={Elsevier} +} + +@article{xu2023dif, + author={Xu, Hongzuo and Pang, Guansong and Wang, Yijie and Wang, Yongjun}, + journal={IEEE Transactions on Knowledge and Data Engineering}, + title={Deep Isolation Forest for Anomaly Detection}, + year={2023}, + volume={}, + number={}, + pages={1-14}, + doi={10.1109/TKDE.2023.3270293} } \ No newline at end of file diff --git a/examples/dif_example.py b/examples/dif_example.py new file mode 100644 index 000000000..8d64ed6b1 --- /dev/null +++ b/examples/dif_example.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +"""Example of using Deep Isolation Forest for +outlier detection""" +# Author: Hongzuo Xu +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import os +import sys + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) + +from pyod.models.dif import DIF +from pyod.utils.data import generate_data +from pyod.utils.data import evaluate_print + + +if __name__ == "__main__": + contamination = 0.1 # percentage of outliers + n_train = 20000 # number of training points + n_test = 2000 # number of testing points + n_features = 300 # number of features + + # Generate sample data + X_train, X_test, y_train, y_test = \ + generate_data(n_train=n_train, + n_test=n_test, + n_features=n_features, + contamination=contamination, + random_state=42) + + # train AutoEncoder detector + clf_name = 'DIF' + clf = DIF() + clf.fit(X_train) + + # get the prediction labels and outlier scores of the training data + y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) + y_train_scores = clf.decision_scores_ # raw outlier scores + + # get the prediction on the test data + y_test_pred = clf.predict(X_test) # outlier labels (0 or 1) + y_test_scores = clf.decision_function(X_test) # outlier scores + + # evaluate and print the results + print("\nOn Training Data:") + evaluate_print(clf_name, y_train, y_train_scores) + print("\nOn Test Data:") + evaluate_print(clf_name, y_test, y_test_scores) diff --git a/pyod/models/dif.py b/pyod/models/dif.py new file mode 100644 index 000000000..e75f05d1d --- /dev/null +++ b/pyod/models/dif.py @@ -0,0 +1,451 @@ +# -*- coding: utf-8 -*- +"""Deep Isolation Forest for Anomaly Detection (DIF) +""" +# Author: Hongzuo Xu +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +from sklearn.utils import check_array +from sklearn.utils.validation import check_is_fitted +from sklearn.ensemble import IsolationForest +from sklearn.preprocessing import StandardScaler, MinMaxScaler +from torch.utils.data import DataLoader + +# from pyod.models.base import BaseDetector +# from pyod.utils.utility import check_parameter +# from pyod.utils.torch_utility import get_activation_by_name +from .base import BaseDetector +from ..utils.utility import check_parameter +from ..utils.torch_utility import get_activation_by_name + + +class DIF(BaseDetector): + """Deep Isolation Forest (DIF) is an extension of iForest. It uses deep + representation ensemble to achieve non-linear isolation on original data + space. See :cite:`xu2023dif` + for details. + + Parameters + ---------- + batch_size : int, optional (default=1000) + Number of samples per gradient update. + + representation_dim, int, optional (default=20) + Dimensionality of the representation space. + + hidden_neurons, list, optional (default=[64, 32]) + The number of neurons per hidden layers. So the network has the + structure as [n_features, hidden_neurons[0], hidden_neurons[1], + ..., representation_dim] + + hidden_activation, str, optional (default='tanh') + Activation function to use for hidden layers. + All hidden layers are forced to use the same type of activation. + See https://pytorch.org/docs/stable/nn.html for details. + Currently only + 'relu': nn.ReLU() + 'sigmoid': nn.Sigmoid() + 'tanh': nn.Tanh() + are supported. See pyod/utils/torch_utility.py for details. + + skip_connection, boolean, optional (default=False) + If True, apply skip-connection in the neural network structure. + + n_ensemble, int, optional (default=50) + The number of deep representation ensemble members. + + n_estimators, int, optional (default=6) + The number of isolation forest of each representation. + + max_samples, int, optional (default=256) + The number of samples to draw from X to train each base isolation tree. + + contamination : float in (0., 0.5), optional (default=0.1) + The amount of contamination of the data set, + i.e. the proportion of outliers in the data set. Used when fitting to + define the threshold on the decision function. + + random_state : int or None, optional (default=None) + If int, random_state is the seed used by the random + number generator; + If None, the random number generator is the + RandomState instance used by `np.random`. + + device, 'cuda', 'cpu', or None, optional (default=None) + if 'cuda', use GPU acceleration in torch + if 'cpu', use cpu in torch + if None, automatically determine whether GPU is available + + + Attributes + ---------- + net_lst : list of torch.Module + The list of representation neural networks. + + iForest_lst : list of iForest + The list of instantiated iForest model. + + x_reduced_lst: list of numpy array + The list of training data representations + + decision_scores_ : numpy array of shape (n_samples,) + The outlier scores of the training data. + The higher, the more abnormal. Outliers tend to have higher + scores. This value is available once the detector is fitted. + + threshold_ : float + The threshold is based on ``contamination``. It is the + ``n_samples * contamination`` most abnormal samples in + ``decision_scores_``. The threshold is calculated for generating + binary outlier labels. + + labels_ : int, either 0 or 1 + The binary labels of the training data. 0 stands for inliers + and 1 for outliers/anomalies. It is generated by applying + ``threshold_`` on ``decision_scores_``. + """ + + def __init__(self, + batch_size=1000, + representation_dim=20, + hidden_neurons=None, + hidden_activation='tanh', + skip_connection=False, + n_ensemble=50, + n_estimators=6, + max_samples=256, + contamination=0.1, + random_state=None, + device=None): + super(DIF, self).__init__(contamination=contamination) + self.batch_size = batch_size + self.representation_dim = representation_dim + self.hidden_activation = hidden_activation + self.skip_connection = skip_connection + self.hidden_neurons = hidden_neurons + + self.n_ensemble = n_ensemble + self.n_estimators = n_estimators + self.max_samples = max_samples + + self.random_state = random_state + self.device = device + + self.minmax_scaler = None + + # create default calculation device (support GPU if available) + if self.device is None: + self.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu") + + # set random seed + if self.random_state is not None: + torch.manual_seed(self.random_state) + torch.cuda.manual_seed(self.random_state) + torch.cuda.manual_seed_all(self.random_state) + np.random.seed(self.random_state) + + # default values for the amount of hidden neurons + if self.hidden_neurons is None: + self.hidden_neurons = [500, 100] + + def fit(self, X, y=None): + """Fit detector. y is ignored in unsupervised methods. + + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The input samples. + + y : Ignored + Not used, present for API consistency by convention. + + Returns + ------- + self : object + Fitted estimator. + """ + # validate inputs X and y (optional) + X = check_array(X) + self._set_n_classes(y) + + n_samples, n_features = X.shape[0], X.shape[1] + + # conduct min-max normalization before feeding into neural networks + self.minmax_scaler = MinMaxScaler() + self.minmax_scaler.fit(X) + X = self.minmax_scaler.transform(X) + + # prepare neural network parameters + network_params = { + 'n_features': n_features, + 'n_hidden': self.hidden_neurons, + 'n_output': self.representation_dim, + 'activation': self.hidden_activation, + 'skip_connection': self.skip_connection + } + + # iteration + self.net_lst = [] + self.iForest_lst = [] + self.x_reduced_lst = [] + ensemble_seeds = np.random.randint(0, 100000, self.n_ensemble) + for i in range(self.n_ensemble): + # instantiate network class and seed random seed + net = MLPnet(**network_params).to(self.device) + torch.manual_seed(ensemble_seeds[i]) + + # initialize network parameters + for name, param in net.named_parameters(): + if name.endswith('weight'): + torch.nn.init.normal_(param, mean=0., std=1.) + + x_reduced = self._deep_representation(net, X) + + # save network and representations + self.x_reduced_lst.append(x_reduced) + self.net_lst.append(net) + + # perform iForest upon representations + self.iForest_lst.append( + IsolationForest(n_estimators=self.n_estimators, + max_samples=self.max_samples, + random_state=ensemble_seeds[i]) + ) + self.iForest_lst[i].fit(x_reduced) + + self.decision_scores_ = self.decision_function(X) + self._process_decision_scores() + return self + + def decision_function(self, X): + """Predict raw anomaly score of X using the fitted detector. + + The anomaly score of an input sample is computed based on different + detector algorithms. For consistency, outliers are assigned with + larger anomaly scores. + + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only + if they are supported by the base estimator. + + Returns + ------- + anomaly_scores : numpy array of shape (n_samples,) + The anomaly score of the input samples. + """ + check_is_fitted(self, ['net_lst', 'iForest_lst', 'x_reduced_lst']) + X = check_array(X) + + # conduct min-max normalization before feeding into neural networks + X = self.minmax_scaler.transform(X) + + testing_n_samples = X.shape[0] + score_lst = np.zeros([self.n_ensemble, testing_n_samples]) + + # iteration + for i in range(self.n_ensemble): + # transform testing data to representation + x_reduced = self._deep_representation(self.net_lst[i], X) + + # calculate outlier scores + scores = _cal_score(x_reduced, self.iForest_lst[i]) + score_lst[i] = scores + + final_scores = np.average(score_lst, axis=0) + return final_scores + + def _deep_representation(self, net, X): + x_reduced = [] + + with torch.no_grad(): + loader = DataLoader(X, batch_size=self.batch_size, drop_last=False, pin_memory=True, shuffle=False) + for batch_x in loader: + batch_x = batch_x.float().to(self.device) + batch_x_reduced = net(batch_x) + x_reduced.append(batch_x_reduced) + + x_reduced = torch.cat(x_reduced).data.cpu().numpy() + x_reduced = StandardScaler().fit_transform(x_reduced) + x_reduced = np.tanh(x_reduced) + return x_reduced + + +class MLPnet(torch.nn.Module): + def __init__(self, n_features, n_hidden=[500, 100], n_output=20, mid_channels=None, + activation='ReLU', bias=False, batch_norm=False, + skip_connection=False): + super(MLPnet, self).__init__() + self.skip_connection = skip_connection + self.n_output = n_output + + num_layers = len(n_hidden) + + if type(activation) == str: + activation = [activation] * num_layers + activation.append(None) + + assert len(activation) == len(n_hidden)+1, 'activation and n_hidden are not matched' + + self.layers = [] + for i in range(num_layers+1): + in_channels, out_channels = self.get_in_out_channels(i, num_layers, n_features, + n_hidden, n_output, skip_connection) + self.layers += [ + LinearBlock(in_channels, out_channels, + bias=bias, batch_norm=batch_norm, + activation=activation[i], + skip_connection=skip_connection if i != num_layers else False) + ] + self.network = torch.nn.Sequential(*self.layers) + + + def forward(self, x): + x = self.network(x) + return x + + @staticmethod + def get_in_out_channels(i, num_layers, n_features, n_hidden, n_output, skip_connection): + if skip_connection is False: + in_channels = n_features if i == 0 else n_hidden[i-1] + out_channels = n_output if i == num_layers else n_hidden[i] + else: + in_channels = n_features if i == 0 else np.sum(n_hidden[:i])+n_features + out_channels = n_output if i == num_layers else n_hidden[i] + return in_channels, out_channels + + +class LinearBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, + activation='Tanh', bias=False, batch_norm=False, + skip_connection=False): + super(LinearBlock, self).__init__() + + self.skip_connection = skip_connection + + self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias) + + if activation is not None: + # self.act_layer = _instantiate_class("torch.nn.modules.activation", activation) + self.act_layer = get_activation_by_name(activation) + else: + self.act_layer = torch.nn.Identity() + + self.batch_norm = batch_norm + if batch_norm is True: + dim = out_channels + self.bn_layer = torch.nn.BatchNorm1d(dim, affine=bias) + + def forward(self, x): + x1 = self.linear(x) + x1 = self.act_layer(x1) + + if self.batch_norm is True: + x1 = self.bn_layer(x1) + + if self.skip_connection: + x1 = torch.cat([x, x1], axis=1) + + return x1 + + +def _cal_score(xx, clf): + depths = np.zeros((xx.shape[0], len(clf.estimators_))) + depth_sum = np.zeros(xx.shape[0]) + deviations = np.zeros((xx.shape[0], len(clf.estimators_))) + leaf_samples = np.zeros((xx.shape[0], len(clf.estimators_))) + + for ii, estimator_tree in enumerate(clf.estimators_): + # estimator_population_ind = sample_without_replacement(n_population=xx.shape[0], n_samples=256, + # random_state=estimator_tree.random_state) + # estimator_population = xx[estimator_population_ind] + + tree = estimator_tree.tree_ + n_node = tree.node_count + + if n_node == 1: + continue + + # get feature and threshold of each node in the iTree + # in feature_lst, -2 indicates the leaf node + feature_lst, threshold_lst = tree.feature.copy(), tree.threshold.copy() + + # compute depth and score + leaves_index = estimator_tree.apply(xx) + node_indicator = estimator_tree.decision_path(xx) + + # The number of training samples in each test sample leaf + n_node_samples = estimator_tree.tree_.n_node_samples + + # node_indicator is a sparse matrix with shape (n_samples, n_nodes), indicating the path of input data samples + # each layer would result in a non-zero element in this matrix, + # and then the row-wise summation is the depth of data sample + n_samples_leaf = estimator_tree.tree_.n_node_samples[leaves_index] + d = (np.ravel(node_indicator.sum(axis=1)) + _average_path_length(n_samples_leaf) - 1.0) + depths[:, ii] = d + depth_sum += d + + # decision path of data matrix XX + node_indicator = np.array(node_indicator.todense()) + + # set a matrix with shape [n_sample, n_node], representing the feature value of each sample on each node + # set the leaf node as -2 + value_mat = np.array([xx[i][feature_lst] for i in range(xx.shape[0])]) + value_mat[:, np.where(feature_lst == -2)[0]] = -2 + th_mat = np.array([threshold_lst for _ in range(xx.shape[0])]) + + mat = np.abs(value_mat - th_mat) * node_indicator + + exist = (mat != 0) + dev = mat.sum(axis=1)/(exist.sum(axis=1)+1e-6) + deviations[:, ii] = dev + + scores = 2 ** (-depth_sum / (len(clf.estimators_) * _average_path_length([clf.max_samples_]))) + deviation = np.mean(deviations, axis=1) + leaf_sample = (clf.max_samples_ - np.mean(leaf_samples, axis=1)) / clf.max_samples_ + + scores = scores * deviation + # scores = scores * deviation * leaf_sample + return scores + + +def _average_path_length(n_samples_leaf): + """ + The average path length in a n_samples iTree, which is equal to + the average path length of an unsuccessful BST search since the + latter has the same structure as an isolation tree. + Parameters + ---------- + n_samples_leaf : array-like of shape (n_samples,) + The number of training samples in each test sample leaf, for + each estimators. + + Returns + ------- + average_path_length : ndarray of shape (n_samples,) + """ + + n_samples_leaf = check_array(n_samples_leaf, ensure_2d=False) + + n_samples_leaf_shape = n_samples_leaf.shape + n_samples_leaf = n_samples_leaf.reshape((1, -1)) + average_path_length = np.zeros(n_samples_leaf.shape) + + mask_1 = n_samples_leaf <= 1 + mask_2 = n_samples_leaf == 2 + not_mask = ~np.logical_or(mask_1, mask_2) + + average_path_length[mask_1] = 0. + average_path_length[mask_2] = 1. + average_path_length[not_mask] = ( + 2.0 * (np.log(n_samples_leaf[not_mask] - 1.0) + np.euler_gamma) + - 2.0 * (n_samples_leaf[not_mask] - 1.0) / n_samples_leaf[not_mask] + ) + + return average_path_length.reshape(n_samples_leaf_shape) diff --git a/pyod/test/test_dif.py b/pyod/test/test_dif.py new file mode 100644 index 000000000..14bf52c01 --- /dev/null +++ b/pyod/test/test_dif.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +from __future__ import division +from __future__ import print_function + +import os +import sys +import unittest + +import numpy as np +import torch +from numpy.testing import assert_almost_equal +# noinspection PyProtectedMember +from numpy.testing import assert_equal +from numpy.testing import assert_raises +from sklearn.metrics import roc_auc_score + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from pyod.models.dif import DIF +from pyod.utils.data import generate_data +from pyod.models.auto_encoder_torch import PyODDataset + + +class TestDIF(unittest.TestCase): + def setUp(self): + self.n_train = 3000 + self.n_test = 1000 + self.n_features = 200 + self.contamination = 0.1 + self.roc_floor = 0.8 + self.X_train, self.X_test, self.y_train, self.y_test = generate_data( + n_train=self.n_train, n_test=self.n_test, + n_features=self.n_features, contamination=self.contamination, + random_state=42) + + self.clf = DIF(skip_connection=True, contamination=self.contamination) + self.clf.fit(self.X_train) + + self.clf2 = DIF(skip_connection=False, contamination=self.contamination) + self.clf2.fit(self.X_train) + + def test_parameters(self): + assert (hasattr(self.clf, 'decision_scores_') and + self.clf.decision_scores_ is not None) + assert (hasattr(self.clf, 'labels_') and + self.clf.labels_ is not None) + assert (hasattr(self.clf, 'threshold_') and + self.clf.threshold_ is not None) + + def test_train_scores(self): + assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) + assert_equal(len(self.clf2.decision_scores_), self.X_train.shape[0]) + + def test_prediction_scores(self): + pred_scores = self.clf.decision_function(self.X_test) + pred_scores2 = self.clf2.decision_function(self.X_test) + + # check score shapes + assert_equal(pred_scores.shape[0], self.X_test.shape[0]) + assert_equal(pred_scores2.shape[0], self.X_test.shape[0]) + + # check performance + assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) + assert (roc_auc_score(self.y_test, pred_scores2) >= self.roc_floor) + + def test_prediction_labels(self): + pred_labels = self.clf.predict(self.X_test) + assert_equal(pred_labels.shape, self.y_test.shape) + + def test_prediction_proba(self): + pred_proba = self.clf.predict_proba(self.X_test) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_linear(self): + pred_proba = self.clf.predict_proba(self.X_test, method='linear') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_unify(self): + pred_proba = self.clf.predict_proba(self.X_test, method='unify') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_parameter(self): + with assert_raises(ValueError): + self.clf.predict_proba(self.X_test, method='something') + + def test_prediction_labels_confidence(self): + pred_labels, confidence = self.clf.predict(self.X_test, + return_confidence=True) + assert_equal(pred_labels.shape, self.y_test.shape) + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_prediction_proba_linear_confidence(self): + pred_proba, confidence = self.clf.predict_proba(self.X_test, + method='linear', + return_confidence=True) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_fit_predict(self): + pred_labels = self.clf.fit_predict(self.X_train) + assert_equal(pred_labels.shape, self.y_train.shape) + + def test_fit_predict_score(self): + self.clf.fit_predict_score(self.X_test, self.y_test) + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='roc_auc_score') + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='prc_n_score') + with assert_raises(NotImplementedError): + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='something') + + def test_model_clone(self): + pass + # clone_clf = clone(self.clf) + + def tearDown(self): + pass + + +if __name__ == '__main__': + unittest.main() From 8c00d8a3630640d85d38c69a1805b0bdcd4e985c Mon Sep 17 00:00:00 2001 From: xuhongzuo Date: Tue, 5 Sep 2023 09:59:35 +0800 Subject: [PATCH 2/2] Deep Isolation Forest method implemented --- pyod/models/dif.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/pyod/models/dif.py b/pyod/models/dif.py index e75f05d1d..1128f18bf 100644 --- a/pyod/models/dif.py +++ b/pyod/models/dif.py @@ -15,9 +15,6 @@ from sklearn.preprocessing import StandardScaler, MinMaxScaler from torch.utils.data import DataLoader -# from pyod.models.base import BaseDetector -# from pyod.utils.utility import check_parameter -# from pyod.utils.torch_utility import get_activation_by_name from .base import BaseDetector from ..utils.utility import check_parameter from ..utils.torch_utility import get_activation_by_name @@ -26,8 +23,7 @@ class DIF(BaseDetector): """Deep Isolation Forest (DIF) is an extension of iForest. It uses deep representation ensemble to achieve non-linear isolation on original data - space. See :cite:`xu2023dif` - for details. + space. See :cite:`xu2023dif` for details. Parameters ---------- @@ -265,7 +261,9 @@ def _deep_representation(self, net, X): x_reduced = [] with torch.no_grad(): - loader = DataLoader(X, batch_size=self.batch_size, drop_last=False, pin_memory=True, shuffle=False) + loader = DataLoader(X, batch_size=self.batch_size, + drop_last=False, pin_memory=True, + shuffle=False) for batch_x in loader: batch_x = batch_x.float().to(self.device) batch_x_reduced = net(batch_x) @@ -278,7 +276,7 @@ def _deep_representation(self, net, X): class MLPnet(torch.nn.Module): - def __init__(self, n_features, n_hidden=[500, 100], n_output=20, mid_channels=None, + def __init__(self, n_features, n_hidden=[500, 100], n_output=20, activation='ReLU', bias=False, batch_norm=False, skip_connection=False): super(MLPnet, self).__init__() @@ -295,8 +293,9 @@ def __init__(self, n_features, n_hidden=[500, 100], n_output=20, mid_channels=No self.layers = [] for i in range(num_layers+1): - in_channels, out_channels = self.get_in_out_channels(i, num_layers, n_features, - n_hidden, n_output, skip_connection) + in_channels, out_channels = \ + self.get_in_out_channels(i, num_layers, n_features, + n_hidden, n_output, skip_connection) self.layers += [ LinearBlock(in_channels, out_channels, bias=bias, batch_norm=batch_norm, @@ -362,10 +361,6 @@ def _cal_score(xx, clf): leaf_samples = np.zeros((xx.shape[0], len(clf.estimators_))) for ii, estimator_tree in enumerate(clf.estimators_): - # estimator_population_ind = sample_without_replacement(n_population=xx.shape[0], n_samples=256, - # random_state=estimator_tree.random_state) - # estimator_population = xx[estimator_population_ind] - tree = estimator_tree.tree_ n_node = tree.node_count @@ -383,7 +378,8 @@ def _cal_score(xx, clf): # The number of training samples in each test sample leaf n_node_samples = estimator_tree.tree_.n_node_samples - # node_indicator is a sparse matrix with shape (n_samples, n_nodes), indicating the path of input data samples + # node_indicator is a sparse matrix with shape (n_samples, n_nodes), + # indicating the path of input data samples # each layer would result in a non-zero element in this matrix, # and then the row-wise summation is the depth of data sample n_samples_leaf = estimator_tree.tree_.n_node_samples[leaves_index] @@ -394,7 +390,8 @@ def _cal_score(xx, clf): # decision path of data matrix XX node_indicator = np.array(node_indicator.todense()) - # set a matrix with shape [n_sample, n_node], representing the feature value of each sample on each node + # set a matrix with shape [n_sample, n_node], + # representing the feature value of each sample on each node # set the leaf node as -2 value_mat = np.array([xx[i][feature_lst] for i in range(xx.shape[0])]) value_mat[:, np.where(feature_lst == -2)[0]] = -2