diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/_cfam.py b/tsml_eval/_wip/rt/transformations/collection/imbalance/_cfam.py new file mode 100644 index 00000000..d24712a6 --- /dev/null +++ b/tsml_eval/_wip/rt/transformations/collection/imbalance/_cfam.py @@ -0,0 +1,83 @@ +import os +import argparse +import numpy as np +import pandas as pd +import torch +from collections import OrderedDict +from typing import Optional, Union + +from sklearn.utils import check_random_state +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.cfamg import CFAMG +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.main import parse_args +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.data_preprocess import set_seed +from aeon.transformations.collection import BaseCollectionTransformer + +__all__ = ["CFAM"] + + +class CFAM(BaseCollectionTransformer): + _tags = { + "capability:multivariate": True, + "capability:unequal_length": False, + "requires_y": True, + } + + def __init__(self, + random_state=None, + ): + self.random_state = random_state + self.generated_samples = None + set_seed(self.random_state) + super().__init__() + + def _fit(self, X, y=None): + args = parse_args() + args.w_lambda, args.w_beta = 1, 1 + unique, counts = np.unique(y, return_counts=True) + target_stats = dict(zip(unique, counts)) + class_majority = max(target_stats, key=target_stats.get) + class_minority = min(target_stats, key=target_stats.get) + X_majority, X_minority = X[y == class_majority], X[y == class_minority] + y_majority, y_minority = y[y == class_majority], y[y == class_minority] + + class_label_project = {class_majority: 0, class_minority: 1} + self.class_label_project = class_label_project + y = np.array([class_label_project[label] for label in y]) + y_majority = np.array([class_label_project[label] for label in y_majority]) + y_minority = np.array([class_label_project[label] for label in y_minority]) + + ir = len(y_majority) / len(y_minority) + dataset = { + "train_data": (X, y), + "train_data_pos": (X_minority, y_minority), + "train_data_neg": (X_majority, y_majority), + "ir": ir, + } + args.dataset = dataset + self.CFAMG_model = CFAMG(args) + self.CFAMG_model.train_on_data() + return self + + def _transform(self, X, y=None): + X_train, y_train, generated_samples = self.CFAMG_model.generator_sample() + inv_class_label_project = {v: k for k, v in self.class_label_project.items()} + y_train = np.array([inv_class_label_project[label] for label in y_train]) + self.generated_samples = generated_samples + return X_train, y_train + + +if __name__ == "__main__": + + n_samples = 100 # Total number of labels + majority_num = 90 # number of majority class + minority_num = n_samples - majority_num # number of minority class + np.random.seed(42) + + X = np.random.rand(n_samples, 1, 10) + y = np.array([0] * majority_num + [1] * minority_num) + print(np.unique(y, return_counts=True)) + smote = CFAM(random_state=42) + + X_resampled, y_resampled = smote.fit_transform(X, y) + print(X_resampled.shape) + print(np.unique(y_resampled, return_counts=True)) diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/FiftyWords/X_train.npy b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/FiftyWords/X_train.npy new file mode 100644 index 00000000..658e88cc Binary files /dev/null and b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/FiftyWords/X_train.npy differ diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/FiftyWords/y_train.npy b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/FiftyWords/y_train.npy new file mode 100644 index 00000000..68fcbb31 Binary files /dev/null and b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/FiftyWords/y_train.npy differ diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/README.md b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/README.md new file mode 100644 index 00000000..a578b4be --- /dev/null +++ b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/README.md @@ -0,0 +1,112 @@ +# Mitigating Data Imbalance in Time Series Classification Based on Counterfactual Minority Samples Augmentation + +In this repository, we provide the original PyTorch implementation of CFAMG framework. + +## Dataset + +We conducted extensive experiments on 53 real-world benchmark datasets selected from the UCR and UEA time series classification repositories. For details, please refer to the following table: +| Type | Dataset | Dimensions | Pos Class | #Length | | Training | | | Test | | +|:------------:|:------------------------------:|:----------:|:----------------------:|:-------:|:----:|:--------:|:------:|:----:|:----:|:-------:| +| | | | | | #Neg | #Pos | IR | #Neg | #Pos | IR | +| Univariate | ACSF1 | 1 | 1 | 1460 | 90 | 10 | 9.00 | 90 | 10 | 9.00 | +| Multivariate | ArticularyWordRecognition | 9 | 25.0 | 144 | 264 | 11 | 24.00 | 288 | 12 | 24.00 | +| Univariate | Car | 1 | 3 | 577 | 49 | 11 | 4.45 | 41 | 19 | 2.16 | +| Univariate | Computers | 1 | 2 | 720 | 125 | 41 | 3.05 | 125 | 125 | 1.00 | +| Univariate | DodgerLoopDay | 1 | 5, 2 | 288 | 51 | 16 | 3.19 | 56 | 21 | 2.67 | +| Univariate | Earthquakes | 1 | 1 | 512 | 264 | 58 | 4.55 | 104 | 35 | 2.97 | +| Univariate | ECG5000 | 1 | 1 | 140 | 208 | 97 | 2.14 | 1873 | 2627 | 0.71 | +| Univariate | EthanolLevel | 1 | 4 | 1751 | 378 | 126 | 3.00 | 376 | 124 | 3.03 | +| Univariate | FaceAll | 1 | 14 | 131 | 520 | 40 | 13.00 | 1658 | 32 | 51.81 | +| Univariate | FacesUCR | 1 | 11, 14 | 131 | 190 | 10 | 19.00 | 1940 | 110 | 17.64 | +| Univariate | FiftyWords | 1 | 41, 50, 49, 42, 25, 34 | 270 | 438 | 12 | 36.50 | 423 | 32 | 13.22 | +| Univariate | Fish | 1 | 4 | 463 | 154 | 21 | 7.33 | 146 | 29 | 5.03 | +| Univariate | FreezerRegularTrain | 1 | 2 | 301 | 75 | 25 | 3.00 | 1425 | 1425 | 1.00 | +| Univariate | GestureMidAirD2 | 1 | 23, 26 | 360 | 192 | 16 | 12.00 | 120 | 10 | 12.00 | +| Univariate | GestureMidAirD3 | 1 | 3, 13 | 360 | 192 | 16 | 12.00 | 120 | 10 | 12.00 | +| Univariate | GesturePebbleZ1 | 1 | 2 | 455 | 112 | 20 | 5.60 | 147 | 25 | 5.88 | +| Univariate | GesturePebbleZ2 | 1 | 2 | 455 | 123 | 23 | 5.35 | 136 | 22 | 6.18 | +| Univariate | GunPoint | 1 | 1 | 150 | 26 | 10 | 2.60 | 74 | 76 | 0.97 | +| Univariate | GunPointAgeSpan | 1 | 2 | 150 | 68 | 22 | 3.09 | 160 | 156 | 1.03 | +| Univariate | Ham | 1 | 1 | 431 | 57 | 17 | 3.35 | 54 | 51 | 1.06 | +| Multivariate | HandMovementDirection | 10 | backward | 400 | 120 | 40 | 3.00 | 59 | 15 | 3.93 | +| Univariate | Haptics | 1 | 1 | 1092 | 137 | 18 | 7.61 | 248 | 60 | 4.13 | +| Univariate | Herring | 1 | 2 | 512 | 39 | 10 | 3.90 | 38 | 26 | 1.46 | +| Univariate | InlineSkate | 1 | 1, 7 | 1882 | 80 | 20 | 4.00 | 446 | 104 | 4.29 | +| Univariate | ItalyPowerDemand | 1 | 2 | 24 | 34 | 11 | 3.09 | 513 | 516 | 0.99 | +| Univariate | Lightning2 | 1 | -1 | 637 | 40 | 20 | 2.00 | 33 | 28 | 1.18 | +| Univariate | Mallat | 1 | 5, 1, 3 | 1024 | 41 | 14 | 2.93 | 1459 | 886 | 1.65 | +| Univariate | MedicalImages | 1 | 8, 6 | 99 | 368 | 13 | 28.31 | 726 | 34 | 21.35 | +| Univariate | MelbournePedestrian | 1 | 9 | 24 | 1040 | 98 | 10.61 | 2129 | 190 | 11.21 | +| Univariate | MiddlePhalanxOutlineCorrect | 1 | 0 | 80 | 388 | 70 | 5.54 | 166 | 125 | 1.33 | +| Univariate | MixedShapesRegularTrain | 1 | 5 | 1024 | 400 | 100 | 4.00 | 2111 | 314 | 6.72 | +| Univariate | MixedShapesSmallTrain | 1 | 5 | 1024 | 80 | 20 | 4.00 | 2111 | 314 | 6.72 | +| Multivariate | NATOPS | 24 | 2.0 | 51 | 150 | 30 | 5.00 | 150 | 30 | 5.00 | +| Univariate | PickupGestureWiimoteZ | 1 | 10, 9 | 361 | 40 | 10 | 4.00 | 40 | 10 | 4.00 | +| Univariate | PigAirwayPressure | 1 | 52, 51, 50, 49, 48 | 2000 | 94 | 10 | 9.40 | 188 | 20 | 9.40 | +| Univariate | PigArtPressure | 1 | 52, 51, 50, 49, 48 | 2000 | 94 | 10 | 9.40 | 188 | 20 | 9.40 | +| Univariate | ProximalPhalanxOutlineAgeGroup | 1 | 1 | 80 | 328 | 72 | 4.56 | 188 | 17 | 11.06 | +| Univariate | ProximalPhalanxTW | 1 | 3 | 80 | 384 | 16 | 24.00 | 203 | 2 | 101.50 | +| Multivariate | RacketSports | 6 | squash_backhandboast | 30 | 117 | 34 | 3.44 | 118 | 34 | 3.47 | +| Univariate | RefrigerationDevices | 1 | 3 | 720 | 250 | 125 | 2.00 | 250 | 125 | 2.00 | +| Univariate | ScreenType | 1 | 3 | 720 | 250 | 125 | 2.00 | 250 | 125 | 2.00 | +| Multivariate | SelfRegulationSCP2 | 7 | positivity | 1152 | 100 | 33 | 3.03 | 90 | 90 | 1.00 | +| Univariate | SemgHandGenderCh2 | 1 | 2 | 1500 | 150 | 50 | 3.00 | 390 | 210 | 1.86 | +| Univariate | SemgHandMovementCh2 | 1 | 6 | 1500 | 375 | 75 | 5.00 | 375 | 75 | 5.00 | +| Univariate | ShakeGestureWiimoteZ | 1 | 10, 9 | 385 | 40 | 10 | 4.00 | 40 | 10 | 4.00 | +| Univariate | SmallKitchenAppliances | 1 | 3 | 720 | 250 | 125 | 2.00 | 250 | 125 | 2.00 | +| Univariate | SmoothSubspace | 1 | 3 | 15 | 100 | 50 | 2.00 | 100 | 50 | 2.00 | +| Univariate | Strawberry | 1 | 1 | 235 | 394 | 73 | 5.40 | 238 | 132 | 1.80 | +| Univariate | Trace | 1 | 2 | 275 | 79 | 21 | 3.76 | 71 | 29 | 2.45 | +| Univariate | TwoPatterns | 1 | 2 | 128 | 763 | 237 | 3.22 | 2989 | 1011 | 2.96 | +| Univariate | Wine | 1 | 2 | 234 | 30 | 10 | 3.00 | 27 | 27 | 1.00 | +| Univariate | Worms | 1 | 5 | 900 | 164 | 17 | 9.65 | 69 | 8 | 8.62 | +| Univariate | WormsTwoClass | 1 | 1 | 900 | 105 | 25 | 4.20 | 44 | 33 | 1.33 | + +## Usage + +#### Requirements + +The code was trained with `python 3.8`, `pytorch 1.13.1, `cuda 11.7`, and `cudnn 8.5.0`. + + ```shell +# create virtual environment +conda create --name CFAMG python=3.8 + +# activate environment +conda activate CFAMG + +# Install dependencies +conda env create -f environment.yml + +#### Run code + +The UCR and UEA datasets can be accessed through the **tslearn** package, which can be installed as follows: + +```shell +pip install tslearn + +```python +from tslearn.datasets import UCR_UEA_datasets + +To train and generate minority samples CFAMG framework on a dataset, run the following command: + +```shell +python main.py --log_dir --save_freq --num_epochs --latent_size --batch_size --lr --tensorboard --seed ``` + +We choose classifiers from the **tsai library** as a unified benchmark, and tsai can be installed as follows + +```shell +pip install tsai + + +## Directory Structure + +The code directory structure is shown as follows: +```shell + +CFAMG +├── main.py # entry for model training +├── cfamg.py # training, generating minority samples of CFAMG +├── data_preprocess.py # dataset loading and preprocessing +├── model_utlis.py # common utility functions +├── network.py # network layer definition diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/__init__.py b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/cfamg.py b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/cfamg.py new file mode 100644 index 00000000..2ca18114 --- /dev/null +++ b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/cfamg.py @@ -0,0 +1,397 @@ +import os +from collections import OrderedDict + +import numpy as np +from torch import nn +import torch.nn.functional as F +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.data_preprocess import create_dataLoader2 +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.model_utils import resample_from_normal +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.network import HiddenLayerMLP +import torch + + +class CriticFunc(nn.Module): + def __init__(self, x_dim, y_dim, dropout=0.1): + super().__init__() + cat_dim = x_dim + y_dim + self.critic = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(cat_dim, cat_dim // 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(cat_dim // 4, 1), + nn.Sigmoid() + ) + + def forward(self, x, y): + cat = torch.cat((x, y), dim=-1) + return self.critic(cat) + + +class DisentangledMILoss(nn.Module): + def __init__(self, x_dim, y_dim, dropout=0.1): + super().__init__() + self.critic_st = CriticFunc(x_dim, y_dim, dropout) + + def forward(self, z_c, z_nc): + # I(Z_c, Z_nc) + idx = torch.randperm(z_nc.shape[0]) + zc_shuffle = z_nc[idx].view(z_nc.size()) + f_cnc = self.critic_st(z_c, z_nc) + f_c_nc = self.critic_st(z_c, zc_shuffle) + mubo = f_cnc - f_c_nc + pos_mask = torch.zeros_like(f_cnc) + pos_mask[mubo < 0] = 1 + mubo_mask = mubo * pos_mask + reg = (mubo_mask ** 2).mean() + return mubo.mean() + reg + + +class HiddenNetworksOPtions(nn.Module): + def __init__(self, + input_dim, + hidden_dim, + dropout_list=None, + model_type='encoder'): + super().__init__() + self.model_type = model_type + self.hidden_net = HiddenLayerMLP(input_dim=input_dim, hidden_dim=hidden_dim, model_type=model_type, + dropout_list=dropout_list) + + def forward(self, x): + return self.hidden_net(x) + + +class VAE(nn.Module): + def __init__(self, + feat_dim, + seq_len, + latent_dim, + hidden_dim, + dropout_list + ): + super().__init__() + + self.feat_dim = feat_dim + self.seq_len = seq_len + + self.encoder_network = HiddenNetworksOPtions(input_dim=feat_dim * seq_len, + hidden_dim=hidden_dim, + dropout_list=dropout_list, + model_type='encoder') + + self.encoder_hidden = hidden_dim[-1] if isinstance(hidden_dim, list) else hidden_dim + + self.mean_net = nn.ModuleList([nn.Linear(self.encoder_hidden, latent_dim // 2) for _ in range(2)]) + self.var_net = nn.ModuleList( + [nn.Sequential(nn.Linear(self.encoder_hidden, latent_dim // 2), nn.Softplus()) for _ in + range(2)]) + + self.output = nn.Sequential( + HiddenNetworksOPtions(input_dim=latent_dim, hidden_dim=hidden_dim, dropout_list=dropout_list, + model_type='decoder'), + nn.Linear(hidden_dim[0] if isinstance(hidden_dim, list) else hidden_dim, feat_dim * seq_len)) + + def decoder_output(self, z): + output = self.output(z) + output = output.view(output.size(0), self.feat_dim, self.seq_len) + return output + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.encoder_network(x) + mean = [fc_mu(x) for fc_mu in self.mean_net] + var = [fc_var(x) for fc_var in self.var_net] + qz = [resample_from_normal(mean[i], var[i]) for i in range(2)] + return qz, mean, var + + +class CFAMGVAE(nn.Module): + def __init__(self, + feat_dim, + seq_len, + latent_dim, + hidden_dim, + dropout_list + ): + super().__init__() + + self.vae = VAE(feat_dim, seq_len, latent_dim, hidden_dim, dropout_list) + + self.classifier = nn.Sequential( + nn.Linear(latent_dim, 2), + nn.ReLU(), + nn.Softmax(dim=1)) + + self.criterion = nn.CrossEntropyLoss(reduction='none') + + self.mse = nn.MSELoss() + + self.MINet = DisentangledMILoss(latent_dim // 2, latent_dim // 2) + + def compute_KL2(self, z_q_mean, z_q_var): + z_q_var = torch.clamp(z_q_var, min=1e-8) + kl_divergence = -0.5 * torch.sum(1 + torch.log(z_q_var) - z_q_mean ** 2 - z_q_var, dim=-1) + return torch.mean(kl_divergence) + + def compute_classifier_loss(self, input, label): + pred_label = self.classifier(input) + if label.dtype != torch.long: + label = label.long() + loss = self.criterion(pred_label, label) + return loss.mean() + + def swap_classifier_loss(self, sample10, sample11, sample21, label): + # swap posority class loss, label: pos + swap_index = np.random.permutation(sample11.size(0)) + swap = torch.cat([sample10.detach().clone(), sample11[swap_index].detach().clone()], dim=1) + swap_loss1 = self.compute_classifier_loss(swap, label) + # concat causal pos and non-causal neg loss, label: pos + c_nc_cat = torch.cat([sample10.detach().clone(), sample21.detach().clone()], dim=1) + swap_pos_loss2 = self.compute_classifier_loss(c_nc_cat, label) + loss = swap_loss1 + swap_pos_loss2 + return loss + + def computer_loss(self, mu_list, var_list, z_sample, x, output): + # computer KL Loss + kl_loss = 0.0 + count = 0 + for mu, var in zip([*mu_list], [*var_list]): + kl_loss += self.compute_KL2(mu, var) + count += 1 + kl_loss = kl_loss / count + + # computer reconstruction loss + reconstruction_loss = self.mse(output, x) + + mi_loss = self.MINet(z_sample[0].detach(), z_sample[1].detach()) + loss = (kl_loss, reconstruction_loss, mi_loss) + return loss + + def forward(self, x, label=None, return_loss=True): + qz, mean, var = self.vae(x) + + if return_loss: + output = self.vae.decoder_output(torch.concat(qz, dim=1)) + loss = self.computer_loss(mean, var, qz, x, output) + return qz, loss + else: + return qz + + +class CFAMG: + def __init__(self, args): + self.args = args + + self.device = args.device + # project_name = args.project_name + + # self.log_dir_path = os.path.join(args.log_dir, project_name, args.dataset_name) + # self.result_dir = os.path.join(self.log_dir_path, "result") + # os.makedirs(self.result_dir, exist_ok=True) + self.dataset = args.dataset + self.pos_dataloader, self.neg_dataloader, self.feat_dim, self.seq_len, self.batch_size = create_dataLoader2( + args.dataset, + args.batch_size) + + self.pos_model = CFAMGVAE(feat_dim=self.feat_dim, + seq_len=self.seq_len, + latent_dim=self.args.latent_dim, + hidden_dim=self.args.hidden_dim, + dropout_list=self.args.dropout_list).to(self.device) + self.neg_model = CFAMGVAE(feat_dim=self.feat_dim, + seq_len=self.seq_len, + latent_dim=self.args.latent_dim, + hidden_dim=self.args.hidden_dim, + dropout_list=self.args.dropout_list).to(self.device) + self.optimizer_pos = torch.optim.Adam(self.pos_model.parameters(), lr=self.args.lr, + weight_decay=self.args.weight_decay) + self.optimizer_neg = torch.optim.Adam(self.neg_model.parameters(), lr=self.args.lr, + weight_decay=self.args.weight_decay) + + def train_on_data(self): + print('*' * 50) + print('Main Training Starts ...') + print('*' * 50) + + if self.args.use_lr_decay: + self.scheduler_pos = torch.optim.lr_scheduler.StepLR(self.optimizer_pos, step_size=self.args.lr_decay_step, + gamma=self.args.lr_gamma) + self.scheduler_neg = torch.optim.lr_scheduler.StepLR(self.optimizer_neg, step_size=self.args.lr_decay_step, + gamma=self.args.lr_gamma) + + best_loss = float('inf') + patience = 10 + counter = 0 + for epoch in range(self.args.num_epochs): + self.pos_model.train() + self.neg_model.train() + + if epoch < 30: + for param in self.pos_model.parameters(): + param.requires_grad = True + lr_scale = 1.0 + else: + for param in self.pos_model.vae.encoder_network.parameters(): + param.requires_grad = False + lr_scale = 0.1 + + for param_group in self.optimizer_pos.param_groups: + param_group['lr'] = self.args.lr * lr_scale + + from itertools import cycle + pos_loader_cycle = cycle(self.pos_dataloader) + neg_loader_iter = iter(self.neg_dataloader) + + self.pos_kl_loss, self.pos_reconstruction_loss, self.pos_mi_loss, self.pos_swap_loss = [], [], [], [] + self.neg_kl_loss, self.neg_reconstruction_loss, self.neg_mi_loss, self.neg_swap_loss = [], [], [], [] + self.total_pos_loss, self.total_neg_loss = [], [] + for _ in range(len(self.neg_dataloader)): + pos_samp, pos_label = next(pos_loader_cycle) + neg_samp, neg_label = next(neg_loader_iter) + + # 数据增强与设备迁移 + pos_samp, pos_label = pos_samp.to(self.device), pos_label.to(self.device) + neg_samp, neg_label = neg_samp.to(self.device), neg_label.to(self.device) + + # 前向传播 + pos_z, pos_losses = self.pos_model(pos_samp, pos_label) + pos_kl_loss, pos_reconstruction_loss, pos_mi_loss = pos_losses + self.pos_kl_loss.append(pos_kl_loss) + self.pos_reconstruction_loss.append(pos_reconstruction_loss) + self.pos_mi_loss.append(pos_mi_loss) + neg_z, neg_losses = self.neg_model(neg_samp, neg_label) + neg_kl_loss, neg_reconstruction_loss, neg_mi_loss = neg_losses + self.neg_kl_loss.append(neg_kl_loss) + self.neg_reconstruction_loss.append(neg_reconstruction_loss) + self.neg_mi_loss.append(neg_mi_loss) + + # 对齐批次并计算Swap Loss(带正则化) + pos_z1, pos_z2 = pos_z + neg_z1, neg_z2 = neg_z + pos_swap_loss = self.pos_model.swap_classifier_loss(pos_z1, pos_z2, neg_z2, pos_label) + neg_swap_loss = self.neg_model.swap_classifier_loss(neg_z1, neg_z2, pos_z2, neg_label) + self.pos_swap_loss.append(pos_swap_loss) + self.neg_swap_loss.append(neg_swap_loss) + + # 加权联合损失 + total_pos_loss = pos_kl_loss + pos_reconstruction_loss + self.args.w_lambda * pos_mi_loss + self.args.w_beta * pos_swap_loss + total_neg_loss = neg_kl_loss + neg_reconstruction_loss + self.args.w_lambda * neg_mi_loss + self.args.w_beta * neg_swap_loss + self.total_pos_loss.append(total_pos_loss) + self.total_neg_loss.append(total_neg_loss) + + # 梯度计算与裁剪 + self.optimizer_pos.zero_grad() + self.optimizer_neg.zero_grad() + total_pos_loss.backward(retain_graph=True) + total_neg_loss.backward() + + torch.nn.utils.clip_grad_norm_(self.pos_model.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_(self.neg_model.parameters(), max_norm=1.0) + + # 选择性更新分类器 + # self.optimizer_classifier_pos.step() + self.optimizer_pos.step() + # self.optimizer_classifier_neg.step() + self.optimizer_neg.step() + + if self.args.use_lr_decay: + self.scheduler_pos.step() + self.scheduler_neg.step() + + if epoch % self.args.save_freq == 0: + # self.save_model(epoch) + pass + + if epoch % self.args.log_freq == 0: + self.board_loss(epoch) + + current_reconstruction_loss = sum(self.total_pos_loss) / len(self.total_pos_loss) + + # 早停与模型保存 + if current_reconstruction_loss < best_loss: + best_loss = current_reconstruction_loss + counter = 0 + else: + counter += 1 + if counter >= patience: + print(f"Early stopping at epoch {epoch}") + # self.save_model(epoch) + break + + def generator_sample(self): + self.pos_model.eval() + self.neg_model.eval() + X_pos, y_pos = self.args.dataset["train_data_pos"] + X_neg, y_neg = self.args.dataset["train_data_neg"] + num_majority, num_minority = len(y_neg), len(y_pos) + num_diff_samp = num_majority - num_minority + generated_samples = [] + + with torch.no_grad(): + for pos_samp, _ in self.pos_dataloader: + pos_samp = pos_samp.squeeze(0).to(self.device) + z_pos = self.pos_model(pos_samp, return_loss=False)[0] + for neg_samp, _ in self.neg_dataloader: + neg_samp = neg_samp.squeeze(0).to(self.device) + z_neg = self.neg_model(neg_samp, return_loss=False)[1] + if z_neg.shape[0] < z_pos.shape[0]: + z_pos = z_pos[:z_neg.shape[0]] + elif z_neg.shape[0] > z_pos.shape[0]: + z_neg = z_neg[:z_pos.shape[0]] + z = torch.concat((z_pos, z_neg), dim=-1) + generated_samp = self.pos_model.vae.decoder_output(z) + generated_samples.append(generated_samp) + + generated_samples = torch.cat(generated_samples, dim=0) + if generated_samples.size(0) < num_diff_samp: + num_diff_samp = generated_samples.size(0) + generated_samples = generated_samples[:num_diff_samp].cpu().detach().numpy() + balance_pos_samp = np.concatenate((X_pos, generated_samples), axis=0) + balance_pos_label = np.concatenate((y_pos, np.ones((num_diff_samp,))), axis=0) + + balance_samp = np.concatenate((balance_pos_samp, X_neg), axis=0) + balance_label = np.concatenate((balance_pos_label, y_neg), axis=0) + + return balance_samp, balance_label, generated_samples + + # def save_model(self, epoch, model_name='CFAMG'): + # model_path = os.path.join(self.result_dir, f"model_{model_name}_{epoch}.pth") + # state_dict = { + # 'steps': epoch, + # 'pos_state_dict': self.pos_model.state_dict(), + # 'pos_optimizer': self.optimizer_pos.state_dict(), + # 'neg_state_dict': self.neg_model.state_dict(), + # 'neg_optimizer': self.optimizer_neg.state_dict() + # } + # with open(model_path, "wb") as f: + # torch.save(state_dict, f) + # + # print(f'\n {epoch} model saved ...') + + def board_loss(self, epoch): + print('\n') + n1, n2 = len(self.pos_dataloader), len(self.neg_dataloader) + print( + f"Epoch : {epoch}," + f" Pos Loss : {sum(self.total_pos_loss) / n2}," + f" Pos KL Loss : {sum(self.pos_kl_loss) / n2}" + f" Pos Recon Loss : {sum(self.pos_reconstruction_loss) / n2}" + f" Pos Swap Loss : {self.args.w_beta * sum(self.pos_swap_loss) / n2}" + f" Pos MI Loss : {self.args.w_lambda * sum(self.pos_mi_loss) / n2}" + ) + print( + f"Epoch : {epoch}," + f" Neg Loss : {sum(self.total_neg_loss) / n2}," + f" Neg KL Loss : {sum(self.neg_kl_loss) / n2}" + f" Neg Recon Loss : {sum(self.neg_reconstruction_loss) / n2}" + f" Neg Swap Loss : {self.args.w_beta * sum(self.neg_swap_loss) / n2}" + f" Neg MI Loss : {self.args.w_lambda * sum(self.neg_mi_loss) / n2}" + ) + + def load_model(self, model_path): + model_param = torch.load(model_path) + self.pos_model.load_state_dict(model_param["pos_state_dict"]) + self.pos_model.eval() + self.neg_model.load_state_dict(model_param["neg_state_dict"]) + self.neg_model.eval() diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/data_preprocess.py b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/data_preprocess.py new file mode 100644 index 00000000..6e7f070f --- /dev/null +++ b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/data_preprocess.py @@ -0,0 +1,103 @@ +import random +import warnings +from collections import Counter +import numpy as np +import pandas as pd +import os +import torch +from torch.utils.data import DataLoader, TensorDataset + + +def load_dataset(dataset_name=None, root_path=None): + X_train = np.load(os.path.join(root_path, dataset_name, "X_train.npy")) + y_train = np.load(os.path.join(root_path, dataset_name, "y_train.npy")).astype(int) + print(Counter(y_train)) + X_train_pos, X_train_neg = X_train[y_train == 1], X_train[y_train == 0] + y_train_pos, y_train_neg = y_train[y_train == 1], y_train[y_train == 0] + ir = X_train_neg.shape[0] / X_train_pos.shape[0] + + print(f"Dataset name: {dataset_name}, The number of Positive sample : {len(y_train_pos)}") + + dataset = { + "train_data": (X_train, y_train), + "train_data_pos": (X_train_pos, y_train_pos), + "train_data_neg": (X_train_neg, y_train_neg), + "ir": ir, + } + + return dataset, ir + + +def set_seed(seed, cudnn_deterministic=False): + """ + Set all seed + :param seed: seed + :param cudnn_deterministic: whether set CUDNN deterministic + """ + if seed is not None: + print(f'Global seed set to {seed}') + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + """ + Even if `torch.backends.cudnn.deterministic` is set to False, + the reproducibility of other random operations in PyTorch can still be ensured by setting the random seed with + `torch.manual_seed`. + """ + torch.backends.cudnn.deterministic = False + + if cudnn_deterministic: + torch.backends.cudnn.deterministic = True + warnings.warn('You have chosen to seed training. This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! You may see unexpected behavior when restarting ' + 'from checkpoints.') + + +class MinMaxScaler(): + def fit_transform(self, data): + self.fit(data) + scaled_data = self.transform(data) + return scaled_data + + def fit(self, data): + self.mini = np.min(data, 0) + self.range = np.max(data, 0) - self.mini + return self + + def transform(self, data): + numerator = data - self.mini + scaled_data = numerator / (self.range + 1e-7) + return scaled_data + + def inverse_transform(self, data): + data *= self.range + data += self.mini + return data + + +def create_dataLoader2(dataset, batch_size): + """ + dataset : dict + data : tuple (Positive Sample, Negative Sample) + """ + + X_pos, y_pos = dataset["train_data_pos"] + X_neg, y_neg = dataset["train_data_neg"] + num_pos, num_neg = X_pos.shape[0], X_neg.shape[0] + batch_size = batch_size if batch_size < num_pos else num_pos + + assert X_pos.ndim == 3 + + X_pos, X_neg = torch.tensor(X_pos, dtype=torch.float32), torch.tensor(X_neg, dtype=torch.float32) + y_pos, y_neg = torch.tensor(y_pos, dtype=torch.int64), torch.tensor(y_neg, dtype=torch.int64) + + # if num_pos % batch_size == 1 or num_neg % batch_size == 1: + pos_dataloader = DataLoader(TensorDataset(X_pos, y_pos), batch_size=batch_size, shuffle=False, drop_last=True) + neg_dataloader = DataLoader(TensorDataset(X_neg, y_neg), batch_size=batch_size, shuffle=False, drop_last=True) + # else: + # pos_dataloader = DataLoader(TensorDataset(X_pos, y_pos), batch_size=batch_size, shuffle=False) + # neg_dataloader = DataLoader(TensorDataset(X_neg, y_neg), batch_size=batch_size, shuffle=False) + + _, feat_dim, seq_len = X_pos.shape + return pos_dataloader, neg_dataloader, feat_dim, seq_len, batch_size diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/main.py b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/main.py new file mode 100644 index 00000000..f5e3ced2 --- /dev/null +++ b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/main.py @@ -0,0 +1,100 @@ +import os +import numpy as np +import pandas as pd +import torch +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.data_preprocess import set_seed, load_dataset +from tsml_eval._wip.rt.transformations.collection.imbalance.pk_cfamg.cfamg import CFAMG + +pd.set_option('display.max_columns', None) +pd.set_option('display.max_colwidth', None) +pd.set_option('display.expand_frame_repr', False) + + +def parse_args(): + # Set path + output_dir = 'CFAMG' + log_dir = './Exp_log' + data_dir = None # action='store_true' implies a boolean flag, so default to None or False + + # Set data parameter + batch_size = 32 + + # Set logging + log_freq = 5 + save_freq = 200 + tensorboard = False # action="store_true" implies a boolean flag + wandb = False # action='store_true' implies a boolean flag + + # Set train parameter + num_epochs = 201 + latent_dim = 64 + hidden_dim = [32, 64, 128] + dropout_list = [0.1, 0.1, 0.2] + use_lr_decay = False + lr_decay_step = 100 + lr_gamma = 0.1 + lr = 1e-3 + weight_decay = 1e-4 + temp_epochs = 300 + cls_num_epochs = 100 + beta = 10 + + # Create a class or object to hold the parameters + class Args: + pass + + args = Args() + args.output_dir = output_dir + args.log_dir = log_dir + args.data_dir = data_dir + args.batch_size = batch_size + args.log_freq = log_freq + args.save_freq = save_freq + args.tensorboard = tensorboard + args.wandb = wandb + args.num_epochs = num_epochs + args.latent_dim = latent_dim + args.hidden_dim = hidden_dim + args.dropout_list = dropout_list + args.use_lr_decay = use_lr_decay + args.lr_decay_step = lr_decay_step + args.lr_gamma = lr_gamma + args.lr = lr + args.weight_decay = weight_decay + args.temp_epochs = temp_epochs + args.cls_num_epochs = cls_num_epochs + args.beta = beta + + args.save_path = os.path.join(os.getcwd(), args.output_dir) + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print('Device:', args.device) + return args + + +def run_CFAMG(args): + model = CFAMG(args) + model.train_on_data() + X_train, y_train, _ = model.generator_sample() + data_save_path = os.path.join(os.getcwd(), args.save_file, args.project_name, args.dataset_name) + os.makedirs(data_save_path, exist_ok=True) + np.save(os.path.join(data_save_path, 'X_train_syn_sample.npy'), X_train) + np.save(os.path.join(data_save_path, 'y_train_syn_sample.npy'), y_train) + print('Synthetic data has been saved ! ') + + +if __name__ == "__main__": + seed = 2024 + set_seed(seed) + args = parse_args() + datasets_path = os.getcwd() + time_log = pd.DataFrame(columns=["Datasets", "Time(s)"]) + args.save_file = "synthetic_dataset" + tsu_name = 'FiftyWords' + args.w_lambda, args.w_beta = 1, 1 + args.project_name = f'CFAMG' + args.exp_name = tsu_name + args.dataset_name = tsu_name + print(f"Model: {args.exp_name} || Dataset : {tsu_name}") + dataset, ir = load_dataset(dataset_name=tsu_name, root_path=datasets_path) + args.dataset = dataset + run_CFAMG(args) diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/model_utils.py b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/model_utils.py new file mode 100644 index 00000000..3563b83f --- /dev/null +++ b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/model_utils.py @@ -0,0 +1,40 @@ +import torch +import numpy as np +from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix + + +def resample_from_normal(mean, var): + epsilon = 1e-6 + qz_gaussian = torch.distributions.Normal(loc=mean, scale=var + epsilon) + qz = qz_gaussian.rsample() + return qz + + +def adjust_input(X_train, X_test, dim_need='2'): + if dim_need == '2': + if len(X_train.shape) > 2 or len(X_test.shape) > 2: + if isinstance(X_train, torch.Tensor) or isinstance(X_test, torch.Tensor): + X_train = X_train.squeeze(1) + X_test = X_test.squeeze(1) + elif isinstance(X_train, np.ndarray) or isinstance(X_test, np.ndarray): + X_train = X_train.reshape(X_train.shape[0], -1) + X_test = X_test.reshape(X_test.shape[0], -1) + elif dim_need == '3': + if len(X_train.shape) < 3 or len(X_test.shape) < 3: + if isinstance(X_train, torch.Tensor) or isinstance(X_test, torch.Tensor): + X_train = X_train.unsqueeze(1) + X_test = X_test.unsqueeze(1) + elif isinstance(X_train, np.ndarray) or isinstance(X_test, np.ndarray): + X_train = np.expand_dims(X_train, axis=1) + X_test = np.expand_dims(X_test, axis=1) + return X_train, X_test + + +def computer_f1_gmeans_auc(y_test, y_pred, y_proba): + f1 = f1_score(y_test, y_pred, average='binary') + + AUC = roc_auc_score(y_test, y_proba) + + tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() + g_mean = np.sqrt((tp / (tp + fn)) * (tn / (tn + fp))) + return f1, g_mean, AUC diff --git a/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/network.py b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/network.py new file mode 100644 index 00000000..b1b77444 --- /dev/null +++ b/tsml_eval/_wip/rt/transformations/collection/imbalance/pk_cfamg/network.py @@ -0,0 +1,42 @@ +import math +import numpy as np +import torch.nn as nn +import torch + + +class HiddenLayerMLP(nn.Module): + def __init__(self, input_dim, hidden_dim, model_type, dropout_list=None): + super().__init__() + + self.mlp = nn.Sequential() + if model_type == 'encoder': + if isinstance(hidden_dim, list): + assert isinstance(dropout_list, list) + for hidden, dropout in zip(hidden_dim, dropout_list): + self.mlp.append(nn.Dropout(dropout)) + self.mlp.append(nn.Linear(input_dim, hidden)) + self.mlp.append(nn.ReLU()) + input_dim = hidden + else: + assert not isinstance(dropout_list, list) + self.mlp.append(nn.Dropout(dropout_list)) + self.mlp.append(nn.Linear(input_dim, hidden_dim)) + elif model_type == 'decoder': + if isinstance(hidden_dim, list): + assert isinstance(dropout_list, list) + hidden_dim = list(reversed(hidden_dim)) + dropout_list = list(reversed(dropout_list)) + for hidden, dropout in zip(hidden_dim, dropout_list): + self.mlp.append(nn.Dropout(dropout)) + self.mlp.append(nn.Linear(input_dim, hidden)) + self.mlp.append(nn.ReLU()) + input_dim = hidden + else: + assert not isinstance(dropout_list, list) + self.mlp.append(nn.Dropout(dropout_list)) + self.mlp.append(nn.Linear(input_dim, hidden_dim)) + + def forward(self, x): + if len(x.shape) == 3: + x = x.view(x.size(0), -1) + return self.mlp(x)