diff --git a/pyod/models/deep_svdd.py b/pyod/models/deep_svdd.py index 4f3c8f1f..e76ec5ca 100644 --- a/pyod/models/deep_svdd.py +++ b/pyod/models/deep_svdd.py @@ -73,7 +73,7 @@ class InnerDeepSVDD(nn.Module): def __init__(self, n_features, use_ae, hidden_neurons, hidden_activation, output_activation, - dropout_rate, l2_regularizer): + dropout_rate, l2_regularizer, input_shape=None): super(InnerDeepSVDD, self).__init__() self.n_features = n_features self.use_ae = use_ae @@ -82,61 +82,52 @@ def __init__(self, n_features, use_ae, self.output_activation = output_activation self.dropout_rate = dropout_rate self.l2_regularizer = l2_regularizer + self.input_shape = input_shape self.model = self._build_model() + self.c = None # Center of the hypersphere for DeepSVDD def _init_c(self, X_norm, eps=0.1): intermediate_output = {} - hook_handle = self.model._modules.get( - 'net_output').register_forward_hook( - lambda module, input, output: intermediate_output.update( - {'net_output': output}) + hook_handle = self.model._modules.get('net_output').register_forward_hook( + lambda module, input, output: intermediate_output.update({'net_output': output}) ) output = self.model(X_norm) - out = intermediate_output['net_output'] hook_handle.remove() - self.c = torch.mean(out, dim=0) self.c[(torch.abs(self.c) < eps) & (self.c < 0)] = -eps self.c[(torch.abs(self.c) < eps) & (self.c > 0)] = eps def _build_model(self): layers = nn.Sequential() - layers.add_module('input_layer', - nn.Linear(self.n_features, self.hidden_neurons[0], - bias=False)) - layers.add_module('hidden_activation_e0', - get_activation_by_name(self.hidden_activation)) + channels = self.input_shape[0] + layers.add_module('cnn_layer1', nn.Conv2d(channels, 32, kernel_size=3, stride=1, padding=1)) + layers.add_module('cnn_activation1', nn.ReLU()) + layers.add_module('cnn_pool', nn.MaxPool2d(kernel_size=2, stride=2)) + layers.add_module('cnn_layer2', nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)) + layers.add_module('cnn_activation2', nn.ReLU()) + layers.add_module('cnn_adaptive_pool', nn.AdaptiveMaxPool2d((64, 64))) + layers.add_module('flatten', nn.Flatten()) + layers.add_module('cnn_fc', nn.Linear(64 * 64 * 64, self.n_features, bias=False)) + layers.add_module('cnn_fc_activation', nn.ReLU()) + layers.add_module('input_layer', nn.Linear(self.n_features, self.hidden_neurons[0], bias=False)) + layers.add_module('hidden_activation_e0', get_activation_by_name(self.hidden_activation)) for i in range(1, len(self.hidden_neurons) - 1): - layers.add_module(f'hidden_layer_e{i}', - nn.Linear(self.hidden_neurons[i - 1], - self.hidden_neurons[i], bias=False)) - layers.add_module(f'hidden_activation_e{i}', - get_activation_by_name(self.hidden_activation)) - layers.add_module(f'hidden_dropout_e{i}', - nn.Dropout(self.dropout_rate)) - layers.add_module(f'net_output', nn.Linear(self.hidden_neurons[-2], - self.hidden_neurons[-1], - bias=False)) - layers.add_module(f'hidden_activation_e{len(self.hidden_neurons)}', - get_activation_by_name(self.hidden_activation)) + layers.add_module(f'hidden_layer_e{i}', nn.Linear(self.hidden_neurons[i - 1], self.hidden_neurons[i], bias=False)) + layers.add_module(f'hidden_activation_e{i}', get_activation_by_name(self.hidden_activation)) + layers.add_module(f'hidden_dropout_e{i}', nn.Dropout(self.dropout_rate)) + layers.add_module('net_output', nn.Linear(self.hidden_neurons[-2], self.hidden_neurons[-1], bias=False)) + layers.add_module(f'hidden_activation_e{len(self.hidden_neurons)}', get_activation_by_name(self.hidden_activation)) if self.use_ae: + # Add reverse layers for the autoencoder if needed for j in range(len(self.hidden_neurons) - 1, 0, -1): - layers.add_module(f'hidden_layer_d{j}', - nn.Linear(self.hidden_neurons[j], - self.hidden_neurons[j - 1], - bias=False)) - layers.add_module(f'hidden_activation_d{j}', - get_activation_by_name( - self.hidden_activation)) - layers.add_module(f'hidden_dropout_d{j}', - nn.Dropout(self.dropout_rate)) - layers.add_module(f'output_layer', - nn.Linear(self.hidden_neurons[0], - self.n_features, bias=False)) - layers.add_module(f'output_activation', - get_activation_by_name(self.output_activation)) + layers.add_module(f'hidden_layer_d{j}', nn.Linear(self.hidden_neurons[j], self.hidden_neurons[j - 1], bias=False)) + layers.add_module(f'hidden_activation_d{j}', get_activation_by_name(self.hidden_activation)) + layers.add_module(f'hidden_dropout_d{j}', nn.Dropout(self.dropout_rate)) + layers.add_module('output_layer', nn.Linear(self.hidden_neurons[0], self.n_features, bias=False)) + layers.add_module('output_activation', get_activation_by_name(self.output_activation)) + return layers def forward(self, x): @@ -155,7 +146,7 @@ class DeepSVDD(BaseDetector): Parameters ---------- - n_features: int, + n_features: int, Number of features in the input data. c: float, optional (default='forwad_nn_pass') @@ -242,7 +233,7 @@ def __init__(self, n_features, c=None, use_ae=False, hidden_neurons=None, batch_size=32, dropout_rate=0.2, l2_regularizer=0.1, validation_size=0.1, preprocessing=True, - verbose=1, random_state=None, contamination=0.1): + verbose=1, random_state=None, contamination=0.1, input_shape=None): super(DeepSVDD, self).__init__(contamination=contamination) self.n_features = n_features @@ -262,6 +253,7 @@ def __init__(self, n_features, c=None, use_ae=False, hidden_neurons=None, self.random_state = random_state self.model_ = None self.best_model_dict = None + self.input_shape = input_shape if self.random_state is not None: torch.manual_seed(self.random_state) @@ -273,7 +265,7 @@ def fit(self, X, y=None): Parameters ---------- - X : numpy array of shape (n_samples, n_features) + X : list or numpy array of shape (n_samples, channels, height, width) The input samples. y : Ignored @@ -284,59 +276,46 @@ def fit(self, X, y=None): self : object Fitted estimator. """ - # validate inputs X and y (optional) - X = check_array(X) - self._set_n_classes(y) - - # Verify and construct the hidden units - self.n_samples_, self.n_features_ = X.shape[0], X.shape[1] - - # Standardize data for better performance - if self.preprocessing: - self.scaler_ = StandardScaler() - X_norm = self.scaler_.fit_transform(X) - else: - X_norm = np.copy(X) - - # Shuffle the data for validation as Keras do not shuffling for - # Validation Split - np.random.shuffle(X_norm) - - # Validate and complete the number of hidden neurons - if np.min(self.hidden_neurons) > self.n_features_ and self.use_ae: - raise ValueError("The number of neurons should not exceed " - "the number of features") - - # Build DeepSVDD model & fit with X - self.model_ = InnerDeepSVDD(self.n_features, use_ae=self.use_ae, - hidden_neurons=self.hidden_neurons, - hidden_activation=self.hidden_activation, - output_activation=self.output_activation, - dropout_rate=self.dropout_rate, - l2_regularizer=self.l2_regularizer) - X_norm = torch.tensor(X_norm, dtype=torch.float32) + # Convert to tensor directly for 4D data and normalize if needed + if isinstance(X, np.ndarray): + X = torch.tensor(X, dtype=torch.float32) + + # Normalize the data (e.g., rescale if pixel values are in the range [0, 255]) + if X.max() > 1: + X = X / 255.0 + + # Set CNN input shape directly + self.input_shape = X.shape[1:] # (channels, height, width) + + # Initialize the DeepSVDD model with updated input shape + self.model_ = InnerDeepSVDD( + n_features=self.n_features, # Now determined by CNN output + use_ae=self.use_ae, + hidden_neurons=self.hidden_neurons, + hidden_activation=self.hidden_activation, + output_activation=self.output_activation, + dropout_rate=self.dropout_rate, + l2_regularizer=self.l2_regularizer, + input_shape=self.input_shape, + ) + + # No need to standardize further if CNN is extracting features directly + X_norm = X + + # Initialize center c for DeepSVDD if self.c is None: self.c = 0.0 self.model_._init_c(X_norm) - # Predict on X itself and calculate the reconstruction error as - # the outlier scores. Noted X_norm was shuffled has to recreate - if self.preprocessing: - X_norm = self.scaler_.transform(X) - else: - X_norm = np.copy(X) - - X_norm = torch.tensor(X_norm, dtype=torch.float32) + # Prepare DataLoader for batch processing dataset = TensorDataset(X_norm, X_norm) - dataloader = DataLoader(dataset, batch_size=self.batch_size, - shuffle=True) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) best_loss = float('inf') best_model_dict = None + optimizer = optimizer_dict[self.optimizer](self.model_.parameters(), weight_decay=self.l2_regularizer) + # w_d = 1e-6 * sum([torch.linalg.norm(w) for w in self.model_.parameters()]) - optimizer = optimizer_dict[self.optimizer](self.model_.parameters(), - weight_decay=self.l2_regularizer) - for epoch in range(self.epochs): self.model_.train() epoch_loss = 0 @@ -345,10 +324,9 @@ def fit(self, X, y=None): dist = torch.sum((outputs - self.c) ** 2, dim=-1) w_d = 1e-6 * sum([torch.linalg.norm(w) for w in self.model_.parameters()]) - + if self.use_ae: - loss = torch.mean(dist) + w_d + torch.mean( - torch.square(outputs - batch_x)) + loss = torch.mean(dist) + w_d + torch.mean(torch.square(outputs - batch_x)) else: loss = torch.mean(dist) + w_d @@ -356,10 +334,11 @@ def fit(self, X, y=None): loss.backward() optimizer.step() epoch_loss += loss.item() - if epoch_loss < best_loss: - best_loss = epoch_loss - best_model_dict = self.model_.state_dict() + epoch_loss /= len(dataloader) print(f"Epoch {epoch + 1}/{self.epochs}, Loss: {epoch_loss}") + if epoch_loss < best_loss: + best_loss = epoch_loss + best_model_dict = self.model_.state_dict() self.best_model_dict = best_model_dict self.decision_scores_ = self.decision_function(X) @@ -369,32 +348,29 @@ def fit(self, X, y=None): def decision_function(self, X): """Predict raw anomaly score of X using the fitted detector. - The anomaly score of an input sample is computed based on different - detector algorithms. For consistency, outliers are assigned with - larger anomaly scores. + The anomaly score of an input sample is computed based on the DeepSVDD model. + Outliers are assigned with larger anomaly scores. Parameters ---------- - X : numpy array of shape (n_samples, n_features) - The training input samples. Sparse matrices are accepted only - if they are supported by the base estimator. + X : numpy array of shape (n_samples, channels, height, width) + The input samples. Returns ------- anomaly_scores : numpy array of shape (n_samples,) The anomaly score of the input samples. """ - # check_is_fitted(self, ['model_', 'history_']) - X = check_array(X) - - if self.preprocessing: - X_norm = self.scaler_.transform(X) - else: - X_norm = np.copy(X) - X_norm = torch.tensor(X_norm, dtype=torch.float32) + # Convert X to tensor if it isn't already, and normalize if needed + if isinstance(X, np.ndarray): + X = torch.tensor(X, dtype=torch.float32) + + # Normalize data if pixel values are in [0, 255] range + if X.max() > 1: + X = X / 255.0 self.model_.eval() with torch.no_grad(): - outputs = self.model_(X_norm) + outputs = self.model_(X) dist = torch.sum((outputs - self.c) ** 2, dim=-1) - anomaly_scores = dist.numpy() + anomaly_scores = dist.cpu().numpy() return anomaly_scores