Skip to content

Commit

Permalink
final cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
andrea-dimarco committed Feb 16, 2024
1 parent 9bd7f11 commit de68714
Show file tree
Hide file tree
Showing 56 changed files with 37 additions and 159 deletions.
16 changes: 0 additions & 16 deletions TODO.txt

This file was deleted.

Binary file removed forecaster-sine.pth
Binary file not shown.
Binary file removed forecaster-wien.pth
Binary file not shown.
Binary file removed models/forecaster-wien.pth
Binary file not shown.
Binary file removed models/timegan-sine.pth
Binary file not shown.
Binary file removed models/timegan-wien.pth
Binary file not shown.
Binary file added report.pdf
Binary file not shown.
Binary file removed report/template.zip
Binary file not shown.
45 changes: 2 additions & 43 deletions src/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,7 @@
"source": [
"# Time-Series Generation using Contrastive Learning\n",
"\n",
"Consider learning a generative model for time-series data.\n",
"\n",
"The sequential setting poses a unique challenge: Not only should the generator capture the conditional dynamics of (stepwise) transitions, but its open-loop rollouts should also preserve the joint distribution of (multi-step) trajectories.\n",
"\n",
"On one hand, autoregressive models\n",
"trained by MLE allow learning and computing explicit transition distributions, but suffer from compounding error during rollouts.\n",
"\n",
"On the other hand, adversarial models based on GAN training alleviate such exposure bias, but transitions are implicit and hard to assess.\n",
"\n",
"In this work, we study a generative framework that seeks to combine the strengths of both: Motivated by a moment-matching objective to mitigate\n",
"compounding error, we optimize a local (but forward-looking) *transition policy*, where the reinforcement signal is provided by a global (but stepwise-decomposable) *energy model* trained by contrastive estimation. \n",
"\n",
"At **training**, the two components are learned cooperatively, avoiding the instabilities typical of adversarial objectives. \n",
"\n",
"At **inference**, the learned policy serves as the generator for iterative sampling, and the learned energy serves as a trajectory-level measure for evaluating sample quality.\n",
"\n",
"By expressly training a policy to imitate sequential behavior of time-series features in a dataset, this approach embodies *“generation by imitation”*. Theoretically, we illustrate the correctness of this formulation and the consistency of the algorithm.\n",
"\n",
"Empirically, we evaluate its ability to generate predictively useful samples from real-world datasets, verifying that it performs at the standard of existing benchmarks."
"This work proposes a general architecture that can be changed in order to generate Time-Series of various domains. Using an embedding network to provide a reversible mapping between features and latent representations, reducing the high-dimensionality of the adversarial learning space. This capitalizes on the fact that the temporal dynamics of even complex systems are often driven by fewer and lower-dimensional factors of variation. "
]
},
{
Expand Down Expand Up @@ -134,13 +116,6 @@
"import dataset_handling as dh"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Eh eh"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -169,22 +144,6 @@
"hparams = Config()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Comment this cell if you don't want to use Weights & Biases to log the process"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -207,7 +166,7 @@
"metadata": {},
"outputs": [],
"source": [
"ut.set_seed(seed=1337)"
"ut.set_seed()"
]
},
{
Expand Down
16 changes: 6 additions & 10 deletions src/forecasting_model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@


# Libraries
from typing import Sequence, Dict, Tuple, Union, Mapping

from dataclasses import asdict
from pathlib import Path

from torch.utils.data import DataLoader

import wandb
import torch
import torch.nn as nn
from torch import optim

import wandb
from pathlib import Path
import pytorch_lightning as pl
from dataclasses import asdict
from torch.utils.data import DataLoader
from typing import Sequence, Dict, Tuple, Union, Mapping

# My modules
import dataset_handling as dh
import utilities as ut
import dataset_handling as dh
from hyperparameters import Config

'''
Expand Down
12 changes: 6 additions & 6 deletions src/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
class Config:

## System parameters
operating_system:str = 'windows' # . Will affect the testing results
dataset_folder:str = "../datasets/"# Path to the datasets folder

## Training parameters
Expand Down Expand Up @@ -41,9 +40,9 @@ class Config:
# . . . . . . . . . with random mutual correlations
train_test_split: float = 0.7 #. . . Split between training and testing samples
train_val_split: float = 0.8 #. . . Split between training and validating samples
num_samples: int = 10**7 # . . . . . Number of samples to generate (if any)
num_samples: int = 10**6 # . . . . . Number of samples to generate (if any)
data_dim: int = 3 # . . . . . . . . Dimension of one generated sample (if any)
seq_len: int = 50**2 #. . . . . . . Length of the input sequences
seq_len: int = 10**3 #. . . . . . . Length of the input sequences


## Network parameters
Expand Down Expand Up @@ -71,12 +70,13 @@ class Config:


## Testing phase
operating_system:str = 'linux' #. . if 'windows' it won't run the anomaly detector
alpha: float = 0.1 #. . . . . . . . Parameter for the Anomaly Detector
h: float = 10 # . . . . . . . . . . Parameter for the Anomaly Detector
limit:int = 10000# . . . . . . . . . Amount of elements to consider when running tests
pic_frequency:int = 100 #. . . . . . How many steps to wait before saving a new picture during testing
limit:int = 1000# . . . . . . . . . Amount of elements to consider when running tests
pic_frequency:int = 100 # . . . . . How many steps to wait before saving a new picture during testing

forecaster_epochs:int = 25**1 #. . . Amount of epochs to train the forecaster model
forecaster_epochs:int = 10**1 #. . . Amount of epochs to train the forecaster model
forecaster_hidden:int = 50 # . . . . Hidden dimension for the forecaster
forecaster_layers:int = 1 #. . . . . Number of layers for the forecaster
forecaster_seq_len:int = 100#. . . . Lookback window for the forecaster
12 changes: 5 additions & 7 deletions src/testing_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def generate_stream_test(model:TimeGAN, test_dataset:dh.RealDataset,
with torch.no_grad():
horizon = min(limit, len(test_dataset)) if limit>0 else len(test_dataset)
timegan.eval()
synth = model.cycle(test_dataset.get_whole_stream()[:horizon]
synth = model(test_dataset.get_whole_noise_stream()[:horizon]
).reshape(horizon, test_dataset.p)
print("Synthetic stream has been generated.")
if compare:
Expand Down Expand Up @@ -347,9 +347,7 @@ def distribution_visualization(model:TimeGAN, test_dataset:dh.RealDataset,
with torch.no_grad():
horizon = min(limit, len(test_dataset)) if limit>0 else len(test_dataset)
timegan.eval()
# TODO: revert this
#synth = model(test_dataset.get_whole_stream()[:horizon]
synth = model.cycle(test_dataset.get_whole_stream()[:horizon]
synth = model(test_dataset.get_whole_noise_stream()[:horizon]
).reshape(horizon, test_dataset.p)
original = test_dataset.get_whole_stream()[:horizon]
print("Synthetic stream has been generated.")
Expand Down Expand Up @@ -393,10 +391,10 @@ def distribution_visualization(model:TimeGAN, test_dataset:dh.RealDataset,
)


## TESTING LOOP
## TESTS
limit = hparams.limit
frequency = hparams.pic_frequency
'''

avg_rec_loss = recovery_seq_test(model=timegan,
test_dataset=test_dataset,
limit=limit,
Expand Down Expand Up @@ -434,5 +432,5 @@ def distribution_visualization(model:TimeGAN, test_dataset:dh.RealDataset,
show_plot=False,
limit=hparams.limit
)
'''

AD_tests(model=timegan, test_dataset=test_dataset)
39 changes: 15 additions & 24 deletions src/timegan_model.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@


# Libraries
from typing import Sequence, Dict, Tuple, Union, Mapping

from dataclasses import asdict
from pathlib import Path

from torch.utils.data import DataLoader

import wandb
import torch
from torch import optim

import wandb
from pathlib import Path
import pytorch_lightning as pl
from dataclasses import asdict
from torch.utils.data import DataLoader
from typing import Sequence, Dict, Tuple, Union, Mapping

# My modules
import dataset_handling as dh
import utilities as ut
import dataset_handling as dh
from hyperparameters import Config
from modules.classifier_cell import ClassCell
from modules.regressor_cell import RegCell
from modules.classifier_cell import ClassCell

'''
This is the main model.
Expand Down Expand Up @@ -188,8 +181,7 @@ def val_dataloader(self) -> DataLoader:


def configure_optimizers(self
) -> Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer, optim.Optimizer, optim.Optimizer]:
#) -> Tuple[Sequence[optim.Optimizer], Sequence[Dict[str, Any]]]:
):
'''
Instantiate the optimizers and schedulers.
Expand Down Expand Up @@ -326,15 +318,14 @@ def DD_loss(self, X: torch.Tensor, Z: torch.Tensor,
w1:float=0.45, w2:float=0.10, w3:float=0.45
) -> torch.Tensor:
# Compute model outputs
# 1. Embedder
# 2. Generator
# 1. Generator
E_hat = self.Gen(Z)
# 3. Supervisor
# 2. Supervisor
H_hat = self.Sup(E_hat)
# 4. Recovery
# 3. Recovery
X_hat = self.Rec(E_hat)
X_hat_s = self.Rec(H_hat)
# 5. Discriminator
# 4. Discriminator
Y_real = self.DataDis(X)
Y_fake = self.DataDis(X_hat_s)
Y_fake_e = self.DataDis(X_hat)
Expand Down Expand Up @@ -455,15 +446,15 @@ def E_loss(self, X: torch.Tensor,
# 1. Embedder
H = self.Emb(X)
# 2. Supervisor
#H_hat_supervise = self.Sup(H)
H_hat_supervise = self.Sup(H)
# 3. Recovery
X_tilde = self.Rec(H)

# Loss Components
R_loss = self.reconstruction_loss(X, X_tilde)
#S_loss = self.reconstruction_loss(H, H_hat_supervise)
S_loss = self.reconstruction_loss(H, H_hat_supervise)

return w1*R_loss #+ w2*S_loss
return w1*R_loss + w2*S_loss


def R_loss(self, X: torch.Tensor,
Expand Down
52 changes: 1 addition & 51 deletions src/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,11 @@
import dataset_handling as dh
from timegan_model import TimeGAN
from hyperparameters import Config
from dataset_handling import train_test_split
from data_generation import iid_sequence_generator, sine_process, wiener_process

import warnings
warnings.filterwarnings("ignore")


def generate_data(datasets_folder="./datasets/"):
hparams = Config()
print("Generating datasets.")
if hparams.dataset_name in ['sine', 'wien', 'iid', 'cov']:
# Generate and store the dataset as requested
dataset_path = f"{datasets_folder}{hparams.dataset_name}_generated_stream.csv"
if hparams.dataset_name == 'sine':
sine_process.save_sine_process(p=hparams.data_dim,
N=hparams.num_samples,
file_path=dataset_path)
elif hparams.dataset_name == 'wien':
wiener_process.save_wiener_process(p=hparams.data_dim,
N=hparams.num_samples,
file_path=dataset_path)
elif hparams.dataset_name == 'iid':
iid_sequence_generator.save_iid_sequence(p=hparams.data_dim,
N=hparams.num_samples,
file_path=dataset_path)
elif hparams.dataset_name == 'cov':
iid_sequence_generator.save_cov_sequence(p=hparams.data_dim,
N=hparams.num_samples,
file_path=dataset_path)
else:
raise ValueError
print(f"The {hparams.dataset_name} dataset has been succesfully created and stored into:\n\t- {dataset_path}")
elif hparams.dataset_name == 'real':
pass
else:
raise ValueError("Dataset not supported.")


if hparams.dataset_name in ['sine', 'wien', 'iid', 'cov']:
train_dataset_path = f"{datasets_folder}{hparams.dataset_name}_training.csv"
test_dataset_path = f"{datasets_folder}{hparams.dataset_name}_testing.csv"

train_test_split(X=loadtxt(dataset_path, delimiter=",", dtype=float32),
split=hparams.train_test_split,
train_file_name=train_dataset_path,
test_file_name=test_dataset_path
)
print(f"The {hparams.dataset_name} dataset has been split successfully into:\n\t- {train_dataset_path}\n\t- {test_dataset_path}")
elif hparams.dataset_name == 'real':
train_dataset_path = datasets_folder + hparams.train_file_name
test_dataset_path = datasets_folder + hparams.test_file_name
else:
raise ValueError("Dataset not supported.")


def train(datasets_folder="./datasets/"):
'''
Train the TimeGAN model
Expand Down Expand Up @@ -164,7 +114,7 @@ def validate_model(model:TimeGAN, datasets_folder:str="./datasets/", limit:int=1
# Training Area #
# # # # # # # # #
datasets_folder = "./datasets/"
generate_data(datasets_folder)
ut.generate_data(datasets_folder)
ut.set_seed(seed=1337)
train(datasets_folder=datasets_folder)

2 changes: 1 addition & 1 deletion test_results/generation_tests/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Save your generation test results here!!
Geneation test results will be saved here.
Binary file removed test_results/generation_tests/sine-generation-0.png
Binary file not shown.
Binary file removed test_results/generation_tests/sine-generation-1.png
Binary file not shown.
Binary file removed test_results/generation_tests/sine-generation-2.png
Binary file not shown.
Binary file removed test_results/generation_tests/sine-generation-3.png
Binary file not shown.
Binary file not shown.
Binary file removed test_results/generation_tests/wien-generation-1.png
Binary file not shown.
Binary file removed test_results/generation_tests/wien-generation-2.png
Binary file not shown.
Binary file removed test_results/generation_tests/wien-generation-3.png
Binary file not shown.
Binary file removed test_results/generation_tests/wien-generation-4.png
Binary file not shown.
Binary file removed test_results/generation_tests/wien-generation-5.png
Binary file not shown.
Binary file removed test_results/generation_tests/wien-generation-6.png
Binary file not shown.
Binary file removed test_results/generation_tests/wien-generation-7.png
Binary file not shown.
2 changes: 1 addition & 1 deletion test_results/recovery_tests/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Save your recovery test results here!!
Recovery test results will be saved here.
Binary file removed test_results/recovery_tests/sine-recovery-0.png
Binary file not shown.
Binary file removed test_results/recovery_tests/sine-recovery-1.png
Binary file not shown.
Binary file removed test_results/recovery_tests/sine-recovery-2.png
Binary file not shown.
Binary file removed test_results/recovery_tests/sine-recovery-3.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-0.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-1.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-2.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-3.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-4.png
Binary file not shown.
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-5.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-6.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-7.png
Binary file not shown.
Binary file removed test_results/recovery_tests/wien-recovery-8.png
Diff not rendered.
Binary file removed test_results/recovery_tests/wien-recovery-9.png
Diff not rendered.
Binary file removed test_results/sine-forecasting-plot.png
Diff not rendered.
File renamed without changes
File renamed without changes
Binary file removed test_results/sine-synth (keep).png
Diff not rendered.
Binary file removed test_results/sine-synth-(keep2).png
Diff not rendered.
Binary file removed test_results/sine-synth-0.png
Diff not rendered.
Binary file added test_results/wien-forecasting-plot.png
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
Binary file removed test_results/wien-synth (keep).png
Diff not rendered.
Binary file removed timegan-wien.pth
Binary file not shown.

0 comments on commit de68714

Please sign in to comment.