diff --git a/README.md b/README.md index 1fb9c40c..697e9338 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,16 @@ provides the possibility to perform benchmark experiments and comparisons by tra the models with the same autoencoding neural network architecture. The feature *make your own autoencoder* allows you to train any of these models with your own data and own Encoder and Decoder neural networks. It integrates experiment monitoring tools such [wandb](https://wandb.ai/), [mlflow](https://mlflow.org/) or [comet-ml](https://www.comet.com/signup?utm_source=pythae&utm_medium=partner&utm_campaign=AMS_US_EN_SNUP_Pythae_Comet_Integration) ๐Ÿงช and allows model sharing and loading from the [HuggingFace Hub](https://huggingface.co/models) ๐Ÿค— in a few lines of code. +**News** ๐Ÿ“ข + +As of v0.1.0, `Pythae` now supports distributed training using PyTorch's [DDP](https://pytorch.org/docs/stable/notes/ddp.html). You can now train your favorite VAE faster and on larger datasets, still with a few lines of code. +See our speed-up [benchmark](#benchmark). ## Quick access: - [Installation](#installation) - [Implemented models](#available-models) / [Implemented samplers](#available-samplers) - [Reproducibility statement](#reproducibility) / [Results flavor](#results) -- [Model training](#launching-a-model-training) / [Data generation](#launching-data-generation) / [Custom network architectures](#define-you-own-autoencoder-architecture) +- [Model training](#launching-a-model-training) / [Data generation](#launching-data-generation) / [Custom network architectures](#define-you-own-autoencoder-architecture) / [Distributed training](#distributed-training-with-pythae) - [Model sharing with ๐Ÿค— Hub](#sharing-your-models-with-the-huggingface-hub-) / [Experiment tracking with `wandb`](#monitoring-your-experiments-with-wandb-) / [Experiment tracking with `mlflow`](#monitoring-your-experiments-with-mlflow-) / [Experiment tracking with `comet_ml`](#monitoring-your-experiments-with-comet_ml-) - [Tutorials](#getting-your-hands-on-the-code) / [Documentation](https://pythae.readthedocs.io/en/latest/) - [Contributing ๐Ÿš€](#contributing-) / [Issues ๐Ÿ› ๏ธ](#dealing-with-issues-%EF%B8%8F) @@ -141,8 +145,15 @@ To launch a model training, you only need to call a `TrainingPipeline` instance. ... output_dir='my_model', ... num_epochs=50, ... learning_rate=1e-3, -... batch_size=200, -... steps_saving=None +... per_device_train_batch_size=200, +... per_device_eval_batch_size=200, +... train_dataloader_num_workers=2, +... eval_dataloader_num_workers=2, +... steps_saving=20, +... optimizer_cls="AdamW", +... optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)}, +... scheduler_cls="ReduceLROnPlateau", +... scheduler_params={"patience": 5, "factor": 0.5} ... ) >>> # Set up the model configuration >>> my_vae_config = model_config = VAEConfig( @@ -334,6 +345,44 @@ You can also find predefined neural network architectures for the most common da ``` Replace *mnist* by cifar or celeba to access to other neural nets. +## Distributed Training with `Pythae` +As of `v0.1.0`, Pythae now supports distributed training using PyTorch's [DDP](https://pytorch.org/docs/stable/notes/ddp.html). It allows you to train your favorite VAE faster and on larger dataset using multi-gpu and/or multi-node training. + +To do so, you can build a python script that will then be launched by a launcher (such as `srun` on a cluster). The only thing that is needed in the script is to specify some elements relative to the distributed environment (such as the number of nodes/gpus) directly in the training configuration as follows + +```python +>>> training_config = BaseTrainerConfig( +... num_epochs=10, +... learning_rate=1e-3, +... per_device_train_batch_size=64, +... per_device_eval_batch_size=64, +... train_dataloader_num_workers=8, +... eval_dataloader_num_workers=8, +... dist_backend="nccl", # distributed backend +... world_size=8 # number of gpus to use (n_nodes x n_gpus_per_node), +... rank=5 # process/gpu id, +... local_rank=1 # node id, +... master_addr="localhost" # master address, +... master_port="12345" # master port, +... ) +``` + +See this [example script](https://github.com/clementchadebec/benchmark_VAE/blob/main/examples/scripts/distributed_training_imagenet.py) that defines a multi-gpu VQVAE training on ImageNet dataset. Please note that the way the distributed environnement variables (`world_size`, `rank` ...) are recovered may be specific to the cluster and launcher you use. + +### Benchmark + +Below are indicated the training times for a Vector Quantized VAE (VQ-VAE) with `Pythae` for 100 epochs on MNIST on V100 16GB GPU(s), for 50 epochs on [FFHQ](https://github.com/NVlabs/ffhq-dataset) (1024x1024 images) and for 20 epochs on [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k) on V100 32GB GPU(s). + +| | Train Data | 1 GPU | 4 GPUs | 2x4 GPUs | +|:---:|:---:|:---:|:---:|---| +| MNIST (VQ-VAE) | 28x28 images (50k) | 235.18 s | 62.00 s | 35.86 s | +| FFHQ 1024x1024 (VQVAE) | 1024x1024 RGB images (60k) | 19h 1min | 5h 6min | 2h 37min | +| ImageNet-1k 128x128 (VQVAE) | 128x128 RGB images ($\approx$ 1.2M) | 6h 25min | 1h 41min | 51min 26s | + + +For each dataset, we provide the benchmarking scripts [here](https://github.com/clementchadebec/benchmark_VAE/tree/main/examples/scripts) + + ## Sharing your models with the HuggingFace Hub ๐Ÿค— Pythae also allows you to share your models on the [HuggingFace Hub](https://huggingface.co/models). To do so you need: - a valid HuggingFace account diff --git a/docs/old/advanced/custom_autoencoder.rst b/docs/old/advanced/custom_autoencoder.rst deleted file mode 100644 index bf60e70d..00000000 --- a/docs/old/advanced/custom_autoencoder.rst +++ /dev/null @@ -1,98 +0,0 @@ -.. _making-your-own-vae: - -################################## -Making your own autoencoder model -################################## - - -By default, the VAE models use Multi Layer Perceptron neural networks -for the encoder and decoder and metric (if applicable) which automatically adapt to the input data shape. The only thing that is needed is to state the data input dimension which equals to ``n_channels x height x width x ...`` in the :class:`ModelConfig`. This important since, if you do not provided any, an error is raised: - -.. code-block:: python - - >>> from pythae.models.base.base_config import BaseAEConfig - >>> from pythae.models import BaseAE - >>> config = BaseAEConfig() - >>> BaseAE(model_config=config) - Traceback (most recent call last): - File "", line 1, in - File "/home/clement/Documents/these/implem/pythae/src/pythae/models/base/base_vae.py", line 57, in __init__ - raise AttributeError("No input dimension provided !" - AttributeError: No input dimension provided !'input_dim' parameter of - BaseAEConfig instance must be set to 'data_shape' where the shape of the data is [mini_batch x data_shape] . Unable to build encoder automatically - -.. note:: - - In case you have different size of data, pythae will reshape it to the minimum size ``min_n_channels x min_height x min_width x ...`` - - -Hence building a basic network which used the basic provided architectures may be done as follows: - -.. code-block:: python - - >>> from pythae.models.my_model.my_model_config import MyModelConfig - >>> from pythae.models.my_model.my_model import MyModelConfig - >>> config = MyModelConfig( - ... input_dim=10 # Setting the data input dimension is needed if you do not use your own autoencoding architecture - ... # you parameters goes here - ... ) - >>> m = MyModel(model_config=config) # Built the model - - - -However, these networks are often not the best suited to generate. Hence, depending on your data, you may want to override the default architecture and use your own networks instead. Doing so is pretty easy! The only thing you have to do is -define you own encoder or decoder ensuring that they -inherit from the :class:`~pythae.models.nn.BaseEncoder` or :class:`~pythae.models.nn.BaseDecoder`. - -************************************************ -Setting your Encoder -************************************************ - -To build your on encoder only makes it inherit from :class:`~pythae.models.nn.BaseEncoder`, define your architecture and code the :class:`forward` method. -Your own Encoder should look as follows: - - -.. code-block:: python - - >>> from pythae.models.nn import BaseEncoder - - >>> class MyEncoder(BaseEncoder): - ... def __init__(self, args): - ... BaseEncoder.__init__(self) - ... # your code goes here - - ... def forward(self, x): - ... # your code goes here - ... return mu, log_var - -For a complete example, please see tutorial (using_your_architectures.ipynb) - -.. warning:: - When building your Encoder, the output order is important. Do not forget to set :math:`\mu` as first argument and the **log** variance then. - -************************************************ -Setting your decoder -************************************************ - -Likewise the encoder, to build your on encoder only makes it inherit from :class:`~pythae.models.nn.BaseDecoder`, define your architecture and code the :class:`forward` method. -Your own Decoder should look as follows: - - .. code-block:: - - >>> from pythae.models.nn import BaseDecoder - - >>> class My_decoder(BaseDecoder): - ... def __init__(self): - ... BaseDecoder.__init__(self) - ... # your code goes here - - ... def forward(self, z): - ... # your code goes here - ... return mu - - -For a complete example, please see tutorial (using_your_architectures.ipynb) - -.. note:: - - By convention, the output tensors :math:`\mu` should be in [0, 1]. Ensure, this is the case when building your net. diff --git a/docs/old/advanced/setting_configs.rst b/docs/old/advanced/setting_configs.rst deleted file mode 100644 index cd1884b2..00000000 --- a/docs/old/advanced/setting_configs.rst +++ /dev/null @@ -1,214 +0,0 @@ - -.. _setting your config: - -################################## -Setting up your own configurations -################################## - - - -The augmentation methods relies on default parameters for the model, training and generation. -Depending on your data these parameters should be modified. - - -************************************************ -Link between ``.json`` files and ``dataclasses`` -************************************************ - - -In pythae, the configurations of the models, trainers and samplers are stored and used as :class:`dataclasses.dataclass` and all inherit from the :class:`~pythae.config.BaseConfig`. Hence, any configuration class has a classmethod :class:`~pythae.config.BaseConfig.from_json_file` coming from :class:`~pythae.config.BaseConfig` allowing to directly load config from ``.json`` files into ``dataclasses``. - - -.. _loading from json: - -Loading a config from a ``.json`` -================================================= - -Say that you want to load a training configuration that is stored in a ``training_config.json`` file. To convert it in a :class:`~pythae.trainers.training_config.TrainingConfig` run the following - -.. code-block:: - - >>> from pythae.trainers.training_config import TrainingConfig - >>> config = TrainingConfig.from_json_file( - ... 'scripts/configs/training_config.json') - >>> config - TrainingConfig(output_dir='outputs/my_model_from_script', batch_size=200, max_epochs=2, learning_rate=0.001, train_early_stopping=50, eval_early_stopping=None, steps_saving=1000, seed=8, no_cuda=False, verbose=True) - -where the ``.json`` that was parsed should look like - -.. code-block:: bash - - $ cat scripts/configs/training_config.json - {"output_dir": "outputs/my_model_from_script", "batch_size": 200, "max_epochs": 2 "learning_rate": 1e-3, "train_early_stopping": 50, "eval_early_stopping": null, "steps_saving": 1000, "seed": 8, "no_cuda": false, "verbose": true} - - - -You must ensure that the keys provided on the ``.json`` config file match the one in the required ``dataclass`` and that the value has the required tpe. - -For instance, if you want to provide your own ``training_config.json`` in the pythae scripts, ensure that the keys in the ``.json`` file match the on in -:class:`~pythae.trainers.training_config.TrainingConfig` with values having the correct type. See `type checking`_. The provided scripts will indeed load ``dataclasses`` from the provided ``.json`` files. - - -Writing a ``.json`` from a :class:`~pythae.config.BaseConfig` instance. -================================================================================================== - - -You can also write a ``.json`` directly from your ``dataclass`` using the :class:`~pythae.config.BaseConfig.save_json` method from :class:`~pythae.config.BaseConfig` - -.. code-block:: python - - >>> from pythae.trainers.training_config import TrainingConfig - >>> config = TrainingConfig(max_epochs=10, learning_rate=0.1) - >>> config.save_json(dir_path='.', filename='test') - -The resulting ``.json`` should looks like this - -.. code-block:: - - $ cat test.json - {"output_dir": null, "batch_size": 50, "max_epochs": 10, "learning_rate": 0.1, "train_early_stopping": 50, "eval_early_stopping": null, "steps_saving": 1000, "seed": 8, "no_cuda": false, "verbose": true} - - -.. _type checking: - -Configuration type checking -================================================= - - -A type check is performed automatically when building the ``dataclasses`` with `pydantic `_. Hence, if you provide the wrong type in the config it will either: - -- be converted to the required type: - -.. code-block:: - - >>> from pythae.trainers.training_config import TrainingConfig - >>> config = TrainingConfig(max_epochs='10') - >>> config.max_epochs, type(config.max_epochs) - (10, ) - -- or raise a Validation Error - -.. code-block:: python - - - >>> from pythae.trainers.training_config import TrainingConfig - >>> config = TrainingConfig(max_epochs='10_') - Traceback (most recent call last): - File "", line 1, in - File "", line 13, in __init__ - File "pydantic/dataclasses.py", line 99, in pydantic.dataclasses._generate_pydantic_post_init._pydantic_post_init - # +=======+=======+=======+ - pydantic.error_wrappers.ValidationError: 1 validation error for TrainingConfig - max_epochs - value is not a valid integer (type=type_error.integer) - - -A similar check is performed on the ``.json`` when the classmethod :class:`~pythae.config.BaseConfig.from_json_file` is called. - - - - - -.. _model-setting: - -************************************************ -The model parameters -************************************************ - -Each model coded in pythae requires a :class:`ModelConfig` inheriting from :class:`~pythae.models.base.base_config.BaseAEConfig` class to be built. Hence, to build a basic model you need to run the following - - - -.. code-block:: python - - >>> from pythae.models.my_model.my_model_config import MyModelConfig - >>> from pythae.models.my_model.my_model import MyModel - >>> config = MyModelConfig( - ... input_dim=10 # Setting the data input dimension is needed if you do not use your own autoencoding architecture - ... # your parameters go here - ... ) - >>> m = MyModel(model_config=config) # Built the model - - -Let now say that you want to override the model default parameters. The only thing you have to do is to pass you arguments to the ``dataclass`` :class:`ModelConfig`. - - -Let say, we want to change the temperature T in the metric in a :class:`~pythae.models.RHVAE` model which defaults to 1.5 and raise it to 2. Well simply run the following. - - -.. code-block:: python - - >>> from pythae.models.rhvae.rhvae_config import RHVAEConfig - >>> from pythae.models import RHVAE - >>> config = RHVAEConfig(input_dim=10, temperature=2) - >>> m = RHVAE(model_config=config) - >>> m.temperature - Parameter containing: - tensor([2.]) - - -Check out the documentation to see the whole list of parameter you can amend. - -.. _trainer-setting: - - -************************************************ -The sampler parameters -************************************************ - -To generate from a pythae's model a :class:`ModelSampler` inheriting from :class:`~pythae.models.base.base_sampler.BaseSampler` is used. A :class:`ModelSampler` is instantiated with a pythae's model and a :class:`ModelSamplerConfig`. Hence, likewise the VAE models, the sampler parameters can be easily amended as follows - -.. code-block:: python - - >>> from pythae.models.my_model.my_model_config import MyModelSamplerConfig - >>> from pythae.models.my_model.my_model_sampler import MyModelSampler - >>> config = MyModelSamplerConfig( - ... # your parameters go here - ... ) - >>> m = MyModelSample(model=my_model, sampler_config=config) # Built the model - - -Let now say that you want to override the sampler default parameters. The only thing you have to do is to pass you arguments to the ``dataclass`` :class:`ModelSamplerConfig`. - -Let say, we want to change the number of leapfrog steps in the :class:`~pythae.models.rhvae.RHVAESampler` config model which defaults to 15 and make it to 5. Well your code should look like the following. - - -.. code-block:: python - - >>> from pythae.models import RHVAE - >>> from pythae.models.rhvae import RHVAESampler, RHVAESamplerConfig, RHVAEConfig - >>> custom_sampler_config = RHVAESamplerConfig( - ... n_lf=5 - ... ) # Set up sampler config - >>> custom_sampler = RHVAESampler( - ... model=model, sampler_config=custom_sampler_config - ... ) # Build sampler - >>> custom_sampler.n_lf - ... tensor([5]) - -Check out the documentation to see the whole list of parameter you can amend. - -.. _trainer-setting: - -************************************************ -The :class:`~pythae.trainers.Trainer` parameters -************************************************ - -Likewise the VAE models, the instance :class:`~pythae.trainers.Trainer` can be created with default parameters or you can easily amend them the same way it is done for the models. - - -Example: -~~~~~~~~ - -Say you want to train your model for 10 epochs, with no early stopping on the train et and a learning_rate of 0.1 - -.. code-block:: python - - >>> from pythae.trainers.training_config import TrainingConfig - >>> config = TrainingConfig( - ... max_epochs=10, learning_rate=0.1, train_early_stopping=None) - >>> config - TrainingConfig(output_dir=None, batch_size=50, max_epochs=10, learning_rate=0.1, train_early_stopping=None, eval_early_stopping=None, steps_saving=1000, seed=8, no_cuda=False, verbose=True) - - -You can find a comprehensive description of any parameters of the :class:`~pythae.trainers.Trainer` you can set in :class:`~pythae.trainers.training_config.TrainingConfig` \ No newline at end of file diff --git a/docs/old/advanced_use.rst b/docs/old/advanced_use.rst deleted file mode 100644 index 712a0ec2..00000000 --- a/docs/old/advanced_use.rst +++ /dev/null @@ -1,13 +0,0 @@ -********************************** -Advanced Usage -********************************** - -Here are described more advanced features you can have access to when using pythae. - -.. toctree:: - :maxdepth: 1 - - advanced/setting_configs - advanced/custom_autoencoder - - diff --git a/docs/old/background.rst b/docs/old/background.rst deleted file mode 100644 index c57a2a1e..00000000 --- a/docs/old/background.rst +++ /dev/null @@ -1,131 +0,0 @@ -############################### -Background -############################### - -************************************************ -Why data augmentation ? -************************************************ - -Even though always larger data sets are now available, the lack of labeled data remains a tremendous issue in many fields of application. Among others, a good example is healthcare where practitioners have to deal most of the time with (very) low sample sizes (think of small patient cohorts) along with very high dimensional data (think of neuroimaging data that are 3D volumes with millions of voxels). Unfortunately, this leads to a very poor representation of a given population and makes classical statistical analyses unreliable. Meanwhile, the remarkable performance of algorithms heavily relying on the deep learning framework has made them extremely attractive and very popular. However, such results are strongly conditioned by the number of training samples since such models usually need to be trained on huge data sets to prevent over-fitting or to give statistically meaningful results. A way to address such issues is to perform Data Augmentation (DA) which consists in creating synthetic samples to enrich an initial data set and allow machine learning algorithm to better generalize on unseen data. For instance, the easiest way to do this on images is to apply simple transformations such as the addition of Gaussian noise, cropping or padding, and assign the label of the initial image to the created ones. - -************************************************ -Limitations of classic DA -************************************************ -While such augmentation techniques have revealed very useful, they remain strongly data dependent and limited. Some transformations may indeed be uninformative or even induce bias. - - -.. centered:: - |pic1| apply rotation |pic2| - - - -.. |pic1| image:: imgs/nine_digits.png - :width: 30% - - -.. |pic2| image:: imgs/nine_digits-rot.png - :width: 30% - - -For instance, think of a digit representing a 6 which gives a 9 when rotated. While assessing the relevance of augmented data may be quite straightforward for simple data sets, it reveals very challenging for complex data and may require the intervention of an *expert* assessing the degree of relevance of the proposed transformations. - -************************************************ -Generative models: A new hope -************************************************ - -The recent rise in performance of generative models such as GAN or VAE has made them very attractive models to perform DA. However, the main limitation to a wider use of these models is that they most of the time produce blurry and fuzzy samples. This undesirable effect is even more emphasized when they are trained with a small number of samples which makes them very hard to use in practice to perform DA in the high dimensional (very) low sample size (HDLSS) setting. - - -**This is why pythae was born!** - - -************************************************ -A Flavour of pythae's Results -************************************************ - -Case Study 1: Classification on 3D MRI (ADNI & AIBL) -===================================================== - -A :class:`~pythae.models.RHVAE` model was used to perform Data Augmentation in the High Dimensional Low Sample Size Setting on 3D MRI neuroimaging data from ADNI (http://adni.loni.usc.edu/) and AIBL (https://aibl.csiro.au/) database. The model was used to try to enhance the classification task consisting in finding Alzheimer's disease patients (AD) from Cognitively Normal participants (CN) using T1-weighted MR images :cite:p:`chadebec_data_2021`. - - -Classification set up -------------------------------------------------------- - -Data Splitting -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The ADNI data set was split into 3 sets: train, validation and test. -First, the test set was created using 100 randomly chosen participants for each diagnostic label (i.e. 100 CN, 100 AD). The rest of the data set was split such that 80% is allocated from training and 20% for validation. The authors ensured that age, sex and site distributions between the three sets were not significantly different. The train set is referred to as *train-full* in the following. - -In addition, a smaller training set (denoted as *train-50*) was extracted from *train-full*. This set comprised only 50 images per diagnostic label, instead of 243 CN and 210 AD for *train-full*. It was ensured that age and sex distributions between *train-50* and *train-full* were not significantly different. This was not done for the site distribution as there are more than 50 sites in the ADNI data set (so they could not all be represented in this smaller training set). The AIBL data was **never used for training** or hyperparameter tuning and was only used as an **independent** test set. - -.. centered:: - |pic3| - Data Split for the classification task: Alzheimer Disease (AD) vs. Cognitively Normal (CN) - -.. |pic3| image:: imgs/Case_study_1.jpg - - -Data Processing -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -All the data was processed as follows: - - - Raw data are converted to the BIDS standard :cite:p:`gorgolewski_brain_2016`, - - Bias field correction is applied using N4ITK :cite:p:`tustison_n4itk_2010`, - - T1w images are linearly registered to the MNI standard space :cite:p:`fonov_unbiased_2009,fonov_unbiased_2011` with ANTS :cite:p:`avants_insight_2014` and cropped. This produced images of size 169x208x179 with :math:`1~\mathrm{mm}^{3}` isotropic voxels. - - An automatic quality check is performed using an open-source pretrained network :cite:p:`fonov_deep_2018`. All images passed the quality check. - - NIfTI files are converted to tensor format. - - (Optional) Images are down-sampled using a trilinear interpolation, leading to an image size of 84x104x89. - - Intensity rescaling between the minimum and maximum values of each image is performed. - - -Classifier -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To perform such classification task a CNN was used with two different paradigms to choose the architecture. First, the authors reused the same architecture as in :cite:p:`wen_convolutional_2020` which was obtained by optimizing manually the networks on the ADNI data set for the same task (AD vs CN). A slight adaption was done for the down-sampled images, which consisted in resizing the number of nodes in the fully-connected layers to keep the same ratio between the input and output feature maps in all layers. This architecture is denoted **baseline**. Secondly, a random search was launched :cite:p:`bergstra_random_2012` allowing to explore different hyperperameter values. The hyperparameters explored for the architecture were the number of convolutional blocks, of filters in the first layer and of convolutional layers in a block, the number of fully-connected layers and the dropout rate. Other hyperparameters such as the learning rate and the weight decay were also part of the search. 100 different random architectures were trained on the 5-fold cross-validation done on *train-full*. For each input, the selected architecture is the one that obtained the best mean balanced accuracy across the validation sets of the cross-validation. This architecture is referred to as **optimized**. - -.. centered:: - |pic4| - CNN architectures: *left*: The baseline net. *right*: The optimized one using a random search across 100 architectures. - -.. |pic4| image:: imgs/CNNs.jpeg - -Augmentation Set up -------------------------------------------------------- - -On the meantime, a :class:`~pythae.models.RHVAE` was trained on each class of the train sets (*train-50* or *train-full*) to be able to generate new synthetic data. Noteworthy is the fact that the VAE and the CNN shared the **same training set** and no augmentation was performed on the validation set or the test set. - - -.. centered:: - |pic5| - Data Augmentation scheme with a VAE. - -.. |pic5| image:: imgs/DA_diagram.png - - -Then the **baseline** (resp. **optimized**) CNN networks were then trained for 100 (resp. 50) epochs using the cross entropy loss for training and validation losses. Balanced accuracy was also computed at the end of each epoch. The models were trained on either 1) only the *real* images; 2) only the synthetic samples created by the :class:`~pythae.models.RHVAE` or 3) the augmented training set (*real* + synthetic) on 20 independent runs for each experiment. The final model was chosen as the one that obtained the highest validation balanced accuracy during training. - - -Results -------------------------------------------------------- - -Below are presented some of the main results obtained in this case study. We refer the reader to :cite:p:`chadebec_data_2021` for the full results of the study. - -.. centered:: - |pic6| - Augmentation results with the **baseline** CNN network. - -.. |pic6| image:: imgs/baseline_results.png - -.. centered:: - |pic7| - Augmentation results with the **optimized** CNN network. - - -.. |pic7| image:: imgs/optimized_results.png - -The proposed method allowed for a significant gain in the model classification results even when the CNN was optimized on the real data only (random search not performed for augmented data set) and even though small size data sets were considered along with very challenging high dimensional data. - - -.. bibliography:: \ No newline at end of file diff --git a/docs/old/data/pythae.data.datasets.rst b/docs/old/data/pythae.data.datasets.rst deleted file mode 100644 index 4be0437e..00000000 --- a/docs/old/data/pythae.data.datasets.rst +++ /dev/null @@ -1,7 +0,0 @@ -********************************** -datasets -********************************** - -.. autoclass:: pythae.data.datasets.BaseDataset - :members: __getitem__ - diff --git a/docs/old/data/pythae.data.loaders.rst b/docs/old/data/pythae.data.loaders.rst deleted file mode 100644 index 34961b75..00000000 --- a/docs/old/data/pythae.data.loaders.rst +++ /dev/null @@ -1,10 +0,0 @@ -********************************** -loaders -********************************** - -.. autoclass:: pythae.data.loaders.BaseDataGetter - :members: - -.. autoclass:: pythae.data.loaders.ImageGetterFromFolder - :members: - diff --git a/docs/old/data/pythae.data.preprocessors.rst b/docs/old/data/pythae.data.preprocessors.rst deleted file mode 100644 index 509f00ff..00000000 --- a/docs/old/data/pythae.data.preprocessors.rst +++ /dev/null @@ -1,7 +0,0 @@ -********************************** -preprocessors -********************************** - -.. autoclass:: pythae.data.preprocessors.DataProcessor - :members: - diff --git a/docs/old/data/pythae.data.rst b/docs/old/data/pythae.data.rst deleted file mode 100644 index 7f14e2c3..00000000 --- a/docs/old/data/pythae.data.rst +++ /dev/null @@ -1,53 +0,0 @@ -********************************** -pythae.data -********************************** - - - - -.. toctree:: - :maxdepth: 1 - - pythae.data.datasets - pythae.data.loaders - pythae.data.preprocessors - - -.. automodule:: - pythae.data - - -Datasets --------------- - -.. automodule:: - pythae.data.datasets - -.. autosummary:: - ~pythae.data.datasets.BaseDataset - :nosignatures: - - -Loaders ---------------- - -.. automodule:: - pythae.data.loaders - -.. autosummary:: - ~pythae.data.loaders.BaseDataGetter - ~pythae.data.loaders.ImageGetterFromFolder - :nosignatures: - - - -Preprocessors ---------------- - - -.. automodule:: - pythae.data.preprocessors - -.. autosummary:: - ~pythae.data.preprocessors.DataProcessor - :nosignatures: \ No newline at end of file diff --git a/docs/old/getting_started.rst b/docs/old/getting_started.rst deleted file mode 100644 index 9ff5ae35..00000000 --- a/docs/old/getting_started.rst +++ /dev/null @@ -1,320 +0,0 @@ -################################## -Getting started -################################## - -************************************************ -Description -************************************************ - -This library provides a way to perform Data Augmentation using Variational Autoencoders in a -reliable way even in challenging contexts such as high dimensional and low sample size -data. - -************************************************ -Installation -************************************************ - -To install the library run the following using ``pip`` - -.. code-block:: bash - - $ pip install pythae - - -or alternatively you can clone the github repo to access to tests, tutorials and scripts. - -.. code-block:: bash - - $ https://github.com/clementchadebec/pythae.git - - -************************************************ -pythae's spirit & overview -************************************************ - -The pythae's library organizes as follows - -.. centered:: - |pic3| - pythae's overview - -.. |pic3| image:: imgs/pythae_diagram_simplified.jpg - - -If you clone the pythae's repository you will access to the following: - -- ``docs``: The folder in which the documentation can be retrieved. -- ``tests``: pythae's unit-testing using pytest. -- ``examples``: A list of ``ipynb`` tutorials describing the main functionalities of pythae. -- ``pythae``: The main library which can be installed with ``pip``. - - -In the main library, you will access to the following modules: - -- :ref:`pythae.models`: This is the module where any Variational Autoencoder model is implemented. It is composed of: - - - :ref:`pythae.models.nn`: The module gathers all the neural networks architectures for the encoders, decoders and metrics networks (if applicable) used within the models. - - pythae.models.base: This is the base module of the VAE models. - - pythae.models.other_model: By convention, each implemented model is contained within a folder located in :ref:`pythae.models` in which are located 4 modules: - - - *model_config.py*: Contains a :class:`OtherModelConfig` instance inheriting from :class:`~pythae.models.base.BaseAEConfig` where the model configuration is stored and a :class:`OtherModelSamplerConfig` instance inheriting from :class:`~pythae.models.base.BaseSamplerConfig` where the configuration of the sampler used to generate new samples is defined. - - *other_model_model.py*: An implementation of the other_model inheriting from :class:`~pythae.models.BaseAE`. - - *other_model_sampler.py*: An implementation of the sampler(s) to use to generate new data inheriting from :class:`~pythae.models.base.base_sampler.BaseSampler`. - - *other_model_utils.py*: A module where utils methods are stored. - -- :ref:`pythae.trainer`: This module contains the main function to perform a model training. In particular, it gathers a :class:`~pythae.trainers.training_config.TrainingConfig` instance stating any training parameters and a :class:`~pythae.trainers.Trainer` instance used to train the model. -- :ref:`pythae.data`: Here are located the modules allowing to load, pre-process and convert the data to types handled by pythae. -- :ref:`pythae.pipelines`: In this module can be found pythae's Pipelines. These are functions that allows a user to combine several pythae's modules together. - - -Please see the full module description for further details. - - - -************************************************ -Augmenting your Data -************************************************ - -In pythae, a typical augmentation process is divided into 2 distinct parts: - - - Training a model using the pythae's :class:`~pythae.pipelines.TrainingPipeline` or using the provided ``scripts/training.py`` script - - Generating new data from a trained model using pythae's :class:`~pythae.pipelines.GenerationPipeline` or using the provided ``scripts/generation.py`` script - -There exist two ways to augment your data pretty straightforwardly using pythae's built-in functions. - - - - -Using the provided scripts -================================================= - -pythae provides two scripts allowing you to augment your data directly with commandlines. - -.. note:: - To access to the predefined scripts you should first clone the pythae's repository. - The following scripts are located in ``pythae/scripts`` folder. For the time being, only :class:`~pythae.models.RHVAE` model training and generation is handled by the provided scripts. Models will be added as they are implemented in :ref:`pythae.models` - -Launching a model training: --------------------------------------------------- - -To launch a model training, run - -.. code-block:: bash - - $ python scripts/training.py --path_to_train_data "path/to/your/data/folder" - - - -The data must be located in ``path/to/your/data/folder`` where each input data is a file. Handled image types are ``.pt``, ``.nii``, ``.nii.gz``, ``.bmp``, ``.jpg``, ``.jpeg``, ``.png``. Depending on the usage, other types will be progressively added. - - -At the end of training, the model weights ``models.pt`` and model config ``model_config.json`` file -will be saved in a folder ``outputs/my_model_from_script/training_YYYY-MM-DD_hh-mm-ss/final_model``. - -.. tip:: - In the simplest configuration, default ``training_config.json`` and ``model_config.json`` are used (located in ``scripts/configs`` folder). You can easily override these parameters by defining your own ``.json`` file and passing them to the parser arguments. - - .. code-block:: bash - - $ python scripts/training.py - --path_to_train_data 'path/to/your/data/folder' - --path_to_model_config 'path/to/your/model/config.json' - --path_to_training_config 'path/to/your/training/config.json' - - See :ref:`setting your config` and tutorials for a more in depth example. - -.. note:: - For high dimensional data we advice you to provide you own network architectures. With the - provided MLP you may end up with a ``MemoryError``. - - - -Launching data generation: --------------------------------------------------- - -Then, to launch the data generation process from a trained model, you only need to run - -.. code-block:: bash - - $ python scripts/training.py --num_samples 10 --path_model_folder 'path/to/your/trained/model/folder' - -The generated data is stored in several ``.pt`` files in ``outputs/my_generated_data_from_script/generation_YYYY-MM-DD_hh_mm_ss``. By default, it stores batch data of 500 samples. - -.. tip:: - In the simplest configuration, default ``sampler_config.json`` is used. You can easily override these parameters by defining your own ``.json`` file and passing it the to the parser arguments. - - .. code-block:: bash - - $ python scripts/training.py - --path_to_train_data 'path/to/your/data/folder' - --path_to_sampler_config 'path/to/your/training/config.json' - - See :ref:`setting your config` and tutorials for a more in depth example. - -.. _retrieve-generated-data: - -Retrieve generated data --------------------------------------------------- - -Generated data can then be loaded pretty easily by running - -.. code-block:: python - - >>> import torch - >>> data = torch.load('path/to/generated_data.pt') - - - - -Using pythae's Pipelines -================================================= - -pythae also provides two pipelines that may be uses to either train a model on your own data or generate new data with a pretrained model. - - -.. tip:: - These pipelines are independent of the choice of the model and sampler. Hence, they can be used even if you want to access to more advanced feature such as defining your own autoencoding architecture. - -Launching a model training --------------------------------------------------- - -To launch a model training, you only need to call a :class:`~pythae.pipelines.TrainingPipeline` instance. -In its most basic version the :class:`~pythae.pipelines.TrainingPipeline` can be built without any arguments. -This will by default train a :class:`~pythae.models.RHVAE` model with default autoencoding architecture and parameters. - -.. code-block:: python - - >>> from pythae.pipelines import TrainingPipeline - >>> pipeline = TrainingPipeline() - >>> pipeline(train_data=dataset_to_augment) - -where ``dataset_to_augment`` is either a :class:`numpy.ndarray`, :class:`torch.Tensor` or a path to a folder where each file is a data (handled data format are ``.pt``, ``.nii``, ``.nii.gz``, ``.bmp``, ``.jpg``, ``.jpeg``, ``.png``). - -More generally, you can instantiate your own model and train it with the :class:`~pythae.pipelines.TrainingPipeline`. For instance, if you want to instantiate a basic :class:`~pythae.models.RHVAE` run: - - -.. code-block:: python - - >>> from pythae.models import RHVAE - >>> from pythae.models.rhvae import RHVAEConfig - >>> model_config = RHVAEConfig( - ... input_dim=int(intput_dim) - ... ) # input_dim is the shape of a flatten input data - ... # needed if you do not provided your own architectures - >>> model = RHVAE(model_config) - - -In case you instantiate yourself a model as shown above and you do not provided all the network architectures (encoder, decoder & metric if applicable), the :class:`ModelConfig` instance will expect you to provide the input dimension of your data which equals to ``n_channels x height x width x ...``. pythae's VAE models' networks indeed default to Multi Layer Perceptron neural networks which automatically adapt to the input data shape. Hence, if you do not provided any input dimension an error is raised: - -.. code-block:: python - - >>> from pythae.models.base.base_config import BaseAEConfig - >>> from pythae.models import BaseAE - >>> config = BaseAEConfig() - >>> BaseAE(model_config=config) - Traceback (most recent call last): - File "", line 1, in - File "/home/clement/Documents/these/implem/pythae/src/pythae/models/base/base_vae.py", line 57, in __init__ - raise AttributeError("No input dimension provided !" - AttributeError: No input dimension provided !'input_dim' parameter of - BaseAEConfig instance must be set to 'data_shape' where the shape of the data is [mini_batch x data_shape] . Unable to build encoder automatically - -.. note:: - - In case you have different size of data, pythae will reshape it to the minimum size ``min_n_channels x min_height x min_width x ...`` - - - -Then the :class:`~pythae.pipelines.TrainingPipeline` can be launched by running: - -.. code-block:: python - - >>> from pythae.pipelines import TrainingPipeline - >>> pipe = TrainingPipeline(model=model) - >>> pipe(train_data=dataset_to_augment) - -At the end of training, the model weights ``models.pt`` and model config ``model_config.json`` file -will be saved in a folder ``outputs/my_model_from_script/training_YYYY-MM-DD_hh-mm-ss/final_model``. - - -.. tip:: - In the simplest configuration, defaults training and model parameters are used. You can easily override these parameters by instantiating your own :class:`~pythae.trainers.training_config.TrainingConfig` and :class:`~pythae.models.base.base_config.ModelConfig` file and passing them the to the :class:`~pythae.pipelines.TrainingPipeline`. - - Example for a :class:`~pythae.models.RHVAE` run: - - .. code-block:: python - - >>> from pythae.models import RHVAE - >>> from pythae.model.rhvae import RHVAEConfig - >>> from pythae.trainers.training_config import TrainingConfig - >>> from pythae.pipelines import TrainingPipeline - >>> custom_model_config = RHVAEConfig( - ... input_dim=input_dim, *my_args, **my_kwargs - ... ) # Set up model config - >>> model = RHVAE( - ... model_config=custom_model_config - ... ) # Build model - >>> custom_training_config = TrainingConfig( - ... *my_args, **my_kwargs - ... ) # Set up training config - >>> pipe = TrainingPipeline( - ... model=model, training_config=custom_training_config - ... ) # Build Pipeline - - See :ref:`setting your config` and tutorials for a more in depth example. - -.. note:: - For high dimensional data we advice you to provide you own network architectures. With the - provided MLP you may end up with a ``MemoryError``. - - -Launching data generation --------------------------------------------------- - -To launch the data generation process from a trained model, run the following. - -.. code-block:: python - - >>> from pythae.pipelines import GenerationPipeline - >>> model = MODEL.load_from_folder( - ... 'path/to/your/trained/model' - ... ) # reload the model - >>> pipe = GenerationPipeline( - ... model=model - ... ) # define pipeline - >>> pipe(samples_number=10) # This will generate 10 data points - -The generated data is in ``.pt`` files in ``dummy_output_dir/generation_YYYY-MM-DD_hh-mm-ss``. By default, it stores batch data of 500 samples. - -.. note:: - - A model can be easily reloaded from a folder using the classmethod :class:`~pythae.models.BaseAE.load_from_folder` that is defined for each model implemented in pythae and allows to load a model directly from a given folder. - - - -.. tip:: - In the simplest configuration, defaults sampler parameters are used. You can easily override these parameters by instantiating your own :class:`~pythae.models.base.SamplerConfig` and passing it the to the :class:`~pythae.pipelines.GenerationPipeline`. - - Example for a :class:`~pythae.models.rhvae.RHVAESampler` run: - - .. code-block:: python - - >>> from pythae.models.rhvae import RHVAESampler - >>> from pythae.models.rhvae import RHVAESamplerConfig - >>> from pythae.pipelines import GenerationPipeline - >>> custom_sampler_config = RHVAESamplerConfig( - ... *my_args, **my_kwargs - ... ) # Set up sampler config - >>> custom_sampler = RHVAESampler( - ... model=model, sampler_config=custom_sampler_config - ... ) # Build sampler - >>> pipe = generationPipeline( - ... model=model, sampler=custom_sampler - ... ) # Build Pipeline - - See :ref:`setting your config` and tutorials for a more in depth example. - - -Generated data can then be loaded pretty as explained in :ref:`retrieve-generated-data` - - diff --git a/docs/old/math_behind.rst b/docs/old/math_behind.rst deleted file mode 100644 index 552bbb14..00000000 --- a/docs/old/math_behind.rst +++ /dev/null @@ -1,86 +0,0 @@ -********************************** -The maths behind the code -********************************** - -.. _math_behind: - -Let's talk about math! -###################### - - -The main idea behind the model proposed in this library is to learned the latent structure -of the input data :math:`x \in \mathcal{X}`. - -Variational AutoEncoder -~~~~~~~~~~~~~~~~~~~~~~~ - -**Model Setting** - -Assume we are given a set of input data :math:`x \in \mathcal{X}`. A VAE aims at maximizing the -likelihood of a given parametric model :math:`\{\mathbb{P}_{\theta}, \theta \in \Theta\}`. It is -assumed that there exist latent variables :math:`z` living in a lower dimensional space -:math:`\mathcal{Z}`, referred to as the *latent space*, such that the marginal distribution -of the data can be written as - - -.. math:: - - p_{\theta}(x) = \int \limits _{\mathcal{Z}} p_{\theta}(x|z)q(z) dz \,, - - - -where :math:`q` is a prior distribution over the latent variables acting as a *regulation factor* -and :math:`p_{\theta}(x|z)` is most of the time taken as a simple parametrized distribution (*e.g.* -Gaussian, Bernoulli, etc.) and is referred to as the *decoder* the parameters of which are -given by neural networks. Since the integral of teh objective is most of the time intractable, -so is the posterior distribution: - -.. math:: - - p_{\theta}(z|x) = \frac{p_{\theta}(x|z) q(z)}{\int \limits_{\mathcal{Z}} p_{\theta}(x|z) q(z) dz}\,. - -This makes direct application of Bayesian inference impossible and so recourse to approximation -techniques such as variational inference is needed. Hence, a variational distribution -:math:`q_{\phi}(z|x)` is introduced and aims at approximating the true posterior distribution -:math:`p_{\theta}(z|x)`. This variational distribution is often referred to as the *encoder*. In the initial version of the VAE, :math:`q_{\phi}` is taken as a multivariate -Gaussian whose parameters :math:`\mu_{\phi}` and :math:`\Sigma_{\phi}` are again given by neural -networks. Importance sampling can then be applied to derive an unbiased estimate of the marginal -distribution :math:`p_{\theta}(x)` we want to maximize. - -.. math:: - - \hat{p}_{\theta}(x) = \frac{p_{\theta}(x|z)q(z)}{q_{\phi}(z|x)} \hspace{2mm} \text{and} \hspace{2mm} \mathbb{E}_{z \sim q_{\phi}}\big[\hat{p}_{\theta}\big] = p_{\theta}(x)\,. - -Using Jensen's inequality allows finding a lower bound on the objective function of the objective - -.. math:: - - \begin{aligned} - \log p_{\theta}(x) &= \log \mathbb{E}_{z \sim q_{\phi}}\big[\hat{p}_{\theta}\big]\\ - &\geq \mathbb{E}_{z \sim q_{\phi}}\big[\log \hat{p}_{\theta}\big]\\ - & \geq \mathbb{E}_{z \sim q_{\phi}}\big[ \log p_{\theta}(x, z) - \log q_{\phi}(z|x) \big] = ELBO\,. - \end{aligned} - -The Evidence Lower BOund (ELBO) is now tractable since both :math:`p_{\theta}(x, z)` and -:math:`q_{\phi}(z|x)` are known and so can be optimized with respect to the *encoder* and *decoder* parameters. - - -**Bringing Geometry to the Model** - -In the RHVAE, the assumption of an Euclidean latent space is relaxed and it is assumed that the -latent variables live in the Riemannian manifold :math:`\mathcal{Z} =(\mathbb{R}^d, g)` where :math:`g` is the Riemannian metric.. -This Riemannian metric is basically a smooth inner product on the tangent space -:math:`T_{\mathcal{Z}}` of the manifold defined at each point :math:`z \in \mathcal{Z}`. Hence, it can be represented by a definite positive matrix :math:`\mathbf{G}(z)` at each point of the manifold :math:`\mathcal{Z}`. This Riemannian metric plays a crucial role in the modeling of the latent space and since it is not known we propose to **parametrize** it and **learn** it directly from the data :math:`x \in \mathcal{X}`. The metric parametrization writes: - -.. math:: - - \mathbf{G}^{-1}(z) = \sum_{i=1}^N L_{\psi_i} L_{\psi_i}^{\top} \exp \Big(-\frac{\lVert z -c_i \rVert_2^2}{T^2} \Big) + \lambda I_d \,, - -where :math:`N` is the number of observations, :math:`L_{\psi_i}` are lower triangular matrices with positive diagonal coefficients learned from the data and parametrized with neural networks, :math:`c_i` are referred to as the *centroids* and correspond to the mean :math:`\mu_{\phi}(x_i)` of the encoded distributions of the latent variables :math:`z_i` :math:`(z_i \sim q_{\phi}(z_i|x_i) = \mathcal{N}(\mu_{\phi}(x_i), \Sigma_{\phi}(x_i))`, :math:`T` is a temperature scaling the metric close to the *centroids* and :math:`\lambda` is a regularization factor that also scales the metric tensor far from the latent codes. - - - -**Combining Geometrical Aspect And Normalizing Flows** - -A way to improve the vanilla VAE resides in trying to enhance the ELBO expression so that it becomes closer to the true objective. Trying to tweak the approximate posterior distribution so that it becomes *closer* to the true posterior can achieve such a goal. To do so, a method involving parametrized invertible mappings :math:`f_x` called *normalizing flows* were proposed in~\cite{rezende_variational_2015} to *sample* :math:`z`. A starting random variable :math:`z_0` is drawn from an initial distribution :math:`q_{\phi}(z|x)` and then :math:`K` normalizing flows are applied to :math:`z_0` resulting in a random variable :math:`z_K = f_x^K \circ \cdots \circ f_x^1(z_0)`. Ideally, we would like to have access to normalizing flows targeting the true posterior and allowing enriching the above distribution and so improve the lower bound. In that particular respect, a model inspired by the Hamiltonian Monte Carlo sampler~\cite{neal_mcmc_2011} and relying on Hamiltonian dynamics was proposed in~\cite{salimans_markov_2015} and~\cite{caterini_hamiltonian_2018}. The strength of such a model relies in the choice of the normalizing flows which are guided by the gradient of the true posterior distribution. - diff --git a/docs/old/pythae.config.rst b/docs/old/pythae.config.rst deleted file mode 100644 index 90d3c26e..00000000 --- a/docs/old/pythae.config.rst +++ /dev/null @@ -1,6 +0,0 @@ -********************************** -pythae.config -********************************** - -.. autoclass:: pythae.config.BaseConfig - :members: \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 5bc78327..7c18a2dc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -41,11 +41,11 @@ "sphinx.ext.napoleon", "sphinx.ext.viewcode", "sphinx.ext.autosectionlabel", - "sphinxcontrib.bibtex" + "sphinxcontrib.bibtex", ] -suppress_warnings = ['autosectionlabel.*'] +suppress_warnings = ["autosectionlabel.*"] bibtex_bibfiles = ["references.bib"] diff --git a/docs/source/index.rst b/docs/source/index.rst index fd430758..16125226 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,6 +10,11 @@ Welcome to pythae's documentation! This library aims at gathering some of the common (Variational) Autoencoders implementations so that we can conduct benchmark analysis and reproducible research! +**News** ๐Ÿ“ข + +As of v0.1.0, `Pythae` now supports distributed training using PyTorch's `DDP `_). You can now train your favorite VAE faster and on larger datasets, still with a few lines of code. +See :ref:`Distributed Training`. + .. toctree:: :maxdepth: 1 :caption: Pythae diff --git a/docs/source/models/pythae.models.rst b/docs/source/models/pythae.models.rst index bf718882..5c991081 100644 --- a/docs/source/models/pythae.models.rst +++ b/docs/source/models/pythae.models.rst @@ -81,8 +81,13 @@ instance. ... output_dir='my_model', ... num_epochs=50, ... learning_rate=1e-3, - ... batch_size=200, - ... steps_saving=None + ... per_device_train_batch_size=200, + ... per_device_eval_batch_size=200, + ... steps_saving=None, + ... optimizer_cls="AdamW", + ... optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)}, + ... scheduler_cls="ReduceLROnPlateau", + ... scheduler_params={"patience": 5, "factor": 0.5} ... ) >>> # Set up the model configuration >>> my_vae_config = model_config = VAEConfig( @@ -102,4 +107,62 @@ instance. >>> pipeline( ... train_data=your_train_data, # must be torch.Tensor or np.array ... eval_data=your_eval_data # must be torch.Tensor or np.array - ... ) \ No newline at end of file + ... ) + + +.. _Distributed Training: + +Distributed Training with Pythae +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +As of v0.1.0, Pythae now supports distributed training using PyTorch's `DDP `_. It allows you to train your favorite VAE faster and on larger dataset using multi-node and/or multi-gpu training. + +To do so, you can built a python script that will then be launched by a launcher (such as `srun` on a cluster). The only thing that is needed in the script is to specify some elements relative to the environment (such as the number of nodes/gpus) directly in the training configuration as follows + +.. code-block:: + + >>> training_config = BaseTrainerConfig( + ... num_epochs=10, + ... learning_rate=1e-3, + ... per_device_train_batch_size=64, + ... per_device_eval_batch_size=64, + ... dist_backend="nccl", # distributed backend + ... world_size=8 # number of gpus to use (n_nodes x n_gpus_per_node), + ... rank=5 # process/gpu id, + ... local_rank=1 # node id, + ... master_addr="localhost" # master address, + ... master_port="1 + + +See this `example script `_ that defines a multi-gpu VQVAE training. Be carefull, the way the environnement (`world_size`, `rank` ...) may be specific to the cluster and launcher you use. + +Benchmark +############ + +Below are indicated the training times for a Vector Quantized VAE (VQ-VAE) with `Pythae` for 100 epochs +on MNIST on V100 16GB GPU(s), for 50 epochs on `FFHQ `_ (1024x1024 images) +and for 20 epochs on `ImageNet-1k `_ on V100 32GB GPU(s). + +.. list-table:: Training times of a VQ-VAE with Pythae + :widths: 25 25 25 25 25 + :header-rows: 1 + + * - Dataset + - Data type (train size) + - 1 GPU + - 4 GPUs + - 2x4 GPUs + * - MNIST + - 28x28 images (50k) + - 221.01s + - 60.32s + - 34.50s + * - FFHQ + - 1024x1024 RGB images (60k) + - 19h 1min + - 5h 6min + - 2h 37min + * - ImageNet-1k + - 128x128 RGB images (~ 1.2M) + - 6h 25min + - 1h 41min + - 51min 26s \ No newline at end of file diff --git a/docs/source/pipelines/pythae.pipelines.rst b/docs/source/pipelines/pythae.pipelines.rst index 32372275..a74893a2 100644 --- a/docs/source/pipelines/pythae.pipelines.rst +++ b/docs/source/pipelines/pythae.pipelines.rst @@ -36,7 +36,8 @@ instance. ... output_dir='my_model', ... num_epochs=50, ... learning_rate=1e-3, - ... batch_size=200, + ... per_device_train_batch_size=64, + ... per_device_eval_batch_size=64, ... steps_saving=None ... ) >>> # Set up the model configuration diff --git a/docs/source/samplers/pythae.samplers.rst b/docs/source/samplers/pythae.samplers.rst index 2f7f5994..a84ce7b7 100644 --- a/docs/source/samplers/pythae.samplers.rst +++ b/docs/source/samplers/pythae.samplers.rst @@ -53,18 +53,18 @@ Normal sampling >>> from pythae.samplers import NormalSampler >>> # Retrieve the trained model >>> my_trained_vae = VAE.load_from_folder( - ... 'path/to/your/trained/model' + ... 'path/to/your/trained/model' ... ) >>> # Define your sampler >>> my_samper = NormalSampler( - ... model=my_trained_vae + ... model=my_trained_vae ... ) >>> # Generate samples >>> gen_data = my_samper.sample( - ... num_samples=50, - ... batch_size=10, - ... output_dir=None, - ... return_gen=True + ... num_samples=50, + ... batch_size=10, + ... output_dir=None, + ... return_gen=True ... ) Gaussian mixture sampling @@ -76,24 +76,24 @@ Gaussian mixture sampling >>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig >>> # Retrieve the trained model >>> my_trained_vae = VAE.load_from_folder( - ... 'path/to/your/trained/model' + ... 'path/to/your/trained/model' ... ) >>> # Define your sampler - ... gmm_sampler_config = GaussianMixtureSamplerConfig( - ... n_components=10 + ... gmm_sampler_config = GaussianMixtureSamplerConfig( + ... n_components=10 ... ) >>> my_samper = GaussianMixtureSampler( - ... sampler_config=gmm_sampler_config, - ... model=my_trained_vae + ... sampler_config=gmm_sampler_config, + ... model=my_trained_vae ... ) >>> # fit the sampler >>> gmm_sampler.fit(train_dataset) >>> # Generate samples >>> gen_data = my_samper.sample( - ... num_samples=50, - ... batch_size=10, - ... output_dir=None, - ... return_gen=True + ... num_samples=50, + ... batch_size=10, + ... output_dir=None, + ... return_gen=True ... ) See also `tutorials `_. \ No newline at end of file diff --git a/examples/notebooks/comet_experiment_monitoring.ipynb b/examples/notebooks/comet_experiment_monitoring.ipynb index b19e3fd9..ffa78769 100644 --- a/examples/notebooks/comet_experiment_monitoring.ipynb +++ b/examples/notebooks/comet_experiment_monitoring.ipynb @@ -71,7 +71,8 @@ "training_config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more,\n", " steps_predict=3\n", ")\n", diff --git a/examples/notebooks/custom_dataset.ipynb b/examples/notebooks/custom_dataset.ipynb index 157b3e38..b2bd65bb 100644 --- a/examples/notebooks/custom_dataset.ipynb +++ b/examples/notebooks/custom_dataset.ipynb @@ -185,7 +185,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-3,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/hf_hub_models_sharing.ipynb b/examples/notebooks/hf_hub_models_sharing.ipynb index aa2f1dac..bd1bbae3 100644 --- a/examples/notebooks/hf_hub_models_sharing.ipynb +++ b/examples/notebooks/hf_hub_models_sharing.ipynb @@ -73,7 +73,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=1, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/making_your_own_autoencoder.ipynb b/examples/notebooks/making_your_own_autoencoder.ipynb index 6e8d2231..276f039a 100644 --- a/examples/notebooks/making_your_own_autoencoder.ipynb +++ b/examples/notebooks/making_your_own_autoencoder.ipynb @@ -291,7 +291,8 @@ "training_config = BaseTrainerConfig(\n", " output_dir='my_model_with_custom_archi',\n", " learning_rate=1e-3,\n", - " batch_size=200,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " steps_saving=None,\n", " num_epochs=10)" ] diff --git a/examples/notebooks/mlflow_experiment_monitoring.ipynb b/examples/notebooks/mlflow_experiment_monitoring.ipynb index 3e1093e4..23eaa337 100644 --- a/examples/notebooks/mlflow_experiment_monitoring.ipynb +++ b/examples/notebooks/mlflow_experiment_monitoring.ipynb @@ -71,7 +71,8 @@ "training_config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/adversarial_ae_training.ipynb b/examples/notebooks/models_training/adversarial_ae_training.ipynb index f19ba326..b507b798 100644 --- a/examples/notebooks/models_training/adversarial_ae_training.ipynb +++ b/examples/notebooks/models_training/adversarial_ae_training.ipynb @@ -58,7 +58,8 @@ "config = AdversarialTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/ae_training.ipynb b/examples/notebooks/models_training/ae_training.ipynb index a569e9ab..01bd9fb1 100644 --- a/examples/notebooks/models_training/ae_training.ipynb +++ b/examples/notebooks/models_training/ae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -341,7 +342,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/beta_tc_vae_training.ipynb b/examples/notebooks/models_training/beta_tc_vae_training.ipynb index 3cbeacdf..57571119 100644 --- a/examples/notebooks/models_training/beta_tc_vae_training.ipynb +++ b/examples/notebooks/models_training/beta_tc_vae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/beta_vae_training.ipynb b/examples/notebooks/models_training/beta_vae_training.ipynb index 42df1b16..40e4ea1b 100644 --- a/examples/notebooks/models_training/beta_vae_training.ipynb +++ b/examples/notebooks/models_training/beta_vae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -343,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/ciwae_training.ipynb b/examples/notebooks/models_training/ciwae_training.ipynb index 52d84e06..8e3e784d 100644 --- a/examples/notebooks/models_training/ciwae_training.ipynb +++ b/examples/notebooks/models_training/ciwae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb b/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb index dd63129e..f47cb691 100644 --- a/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb +++ b/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/factor_vae_training.ipynb b/examples/notebooks/models_training/factor_vae_training.ipynb index 6d6b1953..b0c013cf 100644 --- a/examples/notebooks/models_training/factor_vae_training.ipynb +++ b/examples/notebooks/models_training/factor_vae_training.ipynb @@ -58,7 +58,8 @@ "config = AdversarialTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -343,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/hvae_training.ipynb b/examples/notebooks/models_training/hvae_training.ipynb index f8f70e5f..ae67b09c 100644 --- a/examples/notebooks/models_training/hvae_training.ipynb +++ b/examples/notebooks/models_training/hvae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/info_vae_training.ipynb b/examples/notebooks/models_training/info_vae_training.ipynb index 656c6342..8984813e 100644 --- a/examples/notebooks/models_training/info_vae_training.ipynb +++ b/examples/notebooks/models_training/info_vae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/iwae_training.ipynb b/examples/notebooks/models_training/iwae_training.ipynb index e7904b21..5ace04a8 100644 --- a/examples/notebooks/models_training/iwae_training.ipynb +++ b/examples/notebooks/models_training/iwae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -342,7 +343,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/miwae_training.ipynb b/examples/notebooks/models_training/miwae_training.ipynb index 0e8cdb8d..3bd3a59d 100644 --- a/examples/notebooks/models_training/miwae_training.ipynb +++ b/examples/notebooks/models_training/miwae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/ms_ssim_vae_training.ipynb b/examples/notebooks/models_training/ms_ssim_vae_training.ipynb index a3a5c63e..c5bc0d17 100644 --- a/examples/notebooks/models_training/ms_ssim_vae_training.ipynb +++ b/examples/notebooks/models_training/ms_ssim_vae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -343,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/piwae_training.ipynb b/examples/notebooks/models_training/piwae_training.ipynb index 67199eed..af348ba3 100644 --- a/examples/notebooks/models_training/piwae_training.ipynb +++ b/examples/notebooks/models_training/piwae_training.ipynb @@ -58,7 +58,8 @@ "config = CoupledOptimizerTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/pvae_training.ipynb b/examples/notebooks/models_training/pvae_training.ipynb index 1c9c254e..40649651 100644 --- a/examples/notebooks/models_training/pvae_training.ipynb +++ b/examples/notebooks/models_training/pvae_training.ipynb @@ -160,7 +160,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=5e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/rae_gp_training.ipynb b/examples/notebooks/models_training/rae_gp_training.ipynb index 56525274..032e9bec 100644 --- a/examples/notebooks/models_training/rae_gp_training.ipynb +++ b/examples/notebooks/models_training/rae_gp_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/rae_l2_training.ipynb b/examples/notebooks/models_training/rae_l2_training.ipynb index 69a7f4ce..000cd233 100644 --- a/examples/notebooks/models_training/rae_l2_training.ipynb +++ b/examples/notebooks/models_training/rae_l2_training.ipynb @@ -58,7 +58,8 @@ "config = CoupledOptimizerTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/rhvae_training.ipynb b/examples/notebooks/models_training/rhvae_training.ipynb index e645b4d2..9286f57f 100644 --- a/examples/notebooks/models_training/rhvae_training.ipynb +++ b/examples/notebooks/models_training/rhvae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/svae_training.ipynb b/examples/notebooks/models_training/svae_training.ipynb index 71fdbe4a..7fc9d88e 100644 --- a/examples/notebooks/models_training/svae_training.ipynb +++ b/examples/notebooks/models_training/svae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/vae_iaf_training.ipynb b/examples/notebooks/models_training/vae_iaf_training.ipynb index 6c6f0c19..9db773e9 100644 --- a/examples/notebooks/models_training/vae_iaf_training.ipynb +++ b/examples/notebooks/models_training/vae_iaf_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -343,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/vae_lin_nf_training.ipynb b/examples/notebooks/models_training/vae_lin_nf_training.ipynb index 892b64e0..f459249c 100644 --- a/examples/notebooks/models_training/vae_lin_nf_training.ipynb +++ b/examples/notebooks/models_training/vae_lin_nf_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/vae_lstm.ipynb b/examples/notebooks/models_training/vae_lstm.ipynb index 3c083070..6d8c632f 100644 --- a/examples/notebooks/models_training/vae_lstm.ipynb +++ b/examples/notebooks/models_training/vae_lstm.ipynb @@ -176,7 +176,8 @@ "training_config = BaseTrainerConfig(\n", " num_epochs=50,\n", " learning_rate=1e-3,\n", - " batch_size=64,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " output_dir='my_lstm_models'\n", ")" ] diff --git a/examples/notebooks/models_training/vae_training.ipynb b/examples/notebooks/models_training/vae_training.ipynb index 0be27f45..bef601a2 100644 --- a/examples/notebooks/models_training/vae_training.ipynb +++ b/examples/notebooks/models_training/vae_training.ipynb @@ -58,8 +58,11 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", + " optimizer_cls=\"AdamW\",\n", + " optimizer_params={\"weight_decay\": 0.05, \"betas\": (0.91, 0.99)}\n", ")\n", "\n", "\n", @@ -341,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/vaegan_training.ipynb b/examples/notebooks/models_training/vaegan_training.ipynb index 0ab05e3b..08dce4d5 100644 --- a/examples/notebooks/models_training/vaegan_training.ipynb +++ b/examples/notebooks/models_training/vaegan_training.ipynb @@ -58,7 +58,8 @@ "config = CoupledOptimizerAdversarialTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -347,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/vamp_training.ipynb b/examples/notebooks/models_training/vamp_training.ipynb index 1a9f76fd..208e3472 100644 --- a/examples/notebooks/models_training/vamp_training.ipynb +++ b/examples/notebooks/models_training/vamp_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/models_training/vqvae_training.ipynb b/examples/notebooks/models_training/vqvae_training.ipynb index cc94d109..d9bd6fc7 100644 --- a/examples/notebooks/models_training/vqvae_training.ipynb +++ b/examples/notebooks/models_training/vqvae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-3,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", @@ -282,7 +283,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/wae_training.ipynb b/examples/notebooks/models_training/wae_training.ipynb index 03ea6657..d78d1396 100644 --- a/examples/notebooks/models_training/wae_training.ipynb +++ b/examples/notebooks/models_training/wae_training.ipynb @@ -58,7 +58,8 @@ "config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/notebooks/wandb_experiment_monitoring.ipynb b/examples/notebooks/wandb_experiment_monitoring.ipynb index 17fbc9d6..385324b7 100644 --- a/examples/notebooks/wandb_experiment_monitoring.ipynb +++ b/examples/notebooks/wandb_experiment_monitoring.ipynb @@ -71,7 +71,8 @@ "training_config = BaseTrainerConfig(\n", " output_dir='my_model',\n", " learning_rate=1e-4,\n", - " batch_size=100,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", " num_epochs=10, # Change this to train the model a bit more\n", ")\n", "\n", diff --git a/examples/scripts/configs/binary_mnist/base_training_config.json b/examples/scripts/configs/binary_mnist/base_training_config.json index 776f4f8c..8dd0623c 100644 --- a/examples/scripts/configs/binary_mnist/base_training_config.json +++ b/examples/scripts/configs/binary_mnist/base_training_config.json @@ -1,7 +1,8 @@ { "name": "BaseTrainerConfig", "output_dir": "reproducibility/binary_mnist", - "batch_size": 100, + "per_device_train_batch_size": 100, + "per_device_eval_batch_size": 100, "num_epochs": 500, "learning_rate": 1e-4, "steps_saving": null, diff --git a/examples/scripts/configs/celeba/base_training_config.json b/examples/scripts/configs/celeba/base_training_config.json index 73eed926..43bee787 100644 --- a/examples/scripts/configs/celeba/base_training_config.json +++ b/examples/scripts/configs/celeba/base_training_config.json @@ -1,7 +1,8 @@ { "name": "BaseTrainerConfig", "output_dir": "my_models_on_celeba", - "batch_size": 100, + "per_device_train_batch_size": 100, + "per_device_eval_batch_size": 100, "num_epochs": 50, "learning_rate": 0.001, "steps_saving": null, diff --git a/examples/scripts/custom_nn.py b/examples/scripts/custom_nn.py index 7a380e51..82bec21e 100644 --- a/examples/scripts/custom_nn.py +++ b/examples/scripts/custom_nn.py @@ -1,13 +1,13 @@ +from typing import List + +import numpy as np import torch import torch.nn as nn -import numpy as np -from typing import List -from pythae.models.nn.base_architectures import BaseEncoder, BaseDecoder -from pythae.models.base.base_utils import ModelOutput from pythae.models import BaseAEConfig - -from pythae.models.nn import BaseEncoder, BaseDecoder, BaseDiscriminator +from pythae.models.base.base_utils import ModelOutput +from pythae.models.nn import BaseDecoder, BaseDiscriminator, BaseEncoder +from pythae.models.nn.base_architectures import BaseDecoder, BaseEncoder class Fully_Conv_Encoder_Conv_AE_MNIST(BaseEncoder): diff --git a/examples/scripts/distributed_training_ffhq.py b/examples/scripts/distributed_training_ffhq.py new file mode 100644 index 00000000..1a9ce7ca --- /dev/null +++ b/examples/scripts/distributed_training_ffhq.py @@ -0,0 +1,213 @@ +import argparse +import logging +import os +import time + +import hostlist +import torch +import torch.nn as nn +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +from pythae.data.datasets import DatasetOutput +from pythae.models import VQVAE, VQVAEConfig +from pythae.models.base.base_utils import ModelOutput +from pythae.models.nn.base_architectures import BaseDecoder, BaseEncoder +from pythae.models.nn.benchmarks.utils import ResBlock +from pythae.trainers import BaseTrainer, BaseTrainerConfig + +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + +PATH = os.path.dirname(os.path.abspath(__file__)) + +ap = argparse.ArgumentParser() + +# Training setting +ap.add_argument( + "--use_wandb", + help="whether to log the metrics in wandb", + action="store_true", +) +ap.add_argument( + "--wandb_project", + help="wandb project name", + default="ffhq-distributed", +) +ap.add_argument( + "--wandb_entity", + help="wandb entity name", + default="pythae", +) + +args = ap.parse_args() + + +class Encoder_ResNet_VQVAE_FFHQ(BaseEncoder): + def __init__(self, args): + BaseEncoder.__init__(self) + + self.latent_dim = args.latent_dim + self.n_channels = 3 + + self.layers = nn.Sequential( + nn.Conv2d(self.n_channels, 32, 4, 2, padding=1), + nn.Conv2d(32, 64, 4, 2, padding=1), + nn.Conv2d(64, 128, 4, 2, padding=1), + nn.Conv2d(128, 256, 4, 2, padding=1), + nn.Conv2d(256, 256, 4, 2, padding=1), + ResBlock(in_channels=256, out_channels=64), + ResBlock(in_channels=256, out_channels=64), + ResBlock(in_channels=256, out_channels=64), + ResBlock(in_channels=256, out_channels=64), + ) + + self.pre_qantized = nn.Conv2d(256, self.latent_dim, 1, 1) + + def forward(self, x: torch.Tensor): + output = ModelOutput() + out = x + out = self.layers(out) + output["embedding"] = self.pre_qantized(out) + + return output + + +class Decoder_ResNet_VQVAE_FFHQ(BaseDecoder): + def __init__(self, args): + BaseDecoder.__init__(self) + + self.latent_dim = args.latent_dim + self.n_channels = 3 + + self.dequantize = nn.ConvTranspose2d(self.latent_dim, 256, 1, 1) + + self.layers = nn.Sequential( + ResBlock(in_channels=256, out_channels=64), + ResBlock(in_channels=256, out_channels=64), + ResBlock(in_channels=256, out_channels=64), + ResBlock(in_channels=256, out_channels=64), + nn.ConvTranspose2d(256, 256, 4, 2, padding=1), + nn.ConvTranspose2d(256, 128, 4, 2, padding=1), + nn.ConvTranspose2d(128, 64, 4, 2, padding=1), + nn.ConvTranspose2d(64, 32, 4, 2, padding=1), + nn.ConvTranspose2d(32, self.n_channels, 4, 2, padding=1), + nn.Sigmoid(), + ) + + def forward(self, z: torch.Tensor): + output = ModelOutput() + + out = self.dequantize(z) + output["reconstruction"] = self.layers(out) + + return output + + +class FFHQ(Dataset): + def __init__(self, data_dir=None, is_train=True, transforms=None): + self.imgs_path = [os.path.join(data_dir, n) for n in os.listdir(data_dir)] + if is_train: + self.imgs_path = self.imgs_path[:60000] + else: + self.imgs_path = self.imgs_path[60000:] + self.transforms = transforms + + def __len__(self): + return len(self.imgs_path) + + def __getitem__(self, idx): + img = Image.open(self.imgs_path[idx]) + if self.transforms is not None: + img = self.transforms(img) + return DatasetOutput(data=img) + + +def main(args): + + img_transforms = transforms.Compose([transforms.ToTensor()]) + + train_dataset = FFHQ( + data_dir="/gpfsscratch/rech/wlr/uhw48em/data/ffhq/images1024x1024/all_images", + is_train=True, + transforms=img_transforms, + ) + eval_dataset = FFHQ( + data_dir="/gpfsscratch/rech/wlr/uhw48em/data/ffhq/images1024x1024/all_images", + is_train=False, + transforms=img_transforms, + ) + + model_config = VQVAEConfig( + input_dim=(3, 1024, 1024), latent_dim=128, use_ema=True, num_embeddings=1024 + ) + + encoder = Encoder_ResNet_VQVAE_FFHQ(model_config) + decoder = Decoder_ResNet_VQVAE_FFHQ(model_config) + + model = VQVAE(model_config=model_config, encoder=encoder, decoder=decoder) + + gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") + + training_config = BaseTrainerConfig( + num_epochs=50, + train_dataloader_num_workers=8, + eval_dataloader_num_workers=8, + output_dir="my_models_on_ffhq", + per_device_train_batch_size=64, + per_device_eval_batch_size=64, + learning_rate=1e-4, + steps_saving=None, + steps_predict=None, + no_cuda=False, + world_size=int(os.environ["SLURM_NTASKS"]), + dist_backend="nccl", + rank=int(os.environ["SLURM_PROCID"]), + local_rank=int(os.environ["SLURM_LOCALID"]), + master_addr=hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0], + master_port=str(12345 + int(min(gpu_ids))), + ) + + if int(os.environ["SLURM_PROCID"]) == 0: + logger.info(model) + logger.info(f"Training config: {training_config}\n") + + callbacks = [] + + # Only log to wandb if main process + if args.use_wandb and (training_config.rank == 0 or training_config == -1): + from pythae.trainers.training_callbacks import WandbCallback + + wandb_cb = WandbCallback() + wandb_cb.setup( + training_config, + model_config=model_config, + project_name=args.wandb_project, + entity_name=args.wandb_entity, + ) + + callbacks.append(wandb_cb) + + trainer = BaseTrainer( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + training_config=training_config, + callbacks=callbacks, + ) + + start_time = time.time() + + trainer.train() + + end_time = time.time() + + logger.info(f"Total execution time: {(end_time - start_time)} seconds") + + +if __name__ == "__main__": + + main(args) diff --git a/examples/scripts/distributed_training_imagenet.py b/examples/scripts/distributed_training_imagenet.py new file mode 100644 index 00000000..9e849c70 --- /dev/null +++ b/examples/scripts/distributed_training_imagenet.py @@ -0,0 +1,207 @@ +import argparse +import logging +import os +import time + +import hostlist +import torch +import torch.nn as nn +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +from pythae.data.datasets import DatasetOutput +from pythae.models import VQVAE, VQVAEConfig +from pythae.models.base.base_utils import ModelOutput +from pythae.models.nn.base_architectures import BaseDecoder, BaseEncoder +from pythae.models.nn.benchmarks.utils import ResBlock +from pythae.trainers import BaseTrainer, BaseTrainerConfig + +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + +PATH = os.path.dirname(os.path.abspath(__file__)) + +ap = argparse.ArgumentParser() + +# Training setting +ap.add_argument( + "--use_wandb", + help="whether to log the metrics in wandb", + action="store_true", +) +ap.add_argument( + "--wandb_project", + help="wandb project name", + default="imagenet-distributed", +) +ap.add_argument( + "--wandb_entity", + help="wandb entity name", + default="pythae", +) + +args = ap.parse_args() + + +class Encoder_ResNet_VQVAE_ImageNet(BaseEncoder): + def __init__(self, args): + BaseEncoder.__init__(self) + + self.latent_dim = args.latent_dim + self.n_channels = 3 + + self.layers = nn.Sequential( + nn.Conv2d(self.n_channels, 32, 4, 2, padding=1), + nn.Conv2d(32, 64, 4, 2, padding=1), + nn.Conv2d(64, 128, 4, 2, padding=1), + nn.Conv2d(128, 128, 4, 2, padding=1), + ResBlock(in_channels=128, out_channels=64), + ResBlock(in_channels=128, out_channels=64), + ResBlock(in_channels=128, out_channels=64), + ResBlock(in_channels=128, out_channels=64), + ) + + self.pre_qantized = nn.Conv2d(128, self.latent_dim, 1, 1) + + def forward(self, x: torch.Tensor): + output = ModelOutput() + out = x + out = self.layers(out) + output["embedding"] = self.pre_qantized(out) + + return output + + +class Decoder_ResNet_VQVAE_ImageNet(BaseDecoder): + def __init__(self, args): + BaseDecoder.__init__(self) + + self.latent_dim = args.latent_dim + self.n_channels = 3 + + self.dequantize = nn.ConvTranspose2d(self.latent_dim, 128, 1, 1) + + self.layers = nn.Sequential( + ResBlock(in_channels=128, out_channels=64), + ResBlock(in_channels=128, out_channels=64), + ResBlock(in_channels=128, out_channels=64), + ResBlock(in_channels=128, out_channels=64), + nn.ConvTranspose2d(128, 128, 4, 2, padding=1), + nn.ConvTranspose2d(128, 64, 4, 2, padding=1), + nn.ConvTranspose2d(64, 32, 4, 2, padding=1), + nn.ConvTranspose2d(32, self.n_channels, 4, 2, padding=1), + nn.Sigmoid(), + ) + + def forward(self, z: torch.Tensor): + output = ModelOutput() + + out = self.dequantize(z) + output["reconstruction"] = self.layers(out) + + return output + + +class ImageNet(Dataset): + def __init__(self, data_dir=None, transforms=None): + self.imgs_path = [os.path.join(data_dir, n) for n in os.listdir(data_dir)] + self.transforms = transforms + + def __len__(self): + return len(self.imgs_path) + + def __getitem__(self, idx): + img = Image.open(self.imgs_path[idx]).convert("RGB") + if self.transforms is not None: + img = self.transforms(img) + return DatasetOutput(data=img) + + +def main(args): + + img_transforms = transforms.Compose( + [transforms.Resize((128, 128)), transforms.ToTensor()] + ) + + train_dataset = ImageNet( + data_dir="/gpfsscratch/rech/wlr/uhw48em/data/imagenet/train", + transforms=img_transforms, + ) + eval_dataset = ImageNet( + data_dir="/gpfsscratch/rech/wlr/uhw48em/data/imagenet/val", + transforms=img_transforms, + ) + + model_config = VQVAEConfig( + input_dim=(3, 128, 128), latent_dim=128, use_ema=True, num_embeddings=1024 + ) + + encoder = Encoder_ResNet_VQVAE_ImageNet(model_config) + decoder = Decoder_ResNet_VQVAE_ImageNet(model_config) + + model = VQVAE(model_config=model_config, encoder=encoder, decoder=decoder) + + gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") + + training_config = BaseTrainerConfig( + num_epochs=20, + train_dataloader_num_workers=8, + eval_dataloader_num_workers=8, + output_dir="my_models_on_imagenet", + per_device_train_batch_size=128, + per_device_eval_batch_size=128, + learning_rate=1e-4, + steps_saving=None, + steps_predict=None, + no_cuda=False, + world_size=int(os.environ["SLURM_NTASKS"]), + dist_backend="nccl", + rank=int(os.environ["SLURM_PROCID"]), + local_rank=int(os.environ["SLURM_LOCALID"]), + master_addr=hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0], + master_port=str(12345 + int(min(gpu_ids))), + ) + + if int(os.environ["SLURM_PROCID"]) == 0: + logger.info(model) + logger.info(f"Training config: {training_config}\n") + + callbacks = [] + + # Only log to wandb if main process + if args.use_wandb and (training_config.rank == 0 or training_config == -1): + from pythae.trainers.training_callbacks import WandbCallback + + wandb_cb = WandbCallback() + wandb_cb.setup( + training_config, + model_config=model_config, + project_name=args.wandb_project, + entity_name=args.wandb_entity, + ) + + callbacks.append(wandb_cb) + + trainer = BaseTrainer( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + training_config=training_config, + callbacks=callbacks, + ) + + start_time = time.time() + + trainer.train() + + end_time = time.time() + + logger.info(f"Total execution time: {(end_time - start_time)} seconds") + + +if __name__ == "__main__": + + main(args) diff --git a/examples/scripts/distributed_training_mnist.py b/examples/scripts/distributed_training_mnist.py new file mode 100644 index 00000000..d1110ae9 --- /dev/null +++ b/examples/scripts/distributed_training_mnist.py @@ -0,0 +1,140 @@ +import argparse +import logging +import os +import time + +import hostlist +import numpy as np +import torch +from torch.utils.data import Dataset + +from pythae.data.datasets import DatasetOutput +from pythae.models import VQVAE, VQVAEConfig +from pythae.models.nn.benchmarks.mnist import ( + Decoder_ResNet_VQVAE_MNIST, + Encoder_ResNet_VQVAE_MNIST, +) +from pythae.trainers import BaseTrainer, BaseTrainerConfig + +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + +PATH = os.path.dirname(os.path.abspath(__file__)) + +ap = argparse.ArgumentParser() + +# Training setting +ap.add_argument( + "--use_wandb", + help="whether to log the metrics in wandb", + action="store_true", +) +ap.add_argument( + "--wandb_project", + help="wandb project name", + default="mnist-distributed", +) +ap.add_argument( + "--wandb_entity", + help="wandb entity name", + default="clementchadebec", +) + +args = ap.parse_args() + + +class MNIST(Dataset): + def __init__(self, data): + self.data = data.type(torch.float) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + x = self.data[index] + return DatasetOutput(data=x) + + +def main(args): + + ### Load data + train_data = torch.tensor( + np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))["data"] / 255.0 + ) + eval_data = torch.tensor( + np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] / 255.0 + ) + + train_dataset = MNIST(train_data) + eval_dataset = MNIST(eval_data) + + model_config = VQVAEConfig( + input_dim=(1, 28, 28), latent_dim=16, use_ema=True, num_embeddings=256 + ) + + encoder = Encoder_ResNet_VQVAE_MNIST(model_config) + decoder = Decoder_ResNet_VQVAE_MNIST(model_config) + + model = VQVAE(model_config=model_config, encoder=encoder, decoder=decoder) + + gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") + + training_config = BaseTrainerConfig( + num_epochs=100, + output_dir="my_models_on_mnist", + per_device_train_batch_size=256, + per_device_eval_batch_size=256, + learning_rate=1e-3, + steps_saving=None, + steps_predict=None, + no_cuda=False, + world_size=int(os.environ["SLURM_NTASKS"]), + dist_backend="nccl", + rank=int(os.environ["SLURM_PROCID"]), + local_rank=int(os.environ["SLURM_LOCALID"]), + master_addr=hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0], + master_port=str(12345 + int(min(gpu_ids))), + ) + + if int(os.environ["SLURM_PROCID"]) == 0: + logger.info(model) + logger.info(f"Training config: {training_config}\n") + + callbacks = [] + + # Only log to wandb if main process + if args.use_wandb and (training_config.rank == 0 or training_config == -1): + from pythae.trainers.training_callbacks import WandbCallback + + wandb_cb = WandbCallback() + wandb_cb.setup( + training_config, + model_config=model_config, + project_name=args.wandb_project, + entity_name=args.wandb_entity, + ) + + callbacks.append(wandb_cb) + + trainer = BaseTrainer( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + training_config=training_config, + callbacks=callbacks, + ) + + start_time = time.time() + + trainer.train() + + end_time = time.time() + + logger.info(f"Total execution time: {(end_time - start_time)} seconds") + + +if __name__ == "__main__": + + main(args) diff --git a/examples/scripts/reproducibility/README.md b/examples/scripts/reproducibility/README.md index e135680f..616b10af 100644 --- a/examples/scripts/reproducibility/README.md +++ b/examples/scripts/reproducibility/README.md @@ -5,7 +5,7 @@ We validate the implementations by reproducing some results presented in the ori We gather here the scripts to train those models. They can be launched with the following commandline ```bash -python aae.py --model_config 'configs/celeba/aae/aae_config.json' --training_config 'configs/celeba/aae/base_training_config.json' +python aae.py ``` The folder structure should be as follows: @@ -13,11 +13,6 @@ The folder structure should be as follows: . โ”œโ”€โ”€ aae.py โ”œโ”€โ”€ betatcvae.py -โ”œโ”€โ”€ configs -โ”‚ย ย  โ”œโ”€โ”€ binary_mnist -โ”‚ย ย  โ”œโ”€โ”€ celeba -โ”‚ย ย  โ”œโ”€โ”€ dsprites -โ”‚ย ย  โ””โ”€โ”€ mnist โ”œโ”€โ”€ data โ”‚ย ย  โ”œโ”€โ”€ binary_mnist โ”‚ย ย  โ”œโ”€โ”€ celeba diff --git a/examples/scripts/reproducibility/aae.py b/examples/scripts/reproducibility/aae.py index cf465261..7acb6f24 100644 --- a/examples/scripts/reproducibility/aae.py +++ b/examples/scripts/reproducibility/aae.py @@ -1,18 +1,16 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor -from pythae.trainers import AdversarialTrainer, BaseTrainerConfig from pythae.models import Adversarial_AE, Adversarial_AE_Config -from pythae.models.nn import BaseEncoder, BaseDecoder, BaseDiscriminator -import torch.nn as nn from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseDiscriminator, BaseEncoder +from pythae.trainers import AdversarialTrainer, AdversarialTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -21,24 +19,6 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - - -args = ap.parse_args() - ### Define paper encoder network class Encoder(BaseEncoder): def __init__(self, args): @@ -154,8 +134,7 @@ def __init__(self, args: dict): layers.append( nn.Sequential( - nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), - nn.Sigmoid() + nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), nn.Sigmoid() ) ) @@ -199,6 +178,7 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output + ### Define paper discriminator network class Discriminator(BaseDiscriminator): def __init__(self, args: dict): @@ -217,7 +197,7 @@ def __init__(self, args: dict): nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512), - nn.ReLU() + nn.ReLU(), ) ) @@ -261,38 +241,56 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output -def main(args): +def main(): train_data = ( - np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))[ - "data" - ] - / 255.0 - ) - eval_data = ( - np.load(os.path.join(PATH, f"data/celeba", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))["data"] / 255.0 ) - - data_input_dim = tuple(train_data.shape[1:]) - if args.model_config is not None: - model_config = Adversarial_AE_Config.from_json_file(args.model_config) - - else: - model_config = Adversarial_AE_Config() + data_input_dim = tuple(train_data.shape[1:]) - model_config.input_dim = data_input_dim + model_config = Adversarial_AE_Config( + input_dim=data_input_dim, + latent_dim=64, + reconstruction_loss="mse", + adversarial_loss_scale=0.5, + reconstruction_loss_scale=0.05, + deterministic_posterior=True, + ) model = Adversarial_AE( model_config=model_config, encoder=Encoder(model_config), decoder=Decoder(model_config), - discriminator=Discriminator(model_config) + discriminator=Discriminator(model_config), ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = AdversarialTrainerConfig( + output_dir="my_models_on_celeba", + per_device_train_batch_size=100, + per_device_eval_batch_size=100, + num_epochs=100, + autoencoder_learning_rate=3e-4, + discriminator_learning_rate=1e-3, + steps_saving=3, + steps_predict=1000, + no_cuda=False, + autoencoder_scheduler_cls="LambdaLR", + autoencoder_scheduler_params={ + "lr_lambda": lambda epoch: 1 * (epoch < 30) + + 0.5 * (30 <= epoch < 50) + + 0.2 * (50 <= epoch), + "verbose": True, + }, + discriminator_scheduler_cls="LambdaLR", + discriminator_scheduler_params={ + "lr_lambda": lambda epoch: 1 * (epoch < 30) + + 0.5 * (30 <= epoch < 50) + + 0.2 * (50 <= epoch), + "verbose": True, + }, + ) ### Process data data_processor = DataProcessor() @@ -300,26 +298,6 @@ def main(args): train_data = data_processor.process_data(train_data) train_dataset = data_processor.to_dataset(train_data) - logger.info("Preprocessing eval data...\n") - eval_data = data_processor.process_data(eval_data) - eval_dataset = data_processor.to_dataset(eval_data) - - import itertools - - ### Optimizers - ae_optimizer = torch.optim.Adam(itertools.chain(model.encoder.parameters(), model.decoder.parameters()), lr=3e-4) - dis_optimizer = torch.optim.Adam(model.discriminator.parameters(), lr=1e-3) - - ### Schedulers - lambda_lr = lambda epoch: 1 * (epoch < 30) + 0.5 * (30 <= epoch < 50) + 0.2 * (50 <= epoch) - - ae_scheduler = torch.optim.lr_scheduler.LambdaLR( - ae_optimizer, lr_lambda=lambda_lr, verbose=True - ) - dis_scheduler = torch.optim.lr_scheduler.LambdaLR( - dis_optimizer, lr_lambda=lambda_lr, verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -329,15 +307,12 @@ def main(args): model=model, train_dataset=train_dataset, training_config=training_config, - autoencoder_optimizer=ae_optimizer, - discriminator_optimizer=dis_optimizer, - autoencoder_scheduler=ae_scheduler, - discriminator_scheduler=dis_scheduler, callbacks=None, ) trainer.train() + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/betatcvae.py b/examples/scripts/reproducibility/betatcvae.py index ddbea655..b9906bde 100644 --- a/examples/scripts/reproducibility/betatcvae.py +++ b/examples/scripts/reproducibility/betatcvae.py @@ -1,19 +1,16 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor from pythae.models import BetaTCVAE, BetaTCVAEConfig -from pythae.trainers import BaseTrainer, BaseTrainerConfig - -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -22,23 +19,6 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() - ### Define paper encoder network class Encoder(BaseEncoder): def __init__(self, args: dict): @@ -53,7 +33,7 @@ def __init__(self, args: dict): nn.Linear(np.prod(args.input_dim), 1200), nn.ReLU(inplace=True), nn.Linear(1200, 1200), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) ) @@ -63,7 +43,6 @@ def __init__(self, args: dict): self.layers = layers self.depth = len(layers) - def forward(self, x, output_layer_levels: List[int] = None): output = ModelOutput() @@ -74,9 +53,7 @@ def forward(self, x, output_layer_levels: List[int] = None): assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels - ), ( - f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." - ) + ), f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." if -1 in output_layer_levels: max_depth = self.depth @@ -97,6 +74,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -115,7 +93,7 @@ def __init__(self, args: dict): nn.Linear(1200, 1200), nn.Tanh(), nn.Linear(1200, np.prod(args.input_dim)), - nn.Sigmoid() + nn.Sigmoid(), ) self.layers = layers @@ -156,22 +134,27 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output +def main(): -def main(args): + data = np.load( + os.path.join( + PATH, f"data/dsprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz" + ), + encoding="latin1", + ) + train_data = torch.from_numpy(data["imgs"]).float() - data = np.load(os.path.join(PATH, f"data/dsprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"), encoding='latin1') - train_data = torch.from_numpy(data['imgs']).float() - data_input_dim = tuple(train_data.shape[1:]) - ### Build the model - if args.model_config is not None: - model_config = BetaTCVAEConfig.from_json_file(args.model_config) - - else: - model_config = BetaTCVAEConfig() - - model_config.input_dim = data_input_dim + model_config = BetaTCVAEConfig( + input_dim=data_input_dim, + latent_dim=10, + reconstruction_loss="bce", + beta=6, + gamma=1, + alpha=1, + use_mss=False, + ) model = BetaTCVAE( model_config=model_config, @@ -180,7 +163,18 @@ def main(args): ) ### Set the training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/dsprites", + per_device_train_batch_size=1000, + per_device_eval_batch_size=1000, + num_epochs=50, + learning_rate=1e-3, + steps_saving=None, + steps_predict=None, + no_cuda=False, + optimizer_cls="Adam", + optimizer_params={"eps": 1e-4}, + ) ### Process data data_processor = DataProcessor() @@ -188,14 +182,6 @@ def main(args): train_data = data_processor.process_data(train_data) train_dataset = data_processor.to_dataset(train_data) - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate, eps=1e-4) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[1000], gamma=10**(-1/7), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -205,15 +191,12 @@ def main(args): model=model, train_dataset=train_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) - print(trainer.scheduler) - trainer.train() - + + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/ciwae.py b/examples/scripts/reproducibility/ciwae.py index f1ebf9fd..491f049b 100644 --- a/examples/scripts/reproducibility/ciwae.py +++ b/examples/scripts/reproducibility/ciwae.py @@ -1,22 +1,17 @@ -import argparse import logging import os from typing import List import numpy as np import torch - -from pythae.data.preprocessors import DataProcessor -from pythae.models import AutoModel -from pythae.models import CIWAE, CIWAEConfig -from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.data.datasets import DatasetOutput +import torch.nn as nn from torch.utils.data import Dataset -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.data.datasets import DatasetOutput +from pythae.models import CIWAE, AutoModel, CIWAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -25,30 +20,15 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() +device = "cuda" if torch.cuda.is_available() else "cpu" -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() def unif_init(m, n_in, n_out): - scale = np.sqrt(6./(n_in+n_out)) - m.weight.data.uniform_( - -scale, scale - ) + scale = np.sqrt(6.0 / (n_in + n_out)) + m.weight.data.uniform_(-scale, scale) m.bias.data = torch.zeros((1, n_out)) + ### Define paper encoder network class Encoder(BaseEncoder): def __init__(self, args: dict): @@ -56,7 +36,6 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.fc1 = nn.Linear(np.prod(args.input_dim), 200) self.fc2 = nn.Linear(200, 200) @@ -77,6 +56,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -105,6 +85,7 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output + class DynBinarizedMNIST(Dataset): def __init__(self, data): self.data = data.type(torch.float) @@ -118,37 +99,30 @@ def __getitem__(self, index): return DatasetOutput(data=x) -def main(args): +def main(): - - ### Load data train_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))[ - "data" - ] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))["data"] / 255.0 ) eval_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] / 255.0 ) train_data = torch.cat((train_data, eval_data)) test_data = ( - np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] / 255.0 ) data_input_dim = tuple(train_data.shape[1:]) - if args.model_config is not None: - model_config = CIWAEConfig.from_json_file(args.model_config) - - else: - model_config = CIWAEConfig() - - model_config.input_dim = data_input_dim + model_config = CIWAEConfig( + input_dim=data_input_dim, + latent_dim=50, + reconstruction_loss="bce", + number_samples=64, + beta=0.5, + ) model = CIWAE( model_config=model_config, @@ -157,20 +131,29 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/mnist", + per_device_train_batch_size=20, + per_device_eval_batch_size=20, + num_epochs=3280, + learning_rate=1e-3, + steps_saving=1000, + steps_predict=None, + no_cuda=False, + optimizer_cls="Adam", + optimizer_params={"eps": 1e-4}, + scheduler_cls="MultiStepLR", + scheduler_params={ + "milestones": [2, 5, 14, 28, 41, 122, 365, 1094], + "gamma": 10 ** (-1 / 7), + "verbose": True, + }, + ) ### Process data logger.info("Preprocessing train data...") train_dataset = DynBinarizedMNIST(train_data) - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate, eps=1e-4) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[2, 5, 14, 28, 41, 122, 365, 1094], gamma=10**(-1/7), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -179,20 +162,31 @@ def main(args): trainer = BaseTrainer( model=model, train_dataset=train_dataset, - eval_dataset=None,#eval_dataset, + eval_dataset=None, # eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) trainer.train() ### Reload model - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) - test_data = (test_data > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device)).float() + test_data = ( + test_data + > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device) + ).float() ### Compute NLL with torch.no_grad(): @@ -201,14 +195,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=5000, batch_size=5000) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/configs/binary_mnist/hvae/base_training_config.json b/examples/scripts/reproducibility/configs/binary_mnist/hvae/base_training_config.json deleted file mode 100644 index 2fc8681f..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/hvae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 100, - "num_epochs": 2000, - "learning_rate": 5e-4, - "steps_saving": 50, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/hvae/hvae_config.json b/examples/scripts/reproducibility/configs/binary_mnist/hvae/hvae_config.json deleted file mode 100644 index 3e262677..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/hvae/hvae_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "HVAEConfig", - "latent_dim": 64, - "reconstruction_loss": "bce", - "n_lf": 4, - "eps_lf": 0.05, - "beta_zero": 1, - "learn_eps_lf": false, - "learn_beta_zero": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/iwae/base_training_config.json b/examples/scripts/reproducibility/configs/binary_mnist/iwae/base_training_config.json deleted file mode 100644 index 47a19b26..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/iwae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 20, - "num_epochs": 3280, - "learning_rate": 1e-3, - "steps_saving": null, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/iwae/iwae_config.json b/examples/scripts/reproducibility/configs/binary_mnist/iwae/iwae_config.json deleted file mode 100644 index 2ce782f9..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/iwae/iwae_config.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "IWAEConfig", - "latent_dim": 50, - "reconstruction_loss": "bce", - "number_samples": 50 -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/svae/base_training_config.json b/examples/scripts/reproducibility/configs/binary_mnist/svae/base_training_config.json deleted file mode 100644 index 1eff96f8..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/svae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 64, - "num_epochs": 500, - "learning_rate": 1e-3, - "steps_saving": 50, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/svae/svae_config.json b/examples/scripts/reproducibility/configs/binary_mnist/svae/svae_config.json deleted file mode 100644 index 3f239786..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/svae/svae_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "name": "SVAEConfig", - "latent_dim": 11, - "reconstruction_loss": "bce" -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vae/base_training_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vae/base_training_config.json deleted file mode 100644 index c773827e..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 100, - "num_epochs": 500, - "learning_rate": 1e-4, - "steps_saving": null, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vae/vae_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vae/vae_config.json deleted file mode 100644 index 0c561b58..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vae/vae_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "name": "VAEConfig", - "latent_dim": 40, - "reconstruction_loss": "bce" -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vae_iaf/base_training_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vae_iaf/base_training_config.json deleted file mode 100644 index 15457cf3..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vae_iaf/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 200, - "num_epochs": 2000, - "learning_rate": 1e-3, - "steps_saving": 500, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vae_iaf/vae_iaf_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vae_iaf/vae_iaf_config.json deleted file mode 100644 index d21c556e..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vae_iaf/vae_iaf_config.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "VAE_IAF_Config", - "latent_dim": 32, - "reconstruction_loss": "bce", - "n_made_blocks": 2, - "n_hidden_in_made":2, - "hidden_size": 320 -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vae_lin_nf/base_training_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vae_lin_nf/base_training_config.json deleted file mode 100644 index 1cb9e9c1..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vae_lin_nf/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 100, - "num_epochs": 10000, - "learning_rate": 1e-5, - "steps_saving": 1000, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vae_lin_nf/vae_lin_nf_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vae_lin_nf/vae_lin_nf_config.json deleted file mode 100644 index 2e629ef7..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vae_lin_nf/vae_lin_nf_config.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "VAE_LinNF_Config", - "latent_dim": 40, - "reconstruction_loss": "bce", - "flows": ["Planar","Planar","Planar","Planar","Planar","Planar","Planar","Planar","Planar","Planar"] -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vamp/base_training_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vamp/base_training_config.json deleted file mode 100644 index 91cb6b1c..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vamp/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 100, - "num_epochs": 2000, - "learning_rate": 1e-4, - "steps_saving": null, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/binary_mnist/vamp/vamp_config.json b/examples/scripts/reproducibility/configs/binary_mnist/vamp/vamp_config.json deleted file mode 100644 index 4ed8e260..00000000 --- a/examples/scripts/reproducibility/configs/binary_mnist/vamp/vamp_config.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "name": "VAMPConfig", - "latent_dim": 40, - "reconstruction_loss": "bce", - "number_components": 500, - "linear_scheduling_steps": 100 -} diff --git a/examples/scripts/reproducibility/configs/celeba/aae/aae_config.json b/examples/scripts/reproducibility/configs/celeba/aae/aae_config.json deleted file mode 100644 index eceb4632..00000000 --- a/examples/scripts/reproducibility/configs/celeba/aae/aae_config.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "AdversarialAEConfig", - "latent_dim": 64, - "reconstruction_loss": "mse", - "adversarial_loss_scale": 0.5, - "reconstruction_loss_scale": 0.05, - "deterministic_posterior": true - } diff --git a/examples/scripts/reproducibility/configs/celeba/aae/base_training_config.json b/examples/scripts/reproducibility/configs/celeba/aae/base_training_config.json deleted file mode 100644 index 09ad4955..00000000 --- a/examples/scripts/reproducibility/configs/celeba/aae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "my_models_on_celeba", - "batch_size": 100, - "num_epochs": 100, - "learning_rate": 0.0001, - "steps_saving": 3, - "steps_predict": 1000, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/celeba/rae_gp/base_training_config.json b/examples/scripts/reproducibility/configs/celeba/rae_gp/base_training_config.json deleted file mode 100644 index 978c0ff9..00000000 --- a/examples/scripts/reproducibility/configs/celeba/rae_gp/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "my_models_on_celeba", - "batch_size": 100, - "num_epochs": 100, - "learning_rate": 0.001, - "steps_saving": null, - "steps_predict": 100, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/celeba/rae_gp/rae_gp_config.json b/examples/scripts/reproducibility/configs/celeba/rae_gp/rae_gp_config.json deleted file mode 100644 index 162d740d..00000000 --- a/examples/scripts/reproducibility/configs/celeba/rae_gp/rae_gp_config.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "RAE_GP_Config", - "latent_dim": 16, - "embedding_weight": 1e-3, - "reg_weight": 1e-6 - } \ No newline at end of file diff --git a/examples/scripts/reproducibility/configs/celeba/rae_l2/base_training_config.json b/examples/scripts/reproducibility/configs/celeba/rae_l2/base_training_config.json deleted file mode 100644 index 978c0ff9..00000000 --- a/examples/scripts/reproducibility/configs/celeba/rae_l2/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "my_models_on_celeba", - "batch_size": 100, - "num_epochs": 100, - "learning_rate": 0.001, - "steps_saving": null, - "steps_predict": 100, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/celeba/rae_l2/rae_l2_config.json b/examples/scripts/reproducibility/configs/celeba/rae_l2/rae_l2_config.json deleted file mode 100644 index c3daf434..00000000 --- a/examples/scripts/reproducibility/configs/celeba/rae_l2/rae_l2_config.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "RAE_L2_Config", - "latent_dim": 16, - "embedding_weight": 1e-6, - "reg_weight": 1e-3 - } diff --git a/examples/scripts/reproducibility/configs/celeba/wae/base_training_config.json b/examples/scripts/reproducibility/configs/celeba/wae/base_training_config.json deleted file mode 100644 index ab5e41a2..00000000 --- a/examples/scripts/reproducibility/configs/celeba/wae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "my_models_on_celeba", - "batch_size": 100, - "num_epochs": 100, - "learning_rate": 0.0001, - "steps_saving": 3, - "steps_predict": 100, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/celeba/wae/wae_config.json b/examples/scripts/reproducibility/configs/celeba/wae/wae_config.json deleted file mode 100644 index 4a4045b8..00000000 --- a/examples/scripts/reproducibility/configs/celeba/wae/wae_config.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "WAE_MMD_Config", - "latent_dim": 64, - "kernel_choice": "imq", - "reg_weight": 100, - "kernel_bandwidth": 2.0, - "reconstruction_loss_scale": 0.05 -} diff --git a/examples/scripts/reproducibility/configs/dsprites/beta_tc_vae/base_training_config.json b/examples/scripts/reproducibility/configs/dsprites/beta_tc_vae/base_training_config.json deleted file mode 100644 index 43d7e70d..00000000 --- a/examples/scripts/reproducibility/configs/dsprites/beta_tc_vae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/dsprites", - "batch_size": 1000, - "num_epochs": 50, - "learning_rate": 1e-3, - "steps_saving": null, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/dsprites/beta_tc_vae/beta_tc_vae_config.json b/examples/scripts/reproducibility/configs/dsprites/beta_tc_vae/beta_tc_vae_config.json deleted file mode 100644 index b58df6a5..00000000 --- a/examples/scripts/reproducibility/configs/dsprites/beta_tc_vae/beta_tc_vae_config.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "name": "BetaTCVAEConfig", - "latent_dim": 10, - "reconstruction_loss": "bce", - "beta": 6, - "gamma": 1, - "alpha": 1, - "use_mss": false -} \ No newline at end of file diff --git a/examples/scripts/reproducibility/configs/dsprites/factorvae/base_training_config.json b/examples/scripts/reproducibility/configs/dsprites/factorvae/base_training_config.json deleted file mode 100644 index 86d37bba..00000000 --- a/examples/scripts/reproducibility/configs/dsprites/factorvae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/dsprites", - "batch_size": 64, - "num_epochs": 500, - "learning_rate": 1e-3, - "steps_saving": 50, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/dsprites/factorvae/factorvae_config.json b/examples/scripts/reproducibility/configs/dsprites/factorvae/factorvae_config.json deleted file mode 100644 index 7051250f..00000000 --- a/examples/scripts/reproducibility/configs/dsprites/factorvae/factorvae_config.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "FactorVAEConfig", - "latent_dim": 10, - "reconstruction_loss": "bce", - "gamma": 35 -} diff --git a/examples/scripts/reproducibility/configs/mnist/ciwae/base_training_config.json b/examples/scripts/reproducibility/configs/mnist/ciwae/base_training_config.json deleted file mode 100644 index 08d9e7d9..00000000 --- a/examples/scripts/reproducibility/configs/mnist/ciwae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/mnist", - "batch_size": 20, - "num_epochs": 3280, - "learning_rate": 1e-3, - "steps_saving": 1000, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/mnist/ciwae/ciwae_config.json b/examples/scripts/reproducibility/configs/mnist/ciwae/ciwae_config.json deleted file mode 100644 index fce3f6c0..00000000 --- a/examples/scripts/reproducibility/configs/mnist/ciwae/ciwae_config.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "name": "CIWAEConfig", - "latent_dim": 50, - "reconstruction_loss": "bce", - "number_samples": 64, - "beta": 0.5 -} diff --git a/examples/scripts/reproducibility/configs/mnist/miwae/base_training_config.json b/examples/scripts/reproducibility/configs/mnist/miwae/base_training_config.json deleted file mode 100644 index 08d9e7d9..00000000 --- a/examples/scripts/reproducibility/configs/mnist/miwae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/mnist", - "batch_size": 20, - "num_epochs": 3280, - "learning_rate": 1e-3, - "steps_saving": 1000, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/mnist/miwae/miwae_config.json b/examples/scripts/reproducibility/configs/mnist/miwae/miwae_config.json deleted file mode 100644 index 707bdb95..00000000 --- a/examples/scripts/reproducibility/configs/mnist/miwae/miwae_config.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "name": "MIWAEConfig", - "latent_dim": 50, - "reconstruction_loss": "bce", - "number_samples": 8, - "number_gradient_estimates": 8 -} diff --git a/examples/scripts/reproducibility/configs/mnist/piwae/base_training_config.json b/examples/scripts/reproducibility/configs/mnist/piwae/base_training_config.json deleted file mode 100644 index 08d9e7d9..00000000 --- a/examples/scripts/reproducibility/configs/mnist/piwae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/mnist", - "batch_size": 20, - "num_epochs": 3280, - "learning_rate": 1e-3, - "steps_saving": 1000, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/mnist/piwae/piwae_config.json b/examples/scripts/reproducibility/configs/mnist/piwae/piwae_config.json deleted file mode 100644 index 0f186c8c..00000000 --- a/examples/scripts/reproducibility/configs/mnist/piwae/piwae_config.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "name": "PIWAEConfig", - "latent_dim": 50, - "reconstruction_loss": "bce", - "number_samples": 8, - "number_gradient_estimates": 8 -} diff --git a/examples/scripts/reproducibility/configs/mnist/pvae/base_training_config.json b/examples/scripts/reproducibility/configs/mnist/pvae/base_training_config.json deleted file mode 100644 index f068e8c0..00000000 --- a/examples/scripts/reproducibility/configs/mnist/pvae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/mnist", - "batch_size": 128, - "num_epochs": 80, - "learning_rate": 5e-4, - "steps_saving": 100, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/mnist/pvae/pvae_config.json b/examples/scripts/reproducibility/configs/mnist/pvae/pvae_config.json deleted file mode 100644 index 6a010425..00000000 --- a/examples/scripts/reproducibility/configs/mnist/pvae/pvae_config.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "PoincareVAEConfig", - "latent_dim": 10, - "reconstruction_loss": "bce", - "prior_distribution": "wrapped_normal", - "posterior_distribution": "wrapped_normal", - "curvature": 0.7 -} diff --git a/examples/scripts/reproducibility/configs/mnist/svae/base_training_config.json b/examples/scripts/reproducibility/configs/mnist/svae/base_training_config.json deleted file mode 100644 index 1eff96f8..00000000 --- a/examples/scripts/reproducibility/configs/mnist/svae/base_training_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "BaseTrainerConfig", - "output_dir": "reproducibility/binary_mnist", - "batch_size": 64, - "num_epochs": 500, - "learning_rate": 1e-3, - "steps_saving": 50, - "steps_predict": null, - "no_cuda": false -} diff --git a/examples/scripts/reproducibility/configs/mnist/svae/svae_config.json b/examples/scripts/reproducibility/configs/mnist/svae/svae_config.json deleted file mode 100644 index 3f239786..00000000 --- a/examples/scripts/reproducibility/configs/mnist/svae/svae_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "name": "SVAEConfig", - "latent_dim": 11, - "reconstruction_loss": "bce" -} diff --git a/examples/scripts/reproducibility/hvae.py b/examples/scripts/reproducibility/hvae.py index 2c46b2b7..4dc7897a 100644 --- a/examples/scripts/reproducibility/hvae.py +++ b/examples/scripts/reproducibility/hvae.py @@ -1,20 +1,16 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor -from pythae.models import AutoModel -from pythae.models import HVAE, HVAEConfig -from pythae.trainers import BaseTrainerConfig, BaseTrainer - -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.models import HVAE, AutoModel, HVAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -23,22 +19,7 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() +device = "cuda" if torch.cuda.is_available() else "cpu" ### Define paper encoder network class Encoder(BaseEncoder): @@ -50,10 +31,7 @@ def __init__(self, args: dict): layers = nn.ModuleList() layers.append( - nn.Sequential( - nn.Linear(np.prod(args.input_dim), 300), - nn.Softplus() - ) + nn.Sequential(nn.Linear(np.prod(args.input_dim), 300), nn.Softplus()) ) self.layers = layers @@ -62,7 +40,6 @@ def __init__(self, args: dict): self.embedding = nn.Linear(300, self.latent_dim) self.log_var = nn.Linear(300, self.latent_dim) - def forward(self, x, output_layer_levels: List[int] = None): output = ModelOutput() @@ -73,9 +50,7 @@ def forward(self, x, output_layer_levels: List[int] = None): assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels - ), ( - f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." - ) + ), f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." if -1 in output_layer_levels: max_depth = self.depth @@ -96,6 +71,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -103,21 +79,14 @@ def __init__(self, args: dict): self.input_dim = args.input_dim - # assert 0, np.prod(args.input_dim) - layers = nn.ModuleList() - layers.append( - nn.Sequential( - nn.Linear(args.latent_dim, 300), - nn.Softplus() - ) - ) + layers.append(nn.Sequential(nn.Linear(args.latent_dim, 300), nn.Softplus())) layers.append( nn.Sequential( nn.Linear(300, np.prod(args.input_dim)), - #nn.Sigmoid() + # nn.Sigmoid() ) ) @@ -159,23 +128,30 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output -def main(args): +def main(): - train_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat")).reshape(-1, 1, 28, 28) - eval_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat")).reshape(-1, 1, 28, 28) - test_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat")).reshape(-1, 1, 28, 28) + train_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat") + ).reshape(-1, 1, 28, 28) + eval_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat") + ).reshape(-1, 1, 28, 28) + test_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat") + ).reshape(-1, 1, 28, 28) data_input_dim = tuple(train_data.shape[1:]) - - ### Build the model - if args.model_config is not None: - model_config = HVAEConfig.from_json_file(args.model_config) - - else: - model_config = HVAEConfig() - - model_config.input_dim = data_input_dim + model_config = HVAEConfig( + input_dim=data_input_dim, + latent_dim=64, + reconstruction_loss="bce", + n_lf=4, + eps_lf=0.05, + beta_zero=1, + learn_eps_lf=False, + learn_beta_zero=False, + ) model = HVAE( model_config=model_config, @@ -183,9 +159,24 @@ def main(args): decoder=Decoder(model_config), ) - ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/binary_mnist", + per_device_train_batch_size=100, + per_device_eval_batch_size=100, + num_epochs=2000, + learning_rate=5e-4, + steps_saving=50, + steps_predict=None, + no_cuda=False, + optimizer_cls="Adamax", + scheduler_cls="MultiStepLR", + scheduler_params={ + "milestones": [200, 350, 500, 750, 1000], + "gamma": 10 ** (-1 / 5), + "verbose": True, + }, + ) ### Process data data_processor = DataProcessor() @@ -197,14 +188,6 @@ def main(args): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) - ### Optimizer - optimizer = torch.optim.Adamax(model.parameters(), lr=training_config.learning_rate) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[200, 350, 500, 750, 1000], gamma=10**(-1/5), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -215,17 +198,23 @@ def main(args): train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) - print(trainer.scheduler) - trainer.train() - + ### Reload the model - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) @@ -236,14 +225,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=1000) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/iwae.py b/examples/scripts/reproducibility/iwae.py index 58e11f42..97d4b237 100644 --- a/examples/scripts/reproducibility/iwae.py +++ b/examples/scripts/reproducibility/iwae.py @@ -1,20 +1,16 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor -from pythae.models import AutoModel -from pythae.models import IWAE, IWAEConfig -from pythae.trainers import BaseTrainer, BaseTrainerConfig - -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.models import IWAE, AutoModel, IWAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -23,30 +19,15 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) +device = "cuda" if torch.cuda.is_available() else "cpu" -args = ap.parse_args() def unif_init(m, n_in, n_out): - scale = np.sqrt(6./(n_in+n_out)) - m.weight.data.uniform_( - -scale, scale - ) + scale = np.sqrt(6.0 / (n_in + n_out)) + m.weight.data.uniform_(-scale, scale) m.bias.data = torch.zeros((1, n_out)) + ### Define paper encoder network class Encoder(BaseEncoder): def __init__(self, args: dict): @@ -54,7 +35,6 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.fc1 = nn.Linear(np.prod(args.input_dim), 200) self.fc2 = nn.Linear(200, 200) @@ -75,6 +55,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -104,24 +85,32 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output +def main(): -def main(args): - - - - train_data = torch.tensor(np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat"))).type(torch.float) - eval_data = torch.tensor(np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat"))).type(torch.float) - test_data = torch.tensor(np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat"))).type(torch.float) + train_data = torch.tensor( + np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat") + ) + ).type(torch.float) + eval_data = torch.tensor( + np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat") + ) + ).type(torch.float) + test_data = torch.tensor( + np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat") + ) + ).type(torch.float) data_input_dim = tuple(train_data.shape[1:]) - if args.model_config is not None: - model_config = IWAEConfig.from_json_file(args.model_config) - - else: - model_config = IWAEConfig() - - model_config.input_dim = data_input_dim + model_config = IWAEConfig( + input_dim=data_input_dim, + latent_dim=50, + reconstruction_loss="bce", + number_samples=50, + ) model = IWAE( model_config=model_config, @@ -130,26 +119,35 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/binary_mnist", + per_device_train_batch_size=20, + per_device_eval_batch_size=20, + num_epochs=3280, + learning_rate=1e-3, + steps_saving=None, + steps_predict=None, + no_cuda=False, + optimizer_cls="Adam", + optimizer_params={"eps": 1e-4}, + scheduler_cls="MultiStepLR", + scheduler_params={ + "milestones": [2, 5, 14, 28, 41, 122, 365, 1094], + "gamma": 10 ** (-1 / 7), + "verbose": True, + }, + ) ### Process data data_processor = DataProcessor() logger.info("Preprocessing train data...") - #train_data = data_processor.process_data(train_data) + # train_data = data_processor.process_data(train_data) train_dataset = data_processor.to_dataset(train_data) logger.info("Preprocessing eval data...\n") - #ieval_data = data_processor.process_data(eval_data) + # ieval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate, eps=1e-4) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[2, 5, 14, 28, 41, 122, 365, 1094], gamma=10**(-1/7), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -160,15 +158,23 @@ def main(args): train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) trainer.train() ### Reload model - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) @@ -179,14 +185,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=5000, batch_size=5000) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/miwae.py b/examples/scripts/reproducibility/miwae.py index 45f29c04..1a25a05d 100644 --- a/examples/scripts/reproducibility/miwae.py +++ b/examples/scripts/reproducibility/miwae.py @@ -1,22 +1,17 @@ -import argparse import logging import os from typing import List import numpy as np import torch - -from pythae.data.preprocessors import DataProcessor -from pythae.models import AutoModel -from pythae.models import MIWAE, MIWAEConfig -from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.data.datasets import DatasetOutput +import torch.nn as nn from torch.utils.data import Dataset -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.data.datasets import DatasetOutput +from pythae.models import MIWAE, AutoModel, MIWAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -25,30 +20,15 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = "cuda" if torch.cuda.is_available() else "cpu" -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() def unif_init(m, n_in, n_out): - scale = np.sqrt(6./(n_in+n_out)) - m.weight.data.uniform_( - -scale, scale - ) + scale = np.sqrt(6.0 / (n_in + n_out)) + m.weight.data.uniform_(-scale, scale) m.bias.data = torch.zeros((1, n_out)) + ### Define paper encoder network class Encoder(BaseEncoder): def __init__(self, args: dict): @@ -56,7 +36,6 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.fc1 = nn.Linear(np.prod(args.input_dim), 200) self.fc2 = nn.Linear(200, 200) @@ -77,6 +56,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -119,37 +99,30 @@ def __getitem__(self, index): return DatasetOutput(data=x) -def main(args): +def main(): - - ### Load data train_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))[ - "data" - ] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))["data"] / 255.0 ) eval_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] / 255.0 ) train_data = torch.cat((train_data, eval_data)) test_data = ( - np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] / 255.0 ) data_input_dim = tuple(train_data.shape[1:]) - if args.model_config is not None: - model_config = MIWAEConfig.from_json_file(args.model_config) - - else: - model_config = MIWAEConfig() - - model_config.input_dim = data_input_dim + model_config = MIWAEConfig( + input_dim=data_input_dim, + latent_dim=50, + reconstruction_loss="bce", + number_samples=8, + number_gradient_estimates=8, + ) model = MIWAE( model_config=model_config, @@ -158,19 +131,28 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/mnist", + per_device_train_batch_size=20, + per_device_eval_batch_size=20, + num_epochs=3280, + learning_rate=1e-3, + steps_saving=1000, + steps_predict=None, + no_cuda=False, + optimizer_cls="Adam", + optimizer_params={"eps": 1e-4}, + scheduler_cls="MultiStepLR", + scheduler_params={ + "milestones": [2, 5, 14, 28, 41, 122, 365, 1094], + "gamma": 10 ** (-1 / 7), + "verbose": True, + }, + ) logger.info("Preprocessing train data...") train_dataset = DynBinarizedMNIST(train_data) - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate, eps=1e-4) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[2, 5, 14, 28, 41, 122, 365, 1094], gamma=10**(-1/7), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -179,20 +161,31 @@ def main(args): trainer = BaseTrainer( model=model, train_dataset=train_dataset, - eval_dataset=None,#eval_dataset, + eval_dataset=None, # eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) trainer.train() ### Reload model - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) - test_data = (test_data > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device)).float() + test_data = ( + test_data + > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device) + ).float() ### Compute NLL with torch.no_grad(): @@ -201,14 +194,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=5000, batch_size=5000) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/piwae.py b/examples/scripts/reproducibility/piwae.py index 6f2267bf..24cc6c9d 100644 --- a/examples/scripts/reproducibility/piwae.py +++ b/examples/scripts/reproducibility/piwae.py @@ -1,23 +1,17 @@ -import argparse import logging import os from typing import List import numpy as np import torch - -from pythae.data.preprocessors import DataProcessor -from pythae.models import AutoModel -from pythae.models import PIWAE, PIWAEConfig -from pythae.trainers import CoupledOptimizerTrainerConfig, CoupledOptimizerTrainer -from pythae.data.datasets import DatasetOutput +import torch.nn as nn from torch.utils.data import Dataset - -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.data.datasets import DatasetOutput +from pythae.models import PIWAE, AutoModel, PIWAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import CoupledOptimizerTrainer, CoupledOptimizerTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -26,30 +20,15 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = "cuda" if torch.cuda.is_available() else "cpu" -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() def unif_init(m, n_in, n_out): - scale = np.sqrt(6./(n_in+n_out)) - m.weight.data.uniform_( - -scale, scale - ) + scale = np.sqrt(6.0 / (n_in + n_out)) + m.weight.data.uniform_(-scale, scale) m.bias.data = torch.zeros((1, n_out)) + ### Define paper encoder network class Encoder(BaseEncoder): def __init__(self, args: dict): @@ -57,7 +36,6 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.fc1 = nn.Linear(np.prod(args.input_dim), 200) self.fc2 = nn.Linear(200, 200) @@ -78,6 +56,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + class DynBinarizedMNIST(Dataset): def __init__(self, data): self.data = data.type(torch.float) @@ -90,6 +69,7 @@ def __getitem__(self, index): x = (x > torch.distributions.Uniform(0, 1).sample(x.shape)).float() return DatasetOutput(data=x) + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -119,38 +99,30 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output +def main(): -def main(args): - - - ### Load data train_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))[ - "data" - ] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))["data"] / 255.0 ) eval_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] / 255.0 ) train_data = torch.cat((train_data, eval_data)) test_data = ( - np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] / 255.0 ) data_input_dim = tuple(train_data.shape[1:]) - if args.model_config is not None: - model_config = PIWAEConfig.from_json_file(args.model_config) - - else: - model_config = PIWAEConfig() - - model_config.input_dim = data_input_dim + model_config = PIWAEConfig( + input_dim=data_input_dim, + latent_dim=50, + reconstruction_loss="bce", + number_samples=8, + number_gradient_estimates=8, + ) model = PIWAE( model_config=model_config, @@ -159,23 +131,36 @@ def main(args): ) ### Set training config - training_config = CoupledOptimizerTrainerConfig.from_json_file(args.training_config) + training_config = CoupledOptimizerTrainerConfig( + output_dir="reproducibility/mnist", + per_device_train_batch_size=20, + per_device_eval_batch_size=20, + num_epochs=3280, + learning_rate=1e-3, + steps_saving=1000, + steps_predict=None, + no_cuda=False, + encoder_optimizer_cls="Adam", + encoder_optimizer_params={"eps": 1e-4}, + decoder_optimizer_cls="Adam", + decoder_optimizer_params={"eps": 1e-4}, + encoder_scheduler_cls="MultiStepLR", + encoder_scheduler_params={ + "milestones": [2, 5, 14, 28, 41, 122, 365, 1094], + "gamma": 10 ** (-1 / 7), + "verbose": True, + }, + decoder_scheduler_cls="MultiStepLR", + decoder_scheduler_params={ + "milestones": [2, 5, 14, 28, 41, 122, 365, 1094], + "gamma": 10 ** (-1 / 7), + "verbose": True, + }, + ) logger.info("Preprocessing train data...") train_dataset = DynBinarizedMNIST(train_data) - ### Optimizers - enc_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=training_config.learning_rate, eps=1e-4) - dec_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=training_config.learning_rate, eps=1e-4) - - ### Schedulers - enc_scheduler = torch.optim.lr_scheduler.MultiStepLR( - enc_optimizer, milestones=[2, 5, 14, 28, 41, 122, 365, 1094], gamma=10**(-1/7), verbose=True - ) - dec_scheduler = torch.optim.lr_scheduler.MultiStepLR( - dec_optimizer, milestones=[2, 5, 14, 28, 41, 122, 365, 1094], gamma=10**(-1/7), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -183,22 +168,31 @@ def main(args): trainer = CoupledOptimizerTrainer( model=model, train_dataset=train_dataset, - eval_dataset=None,#eval_dataset, + eval_dataset=None, # eval_dataset, training_config=training_config, - encoder_optimizer=enc_optimizer, - decoder_optimizer=dec_optimizer, - encoder_scheduler=enc_scheduler, - decoder_scheduler=dec_scheduler, callbacks=None, ) trainer.train() ### Reload model - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) - test_data = (test_data > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device)).float() + test_data = ( + test_data + > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device) + ).float() ### Compute NLL with torch.no_grad(): @@ -207,14 +201,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=5000, batch_size=5000) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/pvae.py b/examples/scripts/reproducibility/pvae.py index 50493c4a..d1380ad3 100644 --- a/examples/scripts/reproducibility/pvae.py +++ b/examples/scripts/reproducibility/pvae.py @@ -1,23 +1,18 @@ -import argparse import logging -import os -import numpy as np import math -from typing import List +import os import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F -from pythae.data.preprocessors import DataProcessor -from pythae.models import PoincareVAE, PoincareVAEConfig -from pythae.models.pvae.pvae_utils import PoincareBall -from pythae.models import AutoModel -from pythae.trainers import BaseTrainerConfig, BaseTrainer -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.data.preprocessors import DataProcessor +from pythae.models import AutoModel, PoincareVAE, PoincareVAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.models.pvae.pvae_utils import PoincareBall +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -26,23 +21,8 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - +device = "cuda" if torch.cuda.is_available() else "cpu" -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() class RiemannianLayer(nn.Module): def __init__(self, in_features, out_features, manifold, over_param, weight_norm): @@ -58,14 +38,18 @@ def __init__(self, in_features, out_features, manifold, over_param, weight_norm) @property def weight(self): - return self.manifold.transp0(self.bias, self._weight) # weight \in T_0 => weight \in T_bias + return self.manifold.transp0( + self.bias, self._weight + ) # weight \in T_0 => weight \in T_bias @property def bias(self): if self.over_param: return self._bias else: - return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold + return self.manifold.expmap0( + self._weight * self._bias + ) # reparameterisation of a point on the manifold def reset_parameters(self): nn.init.kaiming_normal_(self._weight, a=math.sqrt(5)) @@ -73,27 +57,41 @@ def reset_parameters(self): bound = 4 / math.sqrt(fan_in) nn.init.uniform_(self._bias, -bound, bound) if self.over_param: - with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias)) + with torch.no_grad(): + self._bias.set_(self.manifold.expmap0(self._bias)) + class GeodesicLayer(RiemannianLayer): - def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False): - super(GeodesicLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm) + def __init__( + self, in_features, out_features, manifold, over_param=False, weight_norm=False + ): + super(GeodesicLayer, self).__init__( + in_features, out_features, manifold, over_param, weight_norm + ) def forward(self, input): input = input.unsqueeze(0) - input = input.unsqueeze(-2).expand(*input.shape[:-(len(input.shape) - 2)], self.out_features, self.in_features) - res = self.manifold.normdist2plane(input, self.bias, self.weight, - signed=True, norm=self.weight_norm) + input = input.unsqueeze(-2).expand( + *input.shape[: -(len(input.shape) - 2)], self.out_features, self.in_features + ) + res = self.manifold.normdist2plane( + input, self.bias, self.weight, signed=True, norm=self.weight_norm + ) return res + ### Define paper encoder network class Encoder(BaseEncoder): - """ Usual encoder followed by an exponential map """ + """Usual encoder followed by an exponential map""" + def __init__(self, model_config, prior_iso): super(Encoder, self).__init__() - self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature) + self.manifold = PoincareBall( + dim=model_config.latent_dim, c=model_config.curvature + ) self.enc = nn.Sequential( - nn.Linear(np.prod(model_config.input_dim), 600), nn.ReLU(), + nn.Linear(np.prod(model_config.input_dim), 600), + nn.ReLU(), ) self.fc21 = nn.Linear(600, model_config.latent_dim) self.fc22 = nn.Linear(600, model_config.latent_dim if not prior_iso else 1) @@ -104,65 +102,63 @@ def forward(self, x): mu = self.manifold.expmap0(mu) return ModelOutput( embedding=mu, - log_covariance=torch.log(F.softplus(self.fc22(e)) + 1e-5), # expects log_covariance - log_concentration=torch.log(F.softplus(self.fc22(e)) + 1e-5) # for Riemannian Normal - + log_covariance=torch.log( + F.softplus(self.fc22(e)) + 1e-5 + ), # expects log_covariance + log_concentration=torch.log( + F.softplus(self.fc22(e)) + 1e-5 + ), # for Riemannian Normal ) + ### Define paper decoder network class Decoder(BaseDecoder): - """ First layer is a Hypergyroplane followed by usual decoder """ + """First layer is a Hypergyroplane followed by usual decoder""" + def __init__(self, model_config): super(Decoder, self).__init__() - self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature) + self.manifold = PoincareBall( + dim=model_config.latent_dim, c=model_config.curvature + ) self.input_dim = model_config.input_dim self.dec = nn.Sequential( GeodesicLayer(model_config.latent_dim, 600, self.manifold), nn.ReLU(), nn.Linear(600, np.prod(model_config.input_dim)), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, z): out = self.dec(z).reshape((z.shape[0],) + self.input_dim) # reshape data - return ModelOutput( - reconstruction=out - ) + return ModelOutput(reconstruction=out) -def main(args): +def main(): ### Load data train_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))[ - "data" - ] - / 255.0 - ).clamp(1e-5, 1-1e-5) + np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))["data"] / 255.0 + ).clamp(1e-5, 1 - 1e-5) eval_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] - / 255.0 - ).clamp(1e-5, 1-1e-5) + np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] / 255.0 + ).clamp(1e-5, 1 - 1e-5) train_data = torch.cat((train_data, eval_data)) test_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] - / 255.0 - ).clamp(1e-5, 1-1e-5) + np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] / 255.0 + ).clamp(1e-5, 1 - 1e-5) data_input_dim = tuple(train_data.shape[1:]) - - if args.model_config is not None: - model_config = PoincareVAEConfig.from_json_file(args.model_config) - - else: - model_config = PoincareVAEConfig() - - model_config.input_dim = data_input_dim - - + model_config = PoincareVAEConfig( + input_dim=data_input_dim, + latent_dim=10, + reconstruction_loss="bce", + prior_distribution="wrapped_normal", + posterior_distribution="wrapped_normal", + curvature=0.7, + ) model = PoincareVAE( model_config=model_config, @@ -171,7 +167,16 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/mnist", + per_device_train_batch_size=128, + per_device_eval_batch_size=128, + num_epochs=80, + learning_rate=5e-4, + steps_saving=100, + steps_predict=None, + no_cuda=False, + ) ### Process data data_processor = DataProcessor() @@ -179,18 +184,6 @@ def main(args): train_data = data_processor.process_data(torch.bernoulli(train_data)) train_dataset = data_processor.to_dataset(train_data) - logger.info("Preprocessing eval data...\n") - eval_data = data_processor.process_data(torch.bernoulli(eval_data)) - eval_dataset = data_processor.to_dataset(eval_data) - - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[10000000], gamma=10**(-1/3), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -199,17 +192,25 @@ def main(args): trainer = BaseTrainer( model=model, train_dataset=train_dataset, - eval_dataset=None,#eval_dataset, + eval_dataset=None, # eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) ### Launch training trainer.train() - - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) @@ -220,14 +221,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=500, batch_size=500) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/rae_gp.py b/examples/scripts/reproducibility/rae_gp.py index cb19457d..807ce5bf 100644 --- a/examples/scripts/reproducibility/rae_gp.py +++ b/examples/scripts/reproducibility/rae_gp.py @@ -1,18 +1,16 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor -from pythae.trainers import BaseTrainer, BaseTrainerConfig from pythae.models import RAE_GP, RAE_GP_Config -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -21,23 +19,7 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() - +device = "cuda" if torch.cuda.is_available() else "cpu" ### Define paper encoder network class Encoder(BaseEncoder): @@ -153,8 +135,7 @@ def __init__(self, args: dict): layers.append( nn.Sequential( - nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), - nn.Sigmoid() + nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), nn.Sigmoid() ) ) @@ -199,29 +180,20 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output - -def main(args): +def main(): train_data = ( - np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))[ - "data" - ] - / 255.0 + np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))["data"] / 255.0 ) eval_data = ( - np.load(os.path.join(PATH, f"data/celeba", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/celeba", "eval_data.npz"))["data"] / 255.0 ) data_input_dim = tuple(train_data.shape[1:]) - if args.model_config is not None: - model_config = RAE_GP_Config.from_json_file(args.model_config) - - else: - model_config = RAE_GP_Config() - - model_config.input_dim = data_input_dim + model_config = RAE_GP_Config( + input_dim=data_input_dim, latent_dim=16, embedding_weight=1e-3, reg_weight=1e-6 + ) model = RAE_GP( model_config=model_config, @@ -230,7 +202,18 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="my_models_on_celeba", + per_device_train_batch_size=100, + per_device_eval_batch_size=100, + num_epochs=100, + learning_rate=0.001, + steps_saving=None, + steps_predict=100, + no_cuda=False, + scheduler_cls="ReduceLROnPlateau", + scheduler_params={"factor": 0.5, "patience": 5, "verbose": True}, + ) ### Process data data_processor = DataProcessor() @@ -242,14 +225,6 @@ def main(args): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, factor=0.5, patience=5, verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -260,13 +235,12 @@ def main(args): train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) trainer.train() + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/rae_l2.py b/examples/scripts/reproducibility/rae_l2.py index debf061e..3d42ea6c 100644 --- a/examples/scripts/reproducibility/rae_l2.py +++ b/examples/scripts/reproducibility/rae_l2.py @@ -1,20 +1,17 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor from pythae.models import RAE_L2, RAE_L2_Config -from pythae.models.rhvae import RHVAEConfig -from pythae.trainers import BaseTrainerConfig, CoupledOptimizerTrainer - -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.models.rhvae import RHVAEConfig +from pythae.trainers import CoupledOptimizerTrainer, CoupledOptimizerTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -23,23 +20,7 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - - -args = ap.parse_args() +device = "cuda" if torch.cuda.is_available() else "cpu" ### Define paper encoder network class Encoder(BaseEncoder): @@ -155,8 +136,7 @@ def __init__(self, args: dict): layers.append( nn.Sequential( - nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), - nn.Sigmoid() + nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), nn.Sigmoid() ) ) @@ -201,29 +181,20 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output - -def main(args): +def main(): train_data = ( - np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))[ - "data" - ] - / 255.0 + np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))["data"] / 255.0 ) eval_data = ( - np.load(os.path.join(PATH, f"data/celeba", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/celeba", "eval_data.npz"))["data"] / 255.0 ) - - data_input_dim = tuple(train_data.shape[1:]) - - if args.model_config is not None: - model_config = RAE_L2_Config.from_json_file(args.model_config) - else: - model_config = RAE_L2_Config() + data_input_dim = tuple(train_data.shape[1:]) - model_config.input_dim = data_input_dim + model_config = RAE_L2_Config( + input_dim=data_input_dim, latent_dim=16, embedding_weight=1e-6, reg_weight=1e-3 + ) model = RAE_L2( model_config=model_config, @@ -232,7 +203,22 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = CoupledOptimizerTrainerConfig( + output_dir="my_models_on_celeba", + per_device_train_batch_size=100, + per_device_eval_batch_size=100, + num_epochs=100, + learning_rate=0.001, + steps_saving=None, + steps_predict=100, + no_cuda=False, + encoder_scheduler_cls="ReduceLROnPlateau", + encoder_scheduler_params={"factor": 0.5, "patience": 5, "verbose": True}, + decoder_optimizer_cls="Adam", + decoder_optimizer_params={"weight_decay": model_config.reg_weight}, + decoder_scheduler_cls="ReduceLROnPlateau", + decoder_scheduler_params={"factor": 0.5, "patience": 5, "verbose": True}, + ) ### Process data data_processor = DataProcessor() @@ -244,39 +230,22 @@ def main(args): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) - ### Optimizers - enc_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=training_config.learning_rate) - dec_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=training_config.learning_rate, weight_decay=model_config.reg_weight) - - ### Schedulers - enc_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - enc_optimizer, factor=0.5, patience=5, verbose=True - ) - dec_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - enc_optimizer, factor=0.5, patience=5, verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - #logger.info("Using Base Trainer\n") + # logger.info("Using Base Trainer\n") trainer = CoupledOptimizerTrainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - encoder_optimizer=enc_optimizer, - decoder_optimizer=dec_optimizer, - encoder_scheduler=enc_scheduler, - decoder_scheduler=dec_scheduler, callbacks=None, ) - print(trainer.scheduler) - trainer.train() + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/svae.py b/examples/scripts/reproducibility/svae.py index 93e76ba1..c2eb843c 100644 --- a/examples/scripts/reproducibility/svae.py +++ b/examples/scripts/reproducibility/svae.py @@ -1,22 +1,17 @@ -import argparse import logging import os from typing import List import numpy as np import torch - -from pythae.data.preprocessors import DataProcessor -from pythae.models import SVAE, SVAEConfig -from pythae.models import AutoModel -from pythae.trainers import BaseTrainerConfig, BaseTrainer -from pythae.data.datasets import DatasetOutput +import torch.nn as nn from torch.utils.data import Dataset -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.data.datasets import DatasetOutput +from pythae.models import SVAE, AutoModel, SVAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -25,23 +20,7 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() +device = "cuda" if torch.cuda.is_available() else "cpu" ### Define paper encoder network class Encoder(BaseEncoder): @@ -57,7 +36,7 @@ def __init__(self, args: dict): nn.Linear(np.prod(args.input_dim), 256), nn.ReLU(), nn.Linear(256, 128), - nn.ReLU() + nn.ReLU(), ) ) @@ -67,7 +46,6 @@ def __init__(self, args: dict): self.embedding = nn.Linear(128, self.latent_dim) self.log_concentration = nn.Linear(128, 1) - def forward(self, x, output_layer_levels: List[int] = None): output = ModelOutput() @@ -78,9 +56,7 @@ def forward(self, x, output_layer_levels: List[int] = None): assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels - ), ( - f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." - ) + ), f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." if -1 in output_layer_levels: max_depth = self.depth @@ -101,6 +77,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -108,8 +85,6 @@ def __init__(self, args: dict): self.input_dim = args.input_dim - # assert 0, np.prod(args.input_dim) - layers = nn.ModuleList() layers.append( @@ -117,15 +92,12 @@ def __init__(self, args: dict): nn.Linear(args.latent_dim, 128), nn.ReLU(), nn.Linear(128, 256), - nn.ReLU() + nn.ReLU(), ) ) layers.append( - nn.Sequential( - nn.Linear(256, np.prod(args.input_dim)), - nn.Sigmoid() - ) + nn.Sequential(nn.Linear(256, np.prod(args.input_dim)), nn.Sigmoid()) ) self.layers = layers @@ -165,6 +137,7 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output + class DynBinarizedMNIST(Dataset): def __init__(self, data): self.data = data.type(torch.float) @@ -177,37 +150,28 @@ def __getitem__(self, index): x = (x > torch.distributions.Uniform(0, 1).sample(x.shape)).float() return DatasetOutput(data=x) -def main(args): + +def main(): ### Load data train_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))[ - "data" - ] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))["data"] / 255.0 ) eval_data = torch.tensor( - np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] / 255.0 ) train_data = torch.cat((train_data, eval_data)) test_data = ( - np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] / 255.0 ) data_input_dim = tuple(train_data.shape[1:]) - - if args.model_config is not None: - model_config = SVAEConfig.from_json_file(args.model_config) - - else: - model_config = SVAEConfig() - - model_config.input_dim = data_input_dim + model_config = SVAEConfig( + input_dim=data_input_dim, latent_dim=11, reconstruction_loss="bce" + ) model = SVAE( model_config=model_config, @@ -216,22 +180,20 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/binary_mnist", + per_device_train_batch_size=64, + per_device_eval_batch_size=64, + num_epochs=500, + learning_rate=1e-3, + steps_saving=50, + steps_predict=None, + no_cuda=False, + ) logger.info("Preprocessing train data...") train_dataset = DynBinarizedMNIST(train_data) - logger.info("Preprocessing eval data...\n") - eval_dataset = DynBinarizedMNIST(eval_data) - - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[10000000], gamma=10**(-1/3), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -240,20 +202,31 @@ def main(args): trainer = BaseTrainer( model=model, train_dataset=train_dataset, - eval_dataset=None,#eval_dataset, + eval_dataset=None, # eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) ### Launch training trainer.train() - - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) - test_data = (test_data > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device)).float() + test_data = ( + test_data + > torch.distributions.Uniform(0, 1).sample(test_data.shape).to(test_data.device) + ).float() ### Compute NLL with torch.no_grad(): @@ -262,14 +235,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=500, batch_size=500) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/vae.py b/examples/scripts/reproducibility/vae.py index 633fe021..286f5343 100644 --- a/examples/scripts/reproducibility/vae.py +++ b/examples/scripts/reproducibility/vae.py @@ -1,20 +1,16 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor -from pythae.models import AutoModel -from pythae.models import VAE, VAEConfig -from pythae.trainers import BaseTrainerConfig, BaseTrainer - -from pythae.models.nn import BaseDecoder, BaseEncoder -import torch.nn as nn +from pythae.models import VAE, AutoModel, VAEConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -23,22 +19,7 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() +device = "cuda" if torch.cuda.is_available() else "cpu" ### Define paper encoder network class Encoder(BaseEncoder): @@ -53,8 +34,8 @@ def __init__(self, args: dict): nn.Sequential( nn.Linear(np.prod(args.input_dim), 400), nn.ReLU(), - #nn.Linear(400, 400), - #nn.ReLU() + # nn.Linear(400, 400), + # nn.ReLU() ) ) @@ -64,7 +45,6 @@ def __init__(self, args: dict): self.embedding = nn.Linear(400, self.latent_dim) self.log_var = nn.Linear(400, self.latent_dim) - def forward(self, x, output_layer_levels: List[int] = None): output = ModelOutput() @@ -75,9 +55,7 @@ def forward(self, x, output_layer_levels: List[int] = None): assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels - ), ( - f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." - ) + ), f"Cannot output layer deeper than depth ({self.depth}). Got ({output_layer_levels})." if -1 in output_layer_levels: max_depth = self.depth @@ -98,6 +76,7 @@ def forward(self, x, output_layer_levels: List[int] = None): return output + ### Define paper decoder network class Decoder(BaseDecoder): def __init__(self, args: dict): @@ -111,16 +90,13 @@ def __init__(self, args: dict): nn.Sequential( nn.Linear(args.latent_dim, 400), nn.ReLU(), - #nn.Linear(400, 400), - #nn.ReLU() + # nn.Linear(400, 400), + # nn.ReLU() ) ) layers.append( - nn.Sequential( - nn.Linear(400, np.prod(args.input_dim)), - nn.Sigmoid() - ) + nn.Sequential(nn.Linear(400, np.prod(args.input_dim)), nn.Sigmoid()) ) self.layers = layers @@ -161,23 +137,24 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output +def main(): -def main(args): - - train_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat")) - eval_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat")) - test_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat")) + train_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat") + ) + eval_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat") + ) + test_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat") + ) data_input_dim = tuple(train_data.shape[1:]) ### Build model - if args.model_config is not None: - model_config = VAEConfig.from_json_file(args.model_config) - - else: - model_config = VAEConfig() - - model_config.input_dim = data_input_dim + model_config = VAEConfig( + input_dim=data_input_dim, latent_dim=40, reconstruction_loss="bce" + ) model = VAE( model_config=model_config, @@ -186,7 +163,22 @@ def main(args): ) ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + training_config = BaseTrainerConfig( + output_dir="reproducibility/binary_mnist", + per_device_train_batch_size=100, + per_device_eval_batch_size=100, + num_epochs=500, + learning_rate=1e-4, + steps_saving=None, + steps_predict=None, + no_cuda=False, + scheduler_cls="MultiStepLR", + scheduler_params={ + "milestones": [200, 350, 500, 750, 1000], + "gamma": 10 ** (-1 / 5), + "verbose": True, + }, + ) ### Process data data_processor = DataProcessor() @@ -198,14 +190,6 @@ def main(args): eval_data = data_processor.process_data(eval_data).reshape(-1, 1, 28, 28) eval_dataset = data_processor.to_dataset(eval_data) - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[200, 350, 500, 750, 1000], gamma=10**(-1/5), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -216,15 +200,23 @@ def main(args): train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) trainer.train() - + ### Reload the model - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) @@ -235,14 +227,11 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=200) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/vamp.py b/examples/scripts/reproducibility/vamp.py index 647c6361..522eb102 100644 --- a/examples/scripts/reproducibility/vamp.py +++ b/examples/scripts/reproducibility/vamp.py @@ -1,21 +1,16 @@ -import argparse import logging import os from typing import List import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor -from pythae.models import AutoModel -from pythae.models import VAMP, VAMPConfig -from pythae.trainers import (AdversarialTrainerConfig, BaseTrainerConfig, BaseTrainer, - CoupledOptimizerTrainerConfig) - -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn +from pythae.models import VAMP, AutoModel, VAMPConfig from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -24,30 +19,17 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() +device = "cuda" if torch.cuda.is_available() else "cpu" def he_init(m): - s = np.sqrt( 2. / m.in_features ) + s = np.sqrt(2.0 / m.in_features) m.weight.data.normal_(0, s) + ### Define custom layers from the paper + class NonLinear(nn.Module): def __init__(self, input_size, output_size, bias=True, activation=None): super(NonLinear, self).__init__() @@ -58,7 +40,7 @@ def __init__(self, input_size, output_size, bias=True, activation=None): def forward(self, x): h = self.linear(x) if self.activation is not None: - h = self.activation( h ) + h = self.activation(h) return h @@ -75,9 +57,9 @@ def __init__(self, input_size, output_size, activation=None): def forward(self, x): h = self.h(x) if self.activation is not None: - h = self.activation( self.h( x ) ) + h = self.activation(self.h(x)) - g = self.sigmoid( self.g( x ) ) + g = self.sigmoid(self.g(x)) return h * g @@ -93,8 +75,7 @@ def __init__(self, args: dict): layers.append( nn.Sequential( - GatedDense(np.prod(args.input_dim), 300), - GatedDense(300, 300) + GatedDense(np.prod(args.input_dim), 300), GatedDense(300, 300) ) ) @@ -102,7 +83,9 @@ def __init__(self, args: dict): self.depth = len(layers) self.embedding = nn.Linear(300, self.latent_dim) - self.log_var = NonLinear(300, self.latent_dim, activation=nn.Hardtanh(min_val=-6.,max_val=2.)) + self.log_var = NonLinear( + 300, self.latent_dim, activation=nn.Hardtanh(min_val=-6.0, max_val=2.0) + ) for m in self.modules(): if isinstance(m, nn.Linear): @@ -150,15 +133,10 @@ def __init__(self, args: dict): self.input_dim = args.input_dim - # assert 0, np.prod(args.input_dim) - layers = nn.ModuleList() layers.append( - nn.Sequential( - GatedDense(args.latent_dim, 300), - GatedDense(300, 300) - ) + nn.Sequential(GatedDense(args.latent_dim, 300), GatedDense(300, 300)) ) layers.append( @@ -209,23 +187,28 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output +def main(): -def main(args): - - train_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat")) - eval_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat")) - test_data = np.loadtxt(os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat")) + train_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_train.amat") + ) + eval_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_valid.amat") + ) + test_data = np.loadtxt( + os.path.join(PATH, f"data/binary_mnist", "binarized_mnist_test.amat") + ) data_input_dim = tuple(train_data.shape[1:]) ### Build model - if args.model_config is not None: - model_config = VAMPConfig.from_json_file(args.model_config) - - else: - model_config = VAMPConfig() - - model_config.input_dim = data_input_dim + model_config = VAMPConfig( + input_dim=data_input_dim, + latent_dim=40, + reconstruction_loss="bce", + number_components=500, + linear_scheduling_steps=100, + ) model = VAMP( model_config=model_config, @@ -233,10 +216,24 @@ def main(args): decoder=Decoder(model_config), ) - ### Set training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) - + training_config = BaseTrainerConfig( + output_dir="reproducibility/binary_mnist", + per_device_train_batch_size=100, + per_device_eval_batch_size=100, + num_epochs=2000, + learning_rate=1e-4, + steps_saving=None, + steps_predict=None, + no_cuda=False, + optimizer_cls="RMSprop", + scheduler_cls="MultiStepLR", + scheduler_params={ + "milestones": [200, 350, 500, 750, 1000], + "gamma": 10 ** (-1 / 5), + "verbose": True, + }, + ) ### Process data data_processor = DataProcessor() @@ -248,14 +245,6 @@ def main(args): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) - ### Optimizer - optimizer = torch.optim.RMSprop(model.parameters(), lr=training_config.learning_rate) - - ### Scheduler - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[200, 350, 500, 750, 1000], gamma=10**(-1/5), verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -266,19 +255,26 @@ def main(args): train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) trainer.train() - + ### Reload model - trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device).eval() + trained_model = ( + AutoModel.load_from_folder( + os.path.join( + training_config.output_dir, + f"{trainer.model.model_name}_training_{trainer._training_signature}", + "final_model", + ) + ) + .to(device) + .eval() + ) test_data = torch.tensor(test_data).to(device).type(torch.float) - ### Compute NLL with torch.no_grad(): nll = [] @@ -286,13 +282,10 @@ def main(args): nll_i = trained_model.get_nll(test_data, n_samples=5000, batch_size=5000) logger.info(f"Round {i+1} nll: {nll_i}") nll.append(nll_i) - - logger.info( - f'\nmean_nll: {np.mean(nll)}' - ) - logger.info( - f'\std_nll: {np.std(nll)}' - ) + + logger.info(f"\nmean_nll: {np.mean(nll)}") + logger.info(f"\std_nll: {np.std(nll)}") + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/reproducibility/wae.py b/examples/scripts/reproducibility/wae.py index 7011e87d..d40977ff 100644 --- a/examples/scripts/reproducibility/wae.py +++ b/examples/scripts/reproducibility/wae.py @@ -1,4 +1,3 @@ -import argparse import logging import os from time import time @@ -6,15 +5,13 @@ import numpy as np import torch +import torch.nn as nn from pythae.data.preprocessors import DataProcessor from pythae.models import WAE_MMD, WAE_MMD_Config -from pythae.trainers import BaseTrainer, BaseTrainerConfig - -from pythae.models.nn import BaseEncoder, BaseDecoder -import torch.nn as nn from pythae.models.base.base_utils import ModelOutput - +from pythae.models.nn import BaseDecoder, BaseEncoder +from pythae.trainers import BaseTrainer, BaseTrainerConfig logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -23,28 +20,10 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -ap = argparse.ArgumentParser() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -# Training setting -ap.add_argument( - "--model_config", - help="path to model config file (expected json file)", - default=None, -) -ap.add_argument( - "--training_config", - help="path to training config_file (expected json file)", - default=os.path.join(PATH, "configs/base_training_config.json"), -) - -args = ap.parse_args() - +device = "cuda" if torch.cuda.is_available() else "cpu" ### Define paper encoder network - class Encoder(BaseEncoder): def __init__(self, args): BaseEncoder.__init__(self) @@ -122,6 +101,7 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): ### Define paper decoder network + class Decoder(BaseDecoder): def __init__(self, args: dict): BaseDecoder.__init__(self) @@ -159,8 +139,7 @@ def __init__(self, args: dict): layers.append( nn.Sequential( - nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), - nn.Sigmoid() + nn.ConvTranspose2d(128, self.n_channels, 5, 1, padding=1), nn.Sigmoid() ) ) @@ -205,39 +184,50 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): return output +def main(): -def main(args): - - ### Load data train_data = ( - np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))[ - "data" - ] - / 255.0 - ) - eval_data = ( - np.load(os.path.join(PATH, f"data/celeba", "eval_data.npz"))["data"] - / 255.0 + np.load(os.path.join(PATH, f"data/celeba", "train_data.npz"))["data"] / 255.0 ) data_input_dim = tuple(train_data.shape[1:]) - ### Build model - model_config = WAE_MMD_Config.from_json_file(args.model_config) - model_config.input_dim = data_input_dim + model_config = WAE_MMD_Config( + input_dim=data_input_dim, + latent_dim=64, + kernel_choice="imq", + reg_weight=100, + kernel_bandwidth=2.0, + reconstruction_loss_scale=0.05, + ) + model = WAE_MMD( model_config=model_config, encoder=Encoder(model_config), decoder=Decoder(model_config), ) - - - ### Get training config - training_config = BaseTrainerConfig.from_json_file(args.training_config) + ### Get training config + training_config = BaseTrainerConfig( + output_dir="my_models_on_celeba", + per_device_train_batch_size=100, + per_device_eval_batch_size=100, + num_epochs=100, + learning_rate=0.0001, + steps_saving=3, + steps_predict=100, + no_cuda=False, + scheduler_cls="LambdaLR", + scheduler_params={ + "lr_lambda": lambda epoch: 1 * (epoch < 30) + + 0.5 * (30 <= epoch < 50) + + 0.2 * (50 <= epoch), + "verbose": True, + }, + ) ### Process data data_processor = DataProcessor() @@ -245,21 +235,6 @@ def main(args): train_data = data_processor.process_data(train_data) train_dataset = data_processor.to_dataset(train_data) - logger.info("Preprocessing eval data...\n") - eval_data = data_processor.process_data(eval_data) - eval_dataset = data_processor.to_dataset(eval_data) - - ### Optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate) - - ### Scheduler - - lambda_lr = lambda epoch: 1 * (epoch < 30) + 0.5 * (30 <= epoch < 50) + 0.2 * (50 <= epoch) - - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=lambda_lr, verbose=True - ) - seed = 123 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -269,14 +244,12 @@ def main(args): model=model, train_dataset=train_dataset, training_config=training_config, - optimizer=optimizer, - scheduler=scheduler, callbacks=None, ) trainer.train() - + if __name__ == "__main__": - main(args) + main() diff --git a/examples/scripts/training.py b/examples/scripts/training.py index 68fec65c..b81e33cc 100644 --- a/examples/scripts/training.py +++ b/examples/scripts/training.py @@ -1,19 +1,14 @@ import argparse -import importlib import logging import os import numpy as np -import torch -from pythae.data.preprocessors import DataProcessor -from pythae.models import RHVAE -from pythae.models.rhvae import RHVAEConfig from pythae.pipelines import TrainingPipeline from pythae.trainers import ( + AdversarialTrainerConfig, BaseTrainerConfig, CoupledOptimizerTrainerConfig, - AdversarialTrainerConfig, ) logger = logging.getLogger(__name__) @@ -75,7 +70,7 @@ "--nn", help="neural nets to use", default="convnet", - choices=["default", "convnet","resnet"] + choices=["default", "convnet", "resnet"], ) ap.add_argument( "--training_config", @@ -107,22 +102,45 @@ def main(args): if args.nn == "convnet": - from pythae.models.nn.benchmarks.mnist import Encoder_Conv_AE_MNIST as Encoder_AE - from pythae.models.nn.benchmarks.mnist import Encoder_Conv_VAE_MNIST as Encoder_VAE - from pythae.models.nn.benchmarks.mnist import Encoder_Conv_SVAE_MNIST as Encoder_SVAE - from pythae.models.nn.benchmarks.mnist import Encoder_Conv_AE_MNIST as Encoder_VQVAE - from pythae.models.nn.benchmarks.mnist import Decoder_Conv_AE_MNIST as Decoder_AE - from pythae.models.nn.benchmarks.mnist import Decoder_Conv_AE_MNIST as Decoder_VQVAE + from pythae.models.nn.benchmarks.mnist import ( + Decoder_Conv_AE_MNIST as Decoder_AE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Decoder_Conv_AE_MNIST as Decoder_VQVAE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Encoder_Conv_AE_MNIST as Encoder_AE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Encoder_Conv_AE_MNIST as Encoder_VQVAE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Encoder_Conv_SVAE_MNIST as Encoder_SVAE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Encoder_Conv_VAE_MNIST as Encoder_VAE, + ) elif args.nn == "resnet": - from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_AE_MNIST as Encoder_AE - from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST as Encoder_VAE - from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_SVAE_MNIST as Encoder_SVAE - from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VQVAE_MNIST as Encoder_VQVAE - from pythae.models.nn.benchmarks.mnist import Decoder_ResNet_AE_MNIST as Decoder_AE - from pythae.models.nn.benchmarks.mnist import Decoder_ResNet_VQVAE_MNIST as Decoder_VQVAE - - + from pythae.models.nn.benchmarks.mnist import ( + Encoder_ResNet_AE_MNIST as Encoder_AE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Encoder_ResNet_VAE_MNIST as Encoder_VAE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Encoder_ResNet_SVAE_MNIST as Encoder_SVAE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Encoder_ResNet_VQVAE_MNIST as Encoder_VQVAE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Decoder_ResNet_AE_MNIST as Decoder_AE, + ) + from pythae.models.nn.benchmarks.mnist import ( + Decoder_ResNet_VQVAE_MNIST as Decoder_VQVAE, + ) + from pythae.models.nn.benchmarks.mnist import ( Discriminator_Conv_MNIST as Discriminator, ) @@ -131,39 +149,87 @@ def main(args): if args.nn == "convnet": - from pythae.models.nn.benchmarks.cifar import Encoder_Conv_AE_CIFAR as Encoder_AE - from pythae.models.nn.benchmarks.cifar import Encoder_Conv_VAE_CIFAR as Encoder_VAE - from pythae.models.nn.benchmarks.cifar import Encoder_Conv_SVAE_CIFAR as Encoder_SVAE - from pythae.models.nn.benchmarks.cifar import Encoder_Conv_AE_CIFAR as Encoder_VQVAE - from pythae.models.nn.benchmarks.cifar import Decoder_Conv_AE_CIFAR as Decoder_AE - from pythae.models.nn.benchmarks.cifar import Decoder_Conv_AE_CIFAR as Decoder_VQVAE + from pythae.models.nn.benchmarks.cifar import ( + Decoder_Conv_AE_CIFAR as Decoder_AE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Decoder_Conv_AE_CIFAR as Decoder_VQVAE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_Conv_AE_CIFAR as Encoder_AE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_Conv_AE_CIFAR as Encoder_VQVAE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_Conv_SVAE_CIFAR as Encoder_SVAE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_Conv_VAE_CIFAR as Encoder_VAE, + ) elif args.nn == "resnet": - from pythae.models.nn.benchmarks.cifar import Encoder_ResNet_AE_CIFAR as Encoder_AE - from pythae.models.nn.benchmarks.cifar import Encoder_ResNet_VAE_CIFAR as Encoder_VAE - from pythae.models.nn.benchmarks.cifar import Encoder_ResNet_SVAE_CIFAR as Encoder_SVAE - from pythae.models.nn.benchmarks.cifar import Encoder_ResNet_VQVAE_CIFAR as Encoder_VQVAE - from pythae.models.nn.benchmarks.cifar import Decoder_ResNet_AE_CIFAR as Decoder_AE - from pythae.models.nn.benchmarks.cifar import Decoder_ResNet_VQVAE_CIFAR as Decoder_VQVAE + from pythae.models.nn.benchmarks.cifar import ( + Decoder_ResNet_AE_CIFAR as Decoder_AE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Decoder_ResNet_VQVAE_CIFAR as Decoder_VQVAE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_ResNet_AE_CIFAR as Encoder_AE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_ResNet_SVAE_CIFAR as Encoder_SVAE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_ResNet_VAE_CIFAR as Encoder_VAE, + ) + from pythae.models.nn.benchmarks.cifar import ( + Encoder_ResNet_VQVAE_CIFAR as Encoder_VQVAE, + ) elif args.dataset == "celeba": if args.nn == "convnet": - from pythae.models.nn.benchmarks.celeba import Encoder_Conv_AE_CELEBA as Encoder_AE - from pythae.models.nn.benchmarks.celeba import Encoder_Conv_VAE_CELEBA as Encoder_VAE - from pythae.models.nn.benchmarks.celeba import Encoder_Conv_SVAE_CELEBA as Encoder_SVAE - from pythae.models.nn.benchmarks.celeba import Encoder_Conv_AE_CELEBA as Encoder_VQVAE - from pythae.models.nn.benchmarks.celeba import Decoder_Conv_AE_CELEBA as Decoder_AE - from pythae.models.nn.benchmarks.celeba import Decoder_Conv_AE_CELEBA as Decoder_VQVAE + from pythae.models.nn.benchmarks.celeba import ( + Decoder_Conv_AE_CELEBA as Decoder_AE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Decoder_Conv_AE_CELEBA as Decoder_VQVAE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_Conv_AE_CELEBA as Encoder_AE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_Conv_AE_CELEBA as Encoder_VQVAE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_Conv_SVAE_CELEBA as Encoder_SVAE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_Conv_VAE_CELEBA as Encoder_VAE, + ) elif args.nn == "resnet": - from pythae.models.nn.benchmarks.celeba import Encoder_ResNet_AE_CELEBA as Encoder_AE - from pythae.models.nn.benchmarks.celeba import Encoder_ResNet_VAE_CELEBA as Encoder_VAE - from pythae.models.nn.benchmarks.celeba import Encoder_ResNet_SVAE_CELEBA as Encoder_SVAE - from pythae.models.nn.benchmarks.celeba import Encoder_ResNet_VQVAE_CELEBA as Encoder_VQVAE - from pythae.models.nn.benchmarks.celeba import Decoder_ResNet_AE_CELEBA as Decoder_AE - from pythae.models.nn.benchmarks.celeba import Decoder_ResNet_VQVAE_CELEBA as Decoder_VQVAE + from pythae.models.nn.benchmarks.celeba import ( + Decoder_ResNet_AE_CELEBA as Decoder_AE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Decoder_ResNet_VQVAE_CELEBA as Decoder_VQVAE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_ResNet_AE_CELEBA as Encoder_AE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_ResNet_SVAE_CELEBA as Encoder_SVAE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_ResNet_VAE_CELEBA as Encoder_VAE, + ) + from pythae.models.nn.benchmarks.celeba import ( + Encoder_ResNet_VQVAE_CELEBA as Encoder_VQVAE, + ) try: logger.info(f"\nLoading {args.dataset} data...\n") diff --git a/setup.py b/setup.py index ae833daa..852c9ad9 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="pythae", - version="0.0.9", + version="0.1.0", author="Clement Chadebec (HekA team INRIA)", author_email="clement.chadebec@inria.fr", description="Unifying Generative Autoencoders in Python", diff --git a/src/pythae/models/msssim_vae/msssim_vae_utils.py b/src/pythae/models/msssim_vae/msssim_vae_utils.py index 7fe972de..6e7d8e6a 100644 --- a/src/pythae/models/msssim_vae/msssim_vae_utils.py +++ b/src/pythae/models/msssim_vae/msssim_vae_utils.py @@ -13,7 +13,7 @@ def __init__(self, window_size=11): def _gaussian(self, sigma): gauss = torch.Tensor( [ - np.exp(-((x - self.window_size // 2) ** 2) / float(2 * sigma**2)) + np.exp(-((x - self.window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(self.window_size) ] ) @@ -129,8 +129,8 @@ def forward(self, img1, img2): mssim = (mssim + 1) / 2 mcs = (mcs + 1) / 2 - pow1 = mcs**weights - pow2 = mssim**weights + pow1 = mcs ** weights + pow2 = mssim ** weights output = torch.prod(pow1[:-1] * pow2[-1]) return 1 - output diff --git a/src/pythae/models/nn/default_architectures.py b/src/pythae/models/nn/default_architectures.py index b410bb5e..9eb52ddd 100644 --- a/src/pythae/models/nn/default_architectures.py +++ b/src/pythae/models/nn/default_architectures.py @@ -166,8 +166,6 @@ def __init__(self, args: dict): self.input_dim = args.input_dim - # assert 0, np.prod(args.input_dim) - layers = nn.ModuleList() layers.append(nn.Sequential(nn.Linear(args.latent_dim, 512), nn.ReLU())) diff --git a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py index 849f0cc1..3465e449 100644 --- a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py +++ b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py @@ -86,8 +86,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: x.shape[3], ) - # assert 0, (out.reshape(-1, self.model_config.n_embeddings).shape, x.reshape(x.shape[0], -1,).long().shape) - loss = F.cross_entropy(out, x.long()) return ModelOutput(out=out, loss=loss) diff --git a/src/pythae/models/vq_vae/vq_vae_config.py b/src/pythae/models/vq_vae/vq_vae_config.py index a6949a4e..86cc25b6 100644 --- a/src/pythae/models/vq_vae/vq_vae_config.py +++ b/src/pythae/models/vq_vae/vq_vae_config.py @@ -6,10 +6,11 @@ @dataclass class VQVAEConfig(AEConfig): r""" - Vector Quentized VAE model config config class + Vector Quantized VAE model config config class Parameters: input_dim (tuple): The input_data dimension. + latent_dim (int): The latent space dimension. Default: None. commitment_loss_factor (float): The commitment loss factor in the loss. Default: 0.25. quantization_loss_factor: The quantization loss factor in the loss. Default: 1. num_embedding (int): The number of embedding points. Default: 512 diff --git a/src/pythae/models/vq_vae/vq_vae_model.py b/src/pythae/models/vq_vae/vq_vae_model.py index b9fd16c8..d36ccb37 100644 --- a/src/pythae/models/vq_vae/vq_vae_model.py +++ b/src/pythae/models/vq_vae/vq_vae_model.py @@ -84,6 +84,7 @@ def forward(self, inputs: BaseDataset, **kwargs): """ x = inputs["data"] + uses_ddp = kwargs.pop("uses_ddp", False) encoder_output = self.encoder(x) @@ -97,7 +98,7 @@ def forward(self, inputs: BaseDataset, **kwargs): embeddings = embeddings.permute(0, 2, 3, 1) - quantizer_output = self.quantizer(embeddings) + quantizer_output = self.quantizer(embeddings, uses_ddp=uses_ddp) quantized_embed = quantizer_output.quantized_vector quantized_indices = quantizer_output.quantized_indices diff --git a/src/pythae/models/vq_vae/vq_vae_utils.py b/src/pythae/models/vq_vae/vq_vae_utils.py index 65667bc3..7eaae14c 100644 --- a/src/pythae/models/vq_vae/vq_vae_utils.py +++ b/src/pythae/models/vq_vae/vq_vae_utils.py @@ -1,4 +1,5 @@ import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -26,7 +27,7 @@ def __init__(self, model_config: VQVAEConfig): -1 / self.num_embeddings, 1 / self.num_embeddings ) - def forward(self, z: torch.Tensor): + def forward(self, z: torch.Tensor, uses_ddp: bool = False): distances = ( (z.reshape(-1, self.embedding_dim) ** 2).sum(dim=-1, keepdim=True) @@ -89,26 +90,24 @@ def __init__(self, model_config: VQVAEConfig): self.commitment_loss_factor = model_config.commitment_loss_factor self.decay = model_config.decay - self.embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim) + embeddings = torch.empty(self.num_embeddings, self.embedding_dim) - self.embeddings.weight.data.uniform_( - -1 / self.num_embeddings, 1 / self.num_embeddings - ) + embeddings.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) self.register_buffer("cluster_size", torch.zeros(self.num_embeddings)) - - self.ema_embed = nn.Parameter( - torch.Tensor(self.num_embeddings, self.embedding_dim) + self.register_buffer( + "ema_embed", torch.zeros(self.num_embeddings, self.embedding_dim) ) - self.ema_embed.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) - def forward(self, z: torch.Tensor): + self.register_buffer("embeddings", embeddings) + + def forward(self, z: torch.Tensor, uses_ddp: bool = False): distances = ( (z.reshape(-1, self.embedding_dim) ** 2).sum(dim=-1, keepdim=True) - + (self.embeddings.weight ** 2).sum(dim=-1) - - 2 * z.reshape(-1, self.embedding_dim) @ self.embeddings.weight.T + + (self.embeddings ** 2).sum(dim=-1) + - 2 * z.reshape(-1, self.embedding_dim) @ self.embeddings.T ) closest = distances.argmin(-1).unsqueeze(-1) @@ -122,20 +121,24 @@ def forward(self, z: torch.Tensor): ) # quantization - quantized = one_hot_encoding @ self.embeddings.weight + quantized = one_hot_encoding @ self.embeddings quantized = quantized.reshape_as(z) if self.training: n_i = torch.sum(one_hot_encoding, dim=0) + if uses_ddp: + dist.all_reduce(n_i) + self.cluster_size = self.cluster_size * self.decay + n_i * (1 - self.decay) dw = one_hot_encoding.T @ z.reshape(-1, self.embedding_dim) - self.ema_embed = nn.Parameter( - self.ema_embed * self.decay + dw * (1 - self.decay) - ) + if uses_ddp: + dist.all_reduce(dw) + + ema_embed = self.ema_embed * self.decay + dw * (1 - self.decay) n = torch.sum(self.cluster_size) @@ -143,9 +146,8 @@ def forward(self, z: torch.Tensor): (self.cluster_size + 1e-5) / (n + self.num_embeddings * 1e-5) * n ) - self.embeddings.weight = nn.Parameter( - self.ema_embed / self.cluster_size.unsqueeze(-1) - ) + self.embeddings.data.copy_(ema_embed / self.cluster_size.unsqueeze(-1)) + self.ema_embed.data.copy_(ema_embed) commitment_loss = F.mse_loss( quantized.detach().reshape(-1, self.embedding_dim), diff --git a/src/pythae/pipelines/training.py b/src/pythae/pipelines/training.py index 8960940c..92775ea9 100644 --- a/src/pythae/pipelines/training.py +++ b/src/pythae/pipelines/training.py @@ -73,8 +73,14 @@ def __init__( f"is expected for training a {model.model_name}" ) if model.model_name == "RAE_L2": - training_config.encoder_optim_decay = 0.0 - training_config.decoder_optim_decay = model.model_config.reg_weight + if training_config.decoder_optimizer_params is None: + training_config.decoder_optimizer_params = { + "weight_decay": model.model_config.reg_weight + } + else: + training_config.decoder_optimizer_params[ + "weight_decay" + ] = model.model_config.reg_weight elif model.model_name == "Adversarial_AE" or model.model_name == "FactorVAE": if not isinstance(training_config, AdversarialTrainerConfig): @@ -138,7 +144,7 @@ def _check_dataset(self, dataset: BaseDataset): loader_out = next(iter(dataloader)) assert loader_out.data.shape[0] == min( len(dataset), 2 - ), "Error when combining dataset wih loader." + ), "Error when combining dataset with loader." def __call__( self, diff --git a/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py b/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py index 7475073e..4dc1d6e9 100644 --- a/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py +++ b/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py @@ -6,12 +6,14 @@ from typing import List, Optional import torch +import torch.distributed as dist import torch.optim as optim -from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.optim.lr_scheduler as lr_scheduler from ...data.datasets import BaseDataset from ...models import BaseAE from ..base_trainer import BaseTrainer +from ..trainer_utils import set_seed from ..training_callbacks import TrainingCallback from .adversarial_trainer_config import AdversarialTrainerConfig @@ -51,10 +53,6 @@ def __init__( train_dataset: BaseDataset, eval_dataset: Optional[BaseDataset] = None, training_config: Optional[AdversarialTrainerConfig] = None, - autoencoder_optimizer: Optional[torch.optim.Optimizer] = None, - discriminator_optimizer: Optional[torch.optim.Optimizer] = None, - autoencoder_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, - discriminator_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, callbacks: List[TrainingCallback] = None, ): @@ -64,66 +62,121 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=None, callbacks=callbacks, ) - # set autoencoder optimizer - if autoencoder_optimizer is None: - autoencoder_optimizer = self.set_default_autoencoder_optimizer(model) + def set_autoencoder_optimizer(self): + autoencoder_optimizer_cls = getattr( + optim, self.training_config.autoencoder_optimizer_cls + ) + if self.training_config.autoencoder_optimizer_params is not None: + if self.distributed: + autoencoder_optimizer = autoencoder_optimizer_cls( + itertools.chain( + self.model.module.encoder.parameters(), + self.model.module.decoder.parameters(), + ), + lr=self.training_config.autoencoder_learning_rate, + **self.training_config.autoencoder_optimizer_params, + ) + else: + autoencoder_optimizer = autoencoder_optimizer_cls( + itertools.chain( + self.model.encoder.parameters(), self.model.decoder.parameters() + ), + lr=self.training_config.autoencoder_learning_rate, + **self.training_config.autoencoder_optimizer_params, + ) else: - autoencoder_optimizer = self._set_optimizer_on_device( - autoencoder_optimizer, self.device - ) + if self.distributed: + autoencoder_optimizer = autoencoder_optimizer_cls( + itertools.chain( + self.model.module.encoder.parameters(), + self.model.module.decoder.parameters(), + ), + lr=self.training_config.autoencoder_learning_rate, + ) + else: + autoencoder_optimizer = autoencoder_optimizer_cls( + itertools.chain( + self.model.encoder.parameters(), self.model.decoder.parameters() + ), + lr=self.training_config.autoencoder_learning_rate, + ) - if autoencoder_scheduler is None: - autoencoder_scheduler = self.set_default_scheduler( - model, autoencoder_optimizer + self.autoencoder_optimizer = autoencoder_optimizer + + def set_autoencoder_scheduler(self): + if self.training_config.autoencoder_scheduler_cls is not None: + autoencoder_scheduler_cls = getattr( + lr_scheduler, self.training_config.autoencoder_scheduler_cls ) - # set discriminator optimizer - if discriminator_optimizer is None: - discriminator_optimizer = self.set_default_discriminator_optimizer(model) + if self.training_config.autoencoder_scheduler_params is not None: + scheduler = autoencoder_scheduler_cls( + self.autoencoder_optimizer, + **self.training_config.autoencoder_scheduler_params, + ) + else: + scheduler = autoencoder_scheduler_cls(self.autoencoder_optimizer) else: - discriminator_optimizer = self._set_optimizer_on_device( - discriminator_optimizer, self.device - ) + scheduler = None - if discriminator_scheduler is None: - discriminator_scheduler = self.set_default_scheduler( - model, discriminator_optimizer - ) + self.autoencoder_scheduler = scheduler - self.autoencoder_optimizer = autoencoder_optimizer - self.discriminator_optimizer = discriminator_optimizer - self.autoencoder_scheduler = autoencoder_scheduler - self.discriminator_scheduler = discriminator_scheduler + def set_discriminator_optimizer(self): + discriminator_cls = getattr( + optim, self.training_config.discriminator_optimizer_cls + ) - self.optimizer = None + if self.training_config.discriminator_optimizer_params is not None: + if self.distributed: + discriminator_optimizer = discriminator_cls( + self.model.module.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + **self.training_config.discriminator_optimizer_params, + ) + else: + discriminator_optimizer = discriminator_cls( + self.model.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + **self.training_config.discriminator_optimizer_params, + ) - def set_default_autoencoder_optimizer(self, model: BaseAE) -> torch.optim.Optimizer: + else: + if self.distributed: + discriminator_optimizer = discriminator_cls( + self.model.module.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + ) + else: + discriminator_optimizer = discriminator_cls( + self.model.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + ) - optimizer = torch.optim.Adam( - itertools.chain(model.encoder.parameters(), model.decoder.parameters()), - lr=self.training_config.learning_rate, - weight_decay=self.training_config.autoencoder_optim_decay, - ) + self.discriminator_optimizer = discriminator_optimizer - return optimizer + def set_discriminator_scheduler(self) -> torch.optim.lr_scheduler: + if self.training_config.discriminator_scheduler_cls is not None: + discriminator_scheduler_cls = getattr( + lr_scheduler, self.training_config.discriminator_scheduler_cls + ) - def set_default_discriminator_optimizer( - self, model: BaseAE - ) -> torch.optim.Optimizer: + if self.training_config.discriminator_scheduler_params is not None: + scheduler = discriminator_scheduler_cls( + self.discriminator_optimizer, + **self.training_config.discriminator_scheduler_params, + ) + else: + scheduler = discriminator_scheduler_cls(self.discriminator_optimizer) - optimizer = optim.Adam( - model.discriminator.parameters(), - lr=self.training_config.learning_rate, - weight_decay=self.training_config.discriminator_optim_decay, - ) + else: + scheduler = None - return optimizer + self.discriminator_scheduler = scheduler def _optimizers_step(self, model_output): @@ -140,18 +193,45 @@ def _optimizers_step(self, model_output): self.discriminator_optimizer.step() def _schedulers_step(self, autoencoder_metrics=None, discriminator_metrics=None): - if isinstance(self.autoencoder_scheduler, ReduceLROnPlateau): + + if self.autoencoder_scheduler is None: + pass + + elif isinstance(self.autoencoder_scheduler, lr_scheduler.ReduceLROnPlateau): self.autoencoder_scheduler.step(autoencoder_metrics) else: self.autoencoder_scheduler.step() - if isinstance(self.discriminator_scheduler, ReduceLROnPlateau): + if self.discriminator_scheduler is None: + pass + + elif isinstance(self.discriminator_scheduler, lr_scheduler.ReduceLROnPlateau): self.discriminator_scheduler.step(discriminator_metrics) else: self.discriminator_scheduler.step() + def prepare_training(self): + """Sets up the trainer for training""" + + # set random seed + set_seed(self.training_config.seed) + + # set autoencoder optimizer and scheduler + self.set_autoencoder_optimizer() + self.set_autoencoder_scheduler() + + # set discriminator optimizer and scheduler + self.set_discriminator_optimizer() + self.set_discriminator_scheduler() + + # create foder for saving + self._set_output_dir() + + # set callbacks + self._setup_callbacks() + def train(self, log_output_dir: str = None): """This function is the main training function @@ -159,72 +239,39 @@ def train(self, log_output_dir: str = None): log_output_dir (str): The path in which the log will be stored """ + self.prepare_training() + self.callback_handler.on_train_begin( - training_config=self.training_config, model_config=self.model.model_config + training_config=self.training_config, model_config=self.model_config ) - # run sanity check on the model - self._run_model_sanity_check(self.model, self.train_loader) - - logger.info("Model passed sanity check !\n") - - self._training_signature = ( - str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-") - ) + log_verbose = False - training_dir = os.path.join( - self.training_config.output_dir, - f"{self.model.model_name}_training_{self._training_signature}", + msg = ( + f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" + " - per_device_train_batch_size: " + f"{self.training_config.per_device_train_batch_size}\n" + " - per_device_eval_batch_size: " + f"{self.training_config.per_device_eval_batch_size}\n" + f" - checkpoint saving every: {self.training_config.steps_saving}\n" + f"Autoencoder Optimizer: {self.autoencoder_optimizer}\n" + f"Autoencoder Scheduler: {self.autoencoder_scheduler}\n" + f"Discriminator Optimizer: {self.discriminator_optimizer}\n" + f"Discriminator Scheduler: {self.discriminator_scheduler}\n" ) - self.training_dir = training_dir - - if not os.path.exists(training_dir): - os.makedirs(training_dir) - logger.info( - f"Created {training_dir}. \n" - "Training config, checkpoints and final model will be saved here.\n" - ) - - log_verbose = False + if self.is_main_process: + logger.info(msg) # set up log file - if log_output_dir is not None: - log_dir = log_output_dir + if log_output_dir is not None and self.is_main_process: log_verbose = True + file_logger = self._get_file_logger(log_output_dir=log_output_dir) - # if dir does not exist create it - if not os.path.exists(log_dir): - os.makedirs(log_dir) - logger.info(f"Created {log_dir} folder since did not exists.") - logger.info("Training logs will be recodered here.\n") - logger.info(" -> Training can be monitored here.\n") - - # create and set logger - log_name = f"training_logs_{self._training_signature}" - - file_logger = logging.getLogger(log_name) - file_logger.setLevel(logging.INFO) - f_handler = logging.FileHandler( - os.path.join(log_dir, f"training_logs_{self._training_signature}.log") - ) - f_handler.setLevel(logging.INFO) - file_logger.addHandler(f_handler) - - # Do not output logs in the console - file_logger.propagate = False - - file_logger.info("Training started !\n") - file_logger.info( - f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" - f" - batch_size: {self.training_config.batch_size}\n" - f" - checkpoint saving every {self.training_config.steps_saving}\n" - ) - - file_logger.info(f"Model Architecture: {self.model}\n") - file_logger.info(f"Optimizer: {self.optimizer}\n") + file_logger.info(msg) - logger.info("Successfully launched training !\n") + if self.is_main_process: + logger.info("Successfully launched training !\n") # set best losses for early stopping best_train_loss = 1e10 @@ -296,6 +343,7 @@ def train(self, log_output_dir: str = None): if ( self.training_config.steps_predict is not None and epoch % self.training_config.steps_predict == 0 + and self.is_main_process ): true_data, reconstructions, generations = self.predict(best_model) @@ -315,24 +363,33 @@ def train(self, log_output_dir: str = None): self.training_config.steps_saving is not None and epoch % self.training_config.steps_saving == 0 ): - self.save_checkpoint( - model=best_model, dir_path=training_dir, epoch=epoch - ) - logger.info(f"Saved checkpoint at epoch {epoch}\n") + if self.is_main_process: + self.save_checkpoint( + model=best_model, dir_path=self.training_dir, epoch=epoch + ) + logger.info(f"Saved checkpoint at epoch {epoch}\n") - if log_verbose: - file_logger.info(f"Saved checkpoint at epoch {epoch}\n") + if log_verbose: + file_logger.info(f"Saved checkpoint at epoch {epoch}\n") self.callback_handler.on_log( - self.training_config, metrics, logger=logger, global_step=epoch + self.training_config, + metrics, + logger=logger, + global_step=epoch, + rank=self.rank, ) - final_dir = os.path.join(training_dir, "final_model") + final_dir = os.path.join(self.training_dir, "final_model") - self.save_model(best_model, dir_path=final_dir) - logger.info("----------------------------------") - logger.info("Training ended!") - logger.info(f"Saved final model in {final_dir}") + if self.is_main_process: + self.save_model(best_model, dir_path=final_dir) + logger.info("----------------------------------") + logger.info("Training ended!") + logger.info(f"Saved final model in {final_dir}") + + if self.distributed: + dist.destroy_process_group() self.callback_handler.on_train_end(self.training_config) @@ -349,6 +406,7 @@ def eval_step(self, epoch: int): training_config=self.training_config, eval_loader=self.eval_loader, epoch=epoch, + rank=self.rank, ) self.model.eval() @@ -365,12 +423,18 @@ def eval_step(self, epoch: int): with torch.no_grad(): model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) except RuntimeError: model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) autoencoder_loss = model_output.autoencoder_loss @@ -406,6 +470,7 @@ def train_step(self, epoch: int): training_config=self.training_config, train_loader=self.train_loader, epoch=epoch, + rank=self.rank, ) # set model in train model @@ -420,7 +485,10 @@ def train_step(self, epoch: int): inputs = self._set_inputs_to_device(inputs) model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.train_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.train_loader.dataset), + uses_ddp=self.distributed, ) self._optimizers_step(model_output) @@ -439,7 +507,10 @@ def train_step(self, epoch: int): ) # Allows model updates if needed - self.model.update() + if self.distributed: + self.model.module.update() + else: + self.model.update() epoch_autoencoder_loss /= len(self.train_loader) epoch_discriminator_loss /= len(self.train_loader) @@ -470,7 +541,11 @@ def save_checkpoint(self, model: BaseAE, dir_path, epoch: int): ) # save model - model.save(checkpoint_dir) + if self.distributed: + model.module.save(checkpoint_dir) + + else: + model.save(checkpoint_dir) # save training config self.training_config.save_json(checkpoint_dir, "training_config") diff --git a/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py b/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py index de3d35dc..460f3700 100644 --- a/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py +++ b/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py @@ -1,3 +1,6 @@ +from typing import Union + +import torch.nn as nn from pydantic.dataclasses import dataclass from ..base_trainer import BaseTrainerConfig @@ -12,18 +15,179 @@ class AdversarialTrainerConfig(BaseTrainerConfig): output_dir (str): The directory where model checkpoints, configs and final model will be stored. Default: None. - - batch_size (int): The number of training samples per batch. Default 100 + per_device_train_batch_size (int): The number of training samples per batch and per device. + Default 64 + per_device_eval_batch_size (int): The number of evaluation samples per batch and per device. + Default 64 num_epochs (int): The maximal number of epochs for training. Default: 100 - learning_rate (int): The learning rate applied to the `Optimizer`. Default: 1e-4 + train_dataloader_num_workers (int): Number of subprocesses to use for train data loading. + 0 means that the data will be loaded in the main process. Default: 0 + eval_dataloader_num_workers (int): Number of subprocesses to use for evaluation data + loading. 0 means that the data will be loaded in the main process. Default: 0 + autoencoder_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + the training of the autoencoder. Default: :class:`~torch.optim.Adam`. + autoencoder_optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer` for the autoencoder. If None, uses the default parameters. + Default: None. + autoencoder_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + the training of the autoencoder. Default :class:`~torch.optim.Adam`. + autoencoder_scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler` for the autoencoder. If None, uses the default parameters. + Default: None. + discriminator_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + the training of the discriminator. Default: :class:`~torch.optim.Adam`. + discriminator_optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer` for the discriminator. If None, uses the default parameters. + Default: None. + discriminator_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + the training of the discriminator. Default :class:`~torch.optim.Adam`. + discriminator_scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler` for the discriminator. If None, uses the default parameters. + Default: None. + autoencoder_learning_rate (int): The learning rate applied to the `Optimizer` for the encoder. + Default: 1e-4 + discriminator_learning_rate (int): The learning rate applied to the `Optimizer` for the + discriminator. Default: 1e-4 steps_saving (int): A model checkpoint will be saved every `steps_saving` epoch. Default: None - keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False. + steps_predict (int): A prediction using the best model will be run every `steps_predict` + epoch. Default: None + keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False seed (int): The random seed for reproducibility no_cuda (bool): Disable `cuda` training. Default: False - encoderdecoder_optim_decay (float): The decay to apply in the optimizer. Default: 0 - discriminator_optim_decay (float): The decay to apply in the optimizer. Default: 0 + world_size (int): The total number of process to run. Default: -1 + local_rank (int): The rank of the node for distributed training. Default: -1 + rank (int): The rank of the process for distributed training. Default: -1 + dist_backend (str): The distributed backend to use. Default: 'nccl' + master_addr (str): The master address for distributed training. Default: 'localhost' + master_port (str): The master port for distributed training. Default: '12345' """ - autoencoder_optim_decay: float = 0 - discriminator_optim_decay: float = 0 + autoencoder_optimizer_cls: str = "Adam" + autoencoder_optimizer_params: Union[dict, None] = None + autoencoder_scheduler_cls: str = None + autoencoder_scheduler_params: Union[dict, None] = None + discriminator_optimizer_cls: str = "Adam" + discriminator_optimizer_params: Union[dict, None] = None + discriminator_scheduler_cls: str = None + discriminator_scheduler_params: Union[dict, None] = None + autoencoder_learning_rate: float = 1e-4 + discriminator_learning_rate: float = 1e-4 + + def __post_init_post_parse__(self): + """Check compatibilty""" + + # Autoencoder optimizer and scheduler + try: + import torch.optim as optim + + autoencoder_optimizer_cls = getattr(optim, self.autoencoder_optimizer_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.autoencoder_optimizer_cls}` autoencoder optimizer " + "from 'torch.optim'. Check spelling and that it is part of " + "'torch.optim.Optimizers.'" + ) + if self.autoencoder_optimizer_params is not None: + try: + autoencoder_optimizer = autoencoder_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.autoencoder_learning_rate, + **self.autoencoder_optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{autoencoder_optimizer_cls}` optimizer. " + f"Got {self.autoencoder_optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + autoencoder_optimizer = autoencoder_optimizer_cls( + nn.Linear(2, 2).parameters(), lr=self.autoencoder_learning_rate + ) + + if self.autoencoder_scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + autoencoder_scheduder_cls = getattr( + schedulers, self.autoencoder_scheduler_cls + ) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.autoencoder_scheduler_cls}` autoencoder scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) + + if self.autoencoder_scheduler_params is not None: + try: + autoencoder_scheduder_cls( + autoencoder_optimizer, **self.autoencoder_scheduler_params + ) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{autoencoder_scheduder_cls}` scheduler. " + f"Got {self.autoencoder_scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + + # Discriminator optimizer and scheduler + try: + discriminator_optimizer_cls = getattr( + optim, self.discriminator_optimizer_cls + ) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.discriminator_optimizer_cls}` discriminator optimizer " + "from 'torch.optim'. Check spelling and that it is part of " + "'torch.optim.Optimizers.'" + ) + if self.discriminator_optimizer_params is not None: + try: + discriminator_optimizer = discriminator_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.discriminator_learning_rate, + **self.discriminator_optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{discriminator_optimizer_cls}` optimizer. " + f"Got {self.discriminator_optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + discriminator_optimizer = discriminator_optimizer_cls( + nn.Linear(2, 2).parameters(), lr=self.discriminator_learning_rate + ) + + if self.discriminator_scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + discriminator_scheduder_cls = getattr( + schedulers, self.discriminator_scheduler_cls + ) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.discriminator_scheduler_cls}` discriminator scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) + + if self.discriminator_scheduler_params is not None: + try: + discriminator_scheduder_cls( + discriminator_optimizer, + **self.discriminator_scheduler_params, + ) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{discriminator_scheduder_cls}` scheduler. " + f"Got {self.discriminator_scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e diff --git a/src/pythae/trainers/base_trainer/base_trainer.py b/src/pythae/trainers/base_trainer/base_trainer.py index 93f4bc1b..6b2af011 100644 --- a/src/pythae/trainers/base_trainer/base_trainer.py +++ b/src/pythae/trainers/base_trainer/base_trainer.py @@ -5,9 +5,12 @@ from typing import Any, Dict, List, Optional import torch +import torch.distributed as dist import torch.optim as optim -from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.optim.lr_scheduler as lr_scheduler +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from ...customexception import ModelError from ...data.datasets import BaseDataset @@ -45,12 +48,6 @@ class BaseTrainer: parameters used for training. If None, a basic training instance of :class:`BaseTrainerConfig` is used. Default: None. - optimizer (~torch.optim.Optimizer): An instance of `torch.optim.Optimizer` used for - training. If None, a :class:`~torch.optim.Adam` optimizer is used. Default: None. - - scheduler (~torch.optim.lr_scheduler): An instance of `torch.optim.Optimizer` used for - training. If None, a :class:`~torch.optim.Adam` optimizer is used. Default: None. - callbacks (List[~pythae.trainers.training_callbacks.TrainingCallbacks]): A list of callbacks to use during training. """ @@ -61,8 +58,6 @@ def __init__( train_dataset: BaseDataset, eval_dataset: Optional[BaseDataset] = None, training_config: Optional[BaseTrainerConfig] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, callbacks: List[TrainingCallback] = None, ): @@ -73,46 +68,43 @@ def __init__( output_dir = "dummy_output_dir" training_config.output_dir = output_dir - if not os.path.exists(training_config.output_dir): - os.makedirs(training_config.output_dir) - logger.info( - f"Created {training_config.output_dir} folder since did not exist.\n" - ) - self.training_config = training_config + self.model_config = model.model_config + self.model_name = model.model_name - set_seed(self.training_config.seed) + # for distributed training + self.world_size = self.training_config.world_size + self.local_rank = self.training_config.local_rank + self.rank = self.training_config.rank + self.dist_backend = self.training_config.dist_backend - device = ( - "cuda" - if torch.cuda.is_available() and not training_config.no_cuda - else "cpu" - ) + if self.world_size > 1: + self.distributed = True + else: + self.distributed = False + + if self.distributed: + device = self._setup_devices() + + else: + device = ( + "cuda" + if torch.cuda.is_available() and not self.training_config.no_cuda + else "cpu" + ) + + self.device = device # place model on device model = model.to(device) model.device = device - # set optimizer - if optimizer is None: - optimizer = self.set_default_optimizer(model) - - else: - optimizer = self._set_optimizer_on_device(optimizer, device) - - # set scheduler - if scheduler is None: - scheduler = self.set_default_scheduler(model, optimizer) + if self.distributed: + model = DDP(model, device_ids=[self.local_rank]) self.train_dataset = train_dataset self.eval_dataset = eval_dataset - self.model = model - self.optimizer = optimizer - self.scheduler = scheduler - - self.device = device - # Define the loaders train_loader = self.get_train_dataloader(train_dataset) @@ -128,53 +120,176 @@ def __init__( self.train_loader = train_loader self.eval_loader = eval_loader + self.callbacks = callbacks - if callbacks is None: - callbacks = [TrainingCallback()] + # run sanity check on the model + self._run_model_sanity_check(model, train_loader) - self.callback_handler = CallbackHandler( - callbacks=callbacks, model=model, optimizer=optimizer, scheduler=scheduler - ) + if self.is_main_process: + logger.info("Model passed sanity check !\n" "Ready for training.\n") - self.callback_handler.add_callback(ProgressBarCallback()) - self.callback_handler.add_callback(MetricConsolePrinterCallback()) + self.model = model + + @property + def is_main_process(self): + if self.rank == 0 or self.rank == -1: + return True + else: + return False + + def _setup_devices(self): + """Sets up the devices to perform distributed training.""" + + if dist.is_available() and dist.is_initialized() and self.local_rank == -1: + logger.warning( + "torch.distributed process group is initialized, but local_rank == -1. " + ) + if self.training_config.no_cuda: + self._n_gpus = 0 + device = "cpu" + + else: + torch.cuda.set_device(self.local_rank) + device = torch.device("cuda", self.local_rank) + + if not dist.is_initialized(): + dist.init_process_group( + backend=self.dist_backend, + init_method="env://", + world_size=self.world_size, + rank=self.rank, + ) + + return device def get_train_dataloader( self, train_dataset: BaseDataset ) -> torch.utils.data.DataLoader: - + if self.distributed: + train_sampler = DistributedSampler( + train_dataset, num_replicas=self.world_size, rank=self.rank + ) + else: + train_sampler = None return DataLoader( dataset=train_dataset, - batch_size=self.training_config.batch_size, - shuffle=True, + batch_size=self.training_config.per_device_train_batch_size, + num_workers=self.training_config.train_dataloader_num_workers, + shuffle=(train_sampler is None), + sampler=train_sampler, ) def get_eval_dataloader( self, eval_dataset: BaseDataset ) -> torch.utils.data.DataLoader: + if self.distributed: + eval_sampler = DistributedSampler( + eval_dataset, num_replicas=self.world_size, rank=self.rank + ) + else: + eval_sampler = None return DataLoader( dataset=eval_dataset, - batch_size=self.training_config.batch_size, - shuffle=False, + batch_size=self.training_config.per_device_eval_batch_size, + num_workers=self.training_config.eval_dataloader_num_workers, + shuffle=(eval_sampler is None), + sampler=eval_sampler, ) - def set_default_optimizer(self, model: BaseAE) -> torch.optim.Optimizer: + def set_optimizer(self): + optimizer_cls = getattr(optim, self.training_config.optimizer_cls) + + if self.training_config.optimizer_params is not None: + optimizer = optimizer_cls( + self.model.parameters(), + lr=self.training_config.learning_rate, + **self.training_config.optimizer_params, + ) + else: + optimizer = optimizer_cls( + self.model.parameters(), lr=self.training_config.learning_rate + ) + + self.optimizer = optimizer + + def set_scheduler(self): + if self.training_config.scheduler_cls is not None: + scheduler_cls = getattr(lr_scheduler, self.training_config.scheduler_cls) - optimizer = optim.Adam( - model.parameters(), lr=self.training_config.learning_rate + if self.training_config.scheduler_params is not None: + scheduler = scheduler_cls( + self.optimizer, **self.training_config.scheduler_params + ) + else: + scheduler = scheduler_cls(self.optimizer) + + else: + scheduler = None + + self.scheduler = scheduler + + def _set_output_dir(self): + # Create folder + if not os.path.exists(self.training_config.output_dir) and self.is_main_process: + os.makedirs(self.training_config.output_dir, exist_ok=True) + logger.info( + f"Created {self.training_config.output_dir} folder since did not exist.\n" + ) + + self._training_signature = ( + str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-") ) - return optimizer + training_dir = os.path.join( + self.training_config.output_dir, + f"{self.model_name}_training_{self._training_signature}", + ) - def set_default_scheduler( - self, model: BaseAE, optimizer: torch.optim.Optimizer - ) -> torch.optim.lr_scheduler: + self.training_dir = training_dir - scheduler = optim.lr_scheduler.ReduceLROnPlateau( - optimizer, factor=0.5, patience=10, verbose=True + if not os.path.exists(training_dir) and self.is_main_process: + os.makedirs(training_dir, exist_ok=True) + logger.info( + f"Created {training_dir}. \n" + "Training config, checkpoints and final model will be saved here.\n" + ) + + def _get_file_logger(self, log_output_dir): + log_dir = log_output_dir + + # if dir does not exist create it + if not os.path.exists(log_dir) and self.is_main_process: + os.makedirs(log_dir, exist_ok=True) + logger.info(f"Created {log_dir} folder since did not exists.") + logger.info("Training logs will be recodered here.\n") + logger.info(" -> Training can be monitored here.\n") + + # create and set logger + log_name = f"training_logs_{self._training_signature}" + + file_logger = logging.getLogger(log_name) + file_logger.setLevel(logging.INFO) + f_handler = logging.FileHandler( + os.path.join(log_dir, f"training_logs_{self._training_signature}.log") ) + f_handler.setLevel(logging.INFO) + file_logger.addHandler(f_handler) + + # Do not output logs in the console + file_logger.propagate = False + + return file_logger - return scheduler + def _setup_callbacks(self): + if self.callbacks is None: + self.callbacks = [TrainingCallback()] + + self.callback_handler = CallbackHandler( + callbacks=self.callbacks, model=self.model + ) + + self.callback_handler.add_callback(ProgressBarCallback()) + self.callback_handler.add_callback(MetricConsolePrinterCallback()) def _run_model_sanity_check(self, model, loader): try: @@ -227,6 +342,7 @@ def _set_inputs_to_device(self, inputs: Dict[str, Any]): return inputs_on_device def _optimizers_step(self, model_output=None): + loss = model_output.loss self.optimizer.zero_grad() @@ -234,12 +350,32 @@ def _optimizers_step(self, model_output=None): self.optimizer.step() def _schedulers_step(self, metrics=None): - if isinstance(self.scheduler, ReduceLROnPlateau): + if self.scheduler is None: + pass + + elif isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau): self.scheduler.step(metrics) else: self.scheduler.step() + def prepare_training(self): + """Sets up the trainer for training""" + # set random seed + set_seed(self.training_config.seed) + + # set optimizer + self.set_optimizer() + + # set scheduler + self.set_scheduler() + + # create folder for saving + self._set_output_dir() + + # set callbacks + self._setup_callbacks() + def train(self, log_output_dir: str = None): """This function is the main training function @@ -247,72 +383,37 @@ def train(self, log_output_dir: str = None): log_output_dir (str): The path in which the log will be stored """ + self.prepare_training() + self.callback_handler.on_train_begin( - training_config=self.training_config, model_config=self.model.model_config + training_config=self.training_config, model_config=self.model_config ) - # run sanity check on the model - self._run_model_sanity_check(self.model, self.train_loader) - - logger.info("Model passed sanity check !\n") - - self._training_signature = ( - str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-") - ) + log_verbose = False - training_dir = os.path.join( - self.training_config.output_dir, - f"{self.model.model_name}_training_{self._training_signature}", + msg = ( + f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" + " - per_device_train_batch_size: " + f"{self.training_config.per_device_train_batch_size}\n" + " - per_device_eval_batch_size: " + f"{self.training_config.per_device_eval_batch_size}\n" + f" - checkpoint saving every: {self.training_config.steps_saving}\n" + f"Optimizer: {self.optimizer}\n" + f"Scheduler: {self.scheduler}\n" ) - self.training_dir = training_dir - - if not os.path.exists(training_dir): - os.makedirs(training_dir) - logger.info( - f"Created {training_dir}. \n" - "Training config, checkpoints and final model will be saved here.\n" - ) - - log_verbose = False + if self.is_main_process: + logger.info(msg) # set up log file - if log_output_dir is not None: - log_dir = log_output_dir + if log_output_dir is not None and self.is_main_process: log_verbose = True + file_logger = self._get_file_logger(log_output_dir=log_output_dir) - # if dir does not exist create it - if not os.path.exists(log_dir): - os.makedirs(log_dir) - logger.info(f"Created {log_dir} folder since did not exists.") - logger.info("Training logs will be recodered here.\n") - logger.info(" -> Training can be monitored here.\n") + file_logger.info(msg) - # create and set logger - log_name = f"training_logs_{self._training_signature}" - - file_logger = logging.getLogger(log_name) - file_logger.setLevel(logging.INFO) - f_handler = logging.FileHandler( - os.path.join(log_dir, f"training_logs_{self._training_signature}.log") - ) - f_handler.setLevel(logging.INFO) - file_logger.addHandler(f_handler) - - # Do not output logs in the console - file_logger.propagate = False - - file_logger.info("Training started !\n") - file_logger.info( - f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" - f" - batch_size: {self.training_config.batch_size}\n" - f" - checkpoint saving every {self.training_config.steps_saving}\n" - ) - - file_logger.info(f"Model Architecture: {self.model}\n") - file_logger.info(f"Optimizer: {self.optimizer}\n") - - logger.info("Successfully launched training !\n") + if self.is_main_process: + logger.info("Successfully launched training !\n") # set best losses for early stopping best_train_loss = 1e10 @@ -360,6 +461,7 @@ def train(self, log_output_dir: str = None): if ( self.training_config.steps_predict is not None and epoch % self.training_config.steps_predict == 0 + and self.is_main_process ): true_data, reconstructions, generations = self.predict(best_model) @@ -378,23 +480,33 @@ def train(self, log_output_dir: str = None): self.training_config.steps_saving is not None and epoch % self.training_config.steps_saving == 0 ): - self.save_checkpoint( - model=best_model, dir_path=training_dir, epoch=epoch - ) - logger.info(f"Saved checkpoint at epoch {epoch}\n") + if self.is_main_process: + self.save_checkpoint( + model=best_model, dir_path=self.training_dir, epoch=epoch + ) + logger.info(f"Saved checkpoint at epoch {epoch}\n") - if log_verbose: - file_logger.info(f"Saved checkpoint at epoch {epoch}\n") + if log_verbose: + file_logger.info(f"Saved checkpoint at epoch {epoch}\n") self.callback_handler.on_log( - self.training_config, metrics, logger=logger, global_step=epoch + self.training_config, + metrics, + logger=logger, + global_step=epoch, + rank=self.rank, ) - final_dir = os.path.join(training_dir, "final_model") + final_dir = os.path.join(self.training_dir, "final_model") + + if self.is_main_process: + self.save_model(best_model, dir_path=final_dir) + + logger.info("Training ended!") + logger.info(f"Saved final model in {final_dir}") - self.save_model(best_model, dir_path=final_dir) - logger.info("Training ended!") - logger.info(f"Saved final model in {final_dir}") + if self.distributed: + dist.destroy_process_group() self.callback_handler.on_train_end(self.training_config) @@ -412,6 +524,7 @@ def eval_step(self, epoch: int): training_config=self.training_config, eval_loader=self.eval_loader, epoch=epoch, + rank=self.rank, ) self.model.eval() @@ -426,12 +539,18 @@ def eval_step(self, epoch: int): with torch.no_grad(): model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) except RuntimeError: model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) loss = model_output.loss @@ -460,6 +579,7 @@ def train_step(self, epoch: int): training_config=self.training_config, train_loader=self.train_loader, epoch=epoch, + rank=self.rank, ) # set model in train model @@ -472,7 +592,10 @@ def train_step(self, epoch: int): inputs = self._set_inputs_to_device(inputs) model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.train_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.train_loader.dataset), + uses_ddp=self.distributed, ) self._optimizers_step(model_output) @@ -489,7 +612,10 @@ def train_step(self, epoch: int): ) # Allows model updates if needed - self.model.update() + if self.distributed: + self.model.module.update() + else: + self.model.update() epoch_loss /= len(self.train_loader) @@ -507,7 +633,11 @@ def save_model(self, model: BaseAE, dir_path: str): os.makedirs(dir_path) # save model - model.save(dir_path) + if self.distributed: + model.module.save(dir_path) + + else: + model.save(dir_path) # save training config self.training_config.save_json(dir_path, "training_config") @@ -533,13 +663,18 @@ def save_checkpoint(self, model: BaseAE, dir_path, epoch: int): ) # save scheduler - torch.save( - deepcopy(self.scheduler.state_dict()), - os.path.join(checkpoint_dir, "scheduler.pt"), - ) + if self.scheduler is not None: + torch.save( + deepcopy(self.scheduler.state_dict()), + os.path.join(checkpoint_dir, "scheduler.pt"), + ) # save model - model.save(checkpoint_dir) + if self.distributed: + model.module.save(checkpoint_dir) + + else: + model.save(checkpoint_dir) # save training config self.training_config.save_json(checkpoint_dir, "training_config") @@ -548,17 +683,22 @@ def predict(self, model: BaseAE): model.eval() - # with torch.no_grad(): - - inputs = self.eval_loader.dataset[ - : min(self.eval_loader.dataset.data.shape[0], 10) - ] + inputs = next(iter(self.eval_loader)) inputs = self._set_inputs_to_device(inputs) model_out = model(inputs) - reconstructions = model_out.recon_x.cpu().detach() - z_enc = model_out.z + reconstructions = model_out.recon_x.cpu().detach()[ + : min(inputs["data"].shape[0], 10) + ] + z_enc = model_out.z[: min(inputs["data"].shape[0], 10)] z = torch.randn_like(z_enc) - normal_generation = model.decoder(z).reconstruction.detach().cpu() + if self.distributed: + normal_generation = model.module.decoder(z).reconstruction.detach().cpu() + else: + normal_generation = model.decoder(z).reconstruction.detach().cpu() - return inputs["data"], reconstructions, normal_generation + return ( + inputs["data"][: min(inputs["data"].shape[0], 10)], + reconstructions, + normal_generation, + ) diff --git a/src/pythae/trainers/base_trainer/base_training_config.py b/src/pythae/trainers/base_trainer/base_training_config.py index 67e83135..1ace7199 100644 --- a/src/pythae/trainers/base_trainer/base_training_config.py +++ b/src/pythae/trainers/base_trainer/base_training_config.py @@ -1,5 +1,8 @@ +import os +from dataclasses import field from typing import Union +import torch.nn as nn from pydantic.dataclasses import dataclass from ...config import BaseConfig @@ -14,25 +17,135 @@ class BaseTrainerConfig(BaseConfig): output_dir (str): The directory where model checkpoints, configs and final model will be stored. Default: None. - - batch_size (int): The number of training samples per batch. Default 100 + per_device_train_batch_size (int): The number of training samples per batch and per device. + Default 64 + per_device_eval_batch_size (int): The number of evaluation samples per batch and per device. + Default 64 num_epochs (int): The maximal number of epochs for training. Default: 100 + train_dataloader_num_workers (int): Number of subprocesses to use for train data loading. + 0 means that the data will be loaded in the main process. Default: 0 + eval_dataloader_num_workers (int): Number of subprocesses to use for evaluation data + loading. 0 means that the data will be loaded in the main process. Default: 0 + optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + training. Default: :class:`~torch.optim.Adam`. + optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer`. If None, uses the default parameters. Default: None. + scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + training. If None, no scheduler is used. Default None. + scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler`. If None, uses the default parameters. Default: None. learning_rate (int): The learning rate applied to the `Optimizer`. Default: 1e-4 steps_saving (int): A model checkpoint will be saved every `steps_saving` epoch. Default: None - steps_saving (int): A prediction using the best model will be run every `steps_predict` + steps_predict (int): A prediction using the best model will be run every `steps_predict` epoch. Default: None - keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False. + keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False seed (int): The random seed for reproducibility no_cuda (bool): Disable `cuda` training. Default: False + world_size (int): The total number of process to run. Default: -1 + local_rank (int): The rank of the node for distributed training. Default: -1 + rank (int): The rank of the process for distributed training. Default: -1 + dist_backend (str): The distributed backend to use. Default: 'nccl' + master_addr (str): The master address for distributed training. Default: 'localhost' + master_port (str): The master port for distributed training. Default: '12345' """ output_dir: str = None - batch_size: int = 100 + per_device_train_batch_size: int = 64 + per_device_eval_batch_size: int = 64 num_epochs: int = 100 + train_dataloader_num_workers: int = 0 + eval_dataloader_num_workers: int = 0 + optimizer_cls: str = "Adam" + optimizer_params: Union[dict, None] = None + scheduler_cls: Union[str, None] = None + scheduler_params: Union[dict, None] = None learning_rate: float = 1e-4 steps_saving: Union[int, None] = None steps_predict: Union[int, None] = None keep_best_on_train: bool = False seed: int = 8 no_cuda: bool = False + world_size: int = field(default=-1) + local_rank: int = field(default=-1) + rank: int = field(default=-1) + dist_backend: str = field(default="nccl") + master_addr: str = field(default="localhost") + master_port: str = field(default="12345") + + def __post_init__(self): + super().__post_init__() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if self.local_rank == -1 and env_local_rank != -1: + self.local_rank = env_local_rank + + env_world_size = int(os.environ.get("WORLD_SIZE", -1)) + if self.world_size == -1 and env_world_size != -1: + self.world_size = env_world_size + + env_rank = int(os.environ.get("RANK", -1)) + if self.rank == -1 and env_rank != -1: + self.rank = env_rank + + env_master_addr = os.environ.get("MASTER_ADDR", "localhost") + if self.master_addr == "localhost" and env_master_addr != "localhost": + self.master_addr = env_master_addr + os.environ["MASTER_ADDR"] = self.master_addr + + env_master_port = os.environ.get("MASTER_PORT", "12345") + if self.master_port == "12345" and env_master_port != "12345": + self.master_port = env_master_port + os.environ["MASTER_PORT"] = self.master_port + + def __post_init_post_parse__(self): + """Check compatibilty""" + try: + import torch.optim as optim + + optimizer_cls = getattr(optim, self.optimizer_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.optimizer_cls}` optimizer from 'torch.optim'. " + "Check spelling and that it is part of 'torch.optim.Optimizers.'" + ) + if self.optimizer_params is not None: + try: + optimizer = optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.learning_rate, + **self.optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{optimizer_cls}` optimizer. " + f"Got {self.optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + optimizer = optimizer_cls( + nn.Linear(2, 2).parameters(), lr=self.learning_rate + ) + + if self.scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + scheduder_cls = getattr(schedulers, self.scheduler_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.scheduler_cls}` scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) + + if self.scheduler_params is not None: + try: + scheduder_cls(optimizer, **self.scheduler_params) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{scheduder_cls}` scheduler. " + f"Got {self.scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e diff --git a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py index ae95f8ef..5e46dfeb 100644 --- a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py +++ b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py @@ -6,12 +6,14 @@ from typing import List, Optional import torch +import torch.distributed as dist import torch.optim as optim -from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.optim.lr_scheduler as lr_scheduler from ...data.datasets import BaseDataset from ...models import BaseAE from ..base_trainer import BaseTrainer +from ..trainer_utils import set_seed from ..training_callbacks import TrainingCallback from .coupled_optimizer_adversarial_trainer_config import ( CoupledOptimizerAdversarialTrainerConfig, @@ -57,12 +59,6 @@ def __init__( train_dataset: BaseDataset, eval_dataset: Optional[BaseDataset] = None, training_config: Optional[CoupledOptimizerAdversarialTrainerConfig] = None, - encoder_optimizer: Optional[torch.optim.Optimizer] = None, - decoder_optimizer: Optional[torch.optim.Optimizer] = None, - discriminator_optimizer: Optional[torch.optim.Optimizer] = None, - encoder_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, - decoder_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, - discriminator_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, callbacks: List[TrainingCallback] = None, ): @@ -72,88 +68,162 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=None, callbacks=callbacks, ) - # set encoder optimizer - if encoder_optimizer is None: - encoder_optimizer = self.set_default_encoder_optimizer(model) + def set_encoder_optimizer(self): + encoder_optimizer_cls = getattr( + optim, self.training_config.encoder_optimizer_cls + ) + if self.training_config.encoder_optimizer_params is not None: + if self.distributed: + encoder_optimizer = encoder_optimizer_cls( + self.model.module.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + **self.training_config.encoder_optimizer_params, + ) + else: + encoder_optimizer = encoder_optimizer_cls( + self.model.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + **self.training_config.encoder_optimizer_params, + ) else: - encoder_optimizer = self._set_optimizer_on_device( - encoder_optimizer, self.device - ) + if self.distributed: + encoder_optimizer = encoder_optimizer_cls( + self.model.module.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + ) + else: + encoder_optimizer = encoder_optimizer_cls( + self.model.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + ) + + self.encoder_optimizer = encoder_optimizer - if encoder_scheduler is None: - encoder_scheduler = self.set_default_scheduler(model, encoder_optimizer) + def set_encoder_scheduler(self): + if self.training_config.encoder_scheduler_cls is not None: + encoder_scheduler_cls = getattr( + lr_scheduler, self.training_config.encoder_scheduler_cls + ) - # set decoder optimizer - if decoder_optimizer is None: - decoder_optimizer = self.set_default_decoder_optimizer(model) + if self.training_config.encoder_scheduler_params is not None: + scheduler = encoder_scheduler_cls( + self.encoder_optimizer, + **self.training_config.encoder_scheduler_params, + ) + else: + scheduler = encoder_scheduler_cls(self.encoder_optimizer) else: - decoder_optimizer = self._set_optimizer_on_device( - decoder_optimizer, self.device - ) + scheduler = None - if decoder_scheduler is None: - decoder_scheduler = self.set_default_scheduler(model, decoder_optimizer) + self.encoder_scheduler = scheduler - # set decoder optimizer - if discriminator_optimizer is None: - discriminator_optimizer = self.set_default_discriminator_optimizer(model) + def set_decoder_optimizer(self): + decoder_optimizer_cls = getattr( + optim, self.training_config.decoder_optimizer_cls + ) + if self.training_config.decoder_optimizer_params is not None: + if self.distributed: + decoder_optimizer = decoder_optimizer_cls( + self.model.module.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + **self.training_config.decoder_optimizer_params, + ) + else: + decoder_optimizer = decoder_optimizer_cls( + self.model.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + **self.training_config.decoder_optimizer_params, + ) else: - discriminator_optimizer = self._set_optimizer_on_device( - discriminator_optimizer, self.device - ) + if self.distributed: + decoder_optimizer = decoder_optimizer_cls( + self.model.module.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + ) + else: + decoder_optimizer = decoder_optimizer_cls( + self.model.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + ) - if discriminator_scheduler is None: - discriminator_scheduler = self.set_default_scheduler( - model, discriminator_optimizer + self.decoder_optimizer = decoder_optimizer + + def set_decoder_scheduler(self): + if self.training_config.decoder_scheduler_cls is not None: + decoder_scheduler_cls = getattr( + lr_scheduler, self.training_config.decoder_scheduler_cls ) - self.encoder_optimizer = encoder_optimizer - self.decoder_optimizer = decoder_optimizer - self.discriminator_optimizer = discriminator_optimizer - self.encoder_scheduler = encoder_scheduler - self.decoder_scheduler = decoder_scheduler - self.discriminator_scheduler = discriminator_scheduler + if self.training_config.decoder_scheduler_params is not None: + scheduler = decoder_scheduler_cls( + self.decoder_optimizer, + **self.training_config.decoder_scheduler_params, + ) + else: + scheduler = decoder_scheduler_cls(self.decoder_optimizer) - self.optimizer = None + else: + scheduler = None - def set_default_encoder_optimizer(self, model: BaseAE) -> torch.optim.Optimizer: + self.decoder_scheduler = scheduler - optimizer = optim.Adam( - model.encoder.parameters(), - lr=self.training_config.learning_rate, - weight_decay=self.training_config.encoder_optim_decay, + def set_discriminator_optimizer(self): + discriminator_cls = getattr( + optim, self.training_config.discriminator_optimizer_cls ) - return optimizer + if self.training_config.discriminator_optimizer_params is not None: + if self.distributed: + discriminator_optimizer = discriminator_cls( + self.model.module.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + **self.training_config.discriminator_optimizer_params, + ) + else: + discriminator_optimizer = discriminator_cls( + self.model.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + **self.training_config.discriminator_optimizer_params, + ) - def set_default_decoder_optimizer(self, model: BaseAE) -> torch.optim.Optimizer: + else: + if self.distributed: + discriminator_optimizer = discriminator_cls( + self.model.module.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + ) + else: + discriminator_optimizer = discriminator_cls( + self.model.discriminator.parameters(), + lr=self.training_config.discriminator_learning_rate, + ) - optimizer = optim.Adam( - model.decoder.parameters(), - lr=self.training_config.learning_rate, - weight_decay=self.training_config.decoder_optim_decay, - ) + self.discriminator_optimizer = discriminator_optimizer - return optimizer + def set_discriminator_scheduler(self) -> torch.optim.lr_scheduler: + if self.training_config.discriminator_scheduler_cls is not None: + discriminator_scheduler_cls = getattr( + lr_scheduler, self.training_config.discriminator_scheduler_cls + ) - def set_default_discriminator_optimizer( - self, model: BaseAE - ) -> torch.optim.Optimizer: + if self.training_config.discriminator_scheduler_params is not None: + scheduler = discriminator_scheduler_cls( + self.discriminator_optimizer, + **self.training_config.discriminator_scheduler_params, + ) + else: + scheduler = discriminator_scheduler_cls(self.discriminator_optimizer) - optimizer = optim.Adam( - model.discriminator.parameters(), - lr=self.training_config.learning_rate, - weight_decay=self.training_config.discriminator_optim_decay, - ) + else: + scheduler = None - return optimizer + self.discriminator_scheduler = scheduler def _optimizers_step(self, model_output): @@ -186,24 +256,57 @@ def _optimizers_step(self, model_output): def _schedulers_step( self, encoder_metrics=None, decoder_metrics=None, discriminator_metrics=None ): - if isinstance(self.encoder_scheduler, ReduceLROnPlateau): + if self.encoder_scheduler is None: + pass + + elif isinstance(self.encoder_scheduler, lr_scheduler.ReduceLROnPlateau): self.encoder_scheduler.step(encoder_metrics) else: self.encoder_scheduler.step() - if isinstance(self.decoder_scheduler, ReduceLROnPlateau): + if self.decoder_scheduler is None: + pass + + elif isinstance(self.decoder_scheduler, lr_scheduler.ReduceLROnPlateau): self.decoder_scheduler.step(decoder_metrics) else: self.decoder_scheduler.step() - if isinstance(self.discriminator_scheduler, ReduceLROnPlateau): + if self.discriminator_scheduler is None: + pass + + elif isinstance(self.discriminator_scheduler, lr_scheduler.ReduceLROnPlateau): self.discriminator_scheduler.step(discriminator_metrics) else: self.discriminator_scheduler.step() + def prepare_training(self): + """Sets up the trainer for training""" + + # set random seed + set_seed(self.training_config.seed) + + # set encoder optimizer and scheduler + self.set_encoder_optimizer() + self.set_encoder_scheduler() + + # set decoder optimizer and scheduler + self.set_decoder_optimizer() + self.set_decoder_scheduler() + + # set discriminator optimizer and scheduler + self.set_discriminator_optimizer() + self.set_discriminator_scheduler() + + # create foder for saving + self._set_output_dir() + + # set callbacks + self._setup_callbacks() + def train(self, log_output_dir: str = None): """This function is the main training function @@ -211,72 +314,41 @@ def train(self, log_output_dir: str = None): log_output_dir (str): The path in which the log will be stored """ + self.prepare_training() + self.callback_handler.on_train_begin( - training_config=self.training_config, model_config=self.model.model_config + training_config=self.training_config, model_config=self.model_config ) - # run sanity check on the model - self._run_model_sanity_check(self.model, self.train_loader) - - logger.info("Model passed sanity check !\n") - - self._training_signature = ( - str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-") - ) + log_verbose = False - training_dir = os.path.join( - self.training_config.output_dir, - f"{self.model.model_name}_training_{self._training_signature}", + msg = ( + f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" + " - per_device_train_batch_size: " + f"{self.training_config.per_device_train_batch_size}\n" + " - per_device_eval_batch_size: " + f"{self.training_config.per_device_eval_batch_size}\n" + f" - checkpoint saving every: {self.training_config.steps_saving}\n" + f"Encoder Optimizer: {self.encoder_optimizer}\n" + f"Encoder Scheduler: {self.encoder_scheduler}\n" + f"Decoder Optimizer: {self.decoder_optimizer}\n" + f"Decoder Scheduler: {self.decoder_scheduler}\n" + f"Discriminator Optimizer: {self.discriminator_optimizer}\n" + f"Discriminator Scheduler: {self.discriminator_scheduler}\n" ) - self.training_dir = training_dir - - if not os.path.exists(training_dir): - os.makedirs(training_dir) - logger.info( - f"Created {training_dir}. \n" - "Training config, checkpoints and final model will be saved here.\n" - ) - - log_verbose = False + if self.is_main_process: + logger.info(msg) # set up log file - if log_output_dir is not None: - log_dir = log_output_dir + if log_output_dir is not None and self.is_main_process: log_verbose = True + file_logger = self._get_file_logger(log_output_dir=log_output_dir) - # if dir does not exist create it - if not os.path.exists(log_dir): - os.makedirs(log_dir) - logger.info(f"Created {log_dir} folder since did not exists.") - logger.info("Training logs will be recodered here.\n") - logger.info(" -> Training can be monitored here.\n") + file_logger.info(msg) - # create and set logger - log_name = f"training_logs_{self._training_signature}" - - file_logger = logging.getLogger(log_name) - file_logger.setLevel(logging.INFO) - f_handler = logging.FileHandler( - os.path.join(log_dir, f"training_logs_{self._training_signature}.log") - ) - f_handler.setLevel(logging.INFO) - file_logger.addHandler(f_handler) - - # Do not output logs in the console - file_logger.propagate = False - - file_logger.info("Training started !\n") - file_logger.info( - f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" - f" - batch_size: {self.training_config.batch_size}\n" - f" - checkpoint saving every {self.training_config.steps_saving}\n" - ) - - file_logger.info(f"Model Architecture: {self.model}\n") - file_logger.info(f"Optimizer: {self.optimizer}\n") - - logger.info("Successfully launched training !\n") + if self.is_main_process: + logger.info("Successfully launched training !\n") # set best losses for early stopping best_train_loss = 1e10 @@ -353,6 +425,7 @@ def train(self, log_output_dir: str = None): if ( self.training_config.steps_predict is not None and epoch % self.training_config.steps_predict == 0 + and self.is_main_process ): true_data, reconstructions, generations = self.predict(best_model) @@ -372,24 +445,33 @@ def train(self, log_output_dir: str = None): self.training_config.steps_saving is not None and epoch % self.training_config.steps_saving == 0 ): - self.save_checkpoint( - model=best_model, dir_path=training_dir, epoch=epoch - ) - logger.info(f"Saved checkpoint at epoch {epoch}\n") + if self.is_main_process: + self.save_checkpoint( + model=best_model, dir_path=self.training_dir, epoch=epoch + ) + logger.info(f"Saved checkpoint at epoch {epoch}\n") if log_verbose: file_logger.info(f"Saved checkpoint at epoch {epoch}\n") self.callback_handler.on_log( - self.training_config, metrics, logger=logger, global_step=epoch + self.training_config, + metrics, + logger=logger, + global_step=epoch, + rank=self.rank, ) - final_dir = os.path.join(training_dir, "final_model") + final_dir = os.path.join(self.training_dir, "final_model") + + if self.is_main_process: + self.save_model(best_model, dir_path=final_dir) + logger.info("----------------------------------") + logger.info("Training ended!") + logger.info(f"Saved final model in {final_dir}") - self.save_model(best_model, dir_path=final_dir) - logger.info("----------------------------------") - logger.info("Training ended!") - logger.info(f"Saved final model in {final_dir}") + if self.distributed: + dist.destroy_process_group() self.callback_handler.on_train_end(training_config=self.training_config) @@ -406,6 +488,7 @@ def eval_step(self, epoch: int): training_config=self.training_config, eval_loader=self.eval_loader, epoch=epoch, + rank=self.rank, ) self.model.eval() @@ -423,12 +506,18 @@ def eval_step(self, epoch: int): with torch.no_grad(): model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) except RuntimeError: model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) encoder_loss = model_output.encoder_loss @@ -472,6 +561,7 @@ def train_step(self, epoch: int): training_config=self.training_config, train_loader=self.train_loader, epoch=epoch, + rank=self.rank, ) # set model in train model @@ -487,7 +577,10 @@ def train_step(self, epoch: int): inputs = self._set_inputs_to_device(inputs) model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.train_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.train_loader.dataset), + uses_ddp=self.distributed, ) self._optimizers_step(model_output) @@ -508,7 +601,10 @@ def train_step(self, epoch: int): ) # Allows model updates if needed - self.model.update() + if self.distributed: + self.model.module.update() + else: + self.model.update() epoch_encoder_loss /= len(self.train_loader) epoch_decoder_loss /= len(self.train_loader) @@ -549,7 +645,11 @@ def save_checkpoint(self, model: BaseAE, dir_path, epoch: int): ) # save model - model.save(checkpoint_dir) + if self.distributed: + model.module.save(checkpoint_dir) + + else: + model.save(checkpoint_dir) # save training config self.training_config.save_json(checkpoint_dir, "training_config") diff --git a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py index 86edcdf3..73f9a8d3 100644 --- a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py +++ b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py @@ -1,3 +1,6 @@ +from typing import Union + +import torch.nn as nn from pydantic.dataclasses import dataclass from ..base_trainer import BaseTrainerConfig @@ -12,20 +15,249 @@ class CoupledOptimizerAdversarialTrainerConfig(BaseTrainerConfig): output_dir (str): The directory where model checkpoints, configs and final model will be stored. Default: None. - - batch_size (int): The number of training samples per batch. Default 100 + per_device_train_batch_size (int): The number of training samples per batch and per device. + Default 64 + per_device_eval_batch_size (int): The number of evaluation samples per batch and per device. + Default 64 num_epochs (int): The maximal number of epochs for training. Default: 100 - learning_rate (int): The learning rate applied to the `Optimizer`. Default: 1e-4 + train_dataloader_num_workers (int): Number of subprocesses to use for train data loading. + 0 means that the data will be loaded in the main process. Default: 0 + eval_dataloader_num_workers (int): Number of subprocesses to use for evaluation data + loading. 0 means that the data will be loaded in the main process. Default: 0 + encoder_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + the training of the encoder. Default: :class:`~torch.optim.Adam`. + encoder_optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer` for the encoder. If None, uses the default parameters. + Default: None. + encoder_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + the training of the encoder. Default :class:`~torch.optim.Adam`. + encoder_scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler` for the encoder. If None, uses the default parameters. + Default: None. + decoder_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + the training of the decoder. Default: :class:`~torch.optim.Adam`. + decoder_optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer` for the decoder. If None, uses the default parameters. + Default: None. + decoder_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + the training of the decoder. Default :class:`~torch.optim.Adam`. + decoder_scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler` for the decoder. If None, uses the default parameters. + Default: None. + discriminator_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + the training of the discriminator. Default: :class:`~torch.optim.Adam`. + discriminator_optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer` for the discriminator. If None, uses the default parameters. + Default: None. + discriminator_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + the training of the discriminator. Default :class:`~torch.optim.Adam`. + discriminator_scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler` for the discriminator. If None, uses the default parameters. + Default: None. + encoder_learning_rate (int): The learning rate applied to the `Optimizer` for the encoder. + Default: 1e-4 + decoder_learning_rate (int): The learning rate applied to the `Optimizer` for the encoder. + Default: 1e-4 + discriminator_learning_rate (int): The learning rate applied to the `Optimizer` for the + discriminator. Default: 1e-4 steps_saving (int): A model checkpoint will be saved every `steps_saving` epoch. Default: None - keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False. + steps_saving (int): A prediction using the best model will be run every `steps_predict` + epoch. Default: None + keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False seed (int): The random seed for reproducibility no_cuda (bool): Disable `cuda` training. Default: False - encoder_optim_decay (float): The decay to apply in the optimizer. Default: 0 - decoder_optim_decay (float): The decay to apply in the optimizer. Default: 0 - discriminator_optim_decay (float): The decay to apply in the optimizer. Default: 0 - """ + world_size (int): The total number of process to run. Default: -1 + local_rank (int): The rank of the node for distributed training. Default: -1 + rank (int): The rank of the process for distributed training. Default: -1 + dist_backend (str): The distributed backend to use. Default: 'nccl' + master_addr (str): The master address for distributed training. Default: 'localhost' + master_port (str): The master port for distributed training. Default: '12345'""" + + encoder_optimizer_cls: str = "Adam" + encoder_optimizer_params: Union[dict, None] = None + encoder_scheduler_cls: str = None + encoder_scheduler_params: Union[dict, None] = None + discriminator_optimizer_cls: str = "Adam" + decoder_optimizer_cls: str = "Adam" + decoder_optimizer_params: Union[dict, None] = None + decoder_scheduler_cls: str = None + decoder_scheduler_params: Union[dict, None] = None + discriminator_optimizer_cls: str = "Adam" + discriminator_optimizer_params: Union[dict, None] = None + discriminator_scheduler_cls: str = None + discriminator_scheduler_params: Union[dict, None] = None + encoder_learning_rate: float = 1e-4 + decoder_learning_rate: float = 1e-4 + discriminator_learning_rate: float = 1e-4 + + def __post_init_post_parse__(self): + """Check compatibilty""" + + # Encoder optimizer and scheduler + try: + import torch.optim as optim + + encoder_optimizer_cls = getattr(optim, self.encoder_optimizer_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.encoder_optimizer_cls}` encoder optimizer " + "from 'torch.optim'. Check spelling and that it is part of " + "'torch.optim.Optimizers.'" + ) + if self.encoder_optimizer_params is not None: + try: + encoder_optimizer = encoder_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.encoder_learning_rate, + **self.encoder_optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{encoder_optimizer_cls}` optimizer. " + f"Got {self.encoder_optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + encoder_optimizer = encoder_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.encoder_learning_rate, + ) + + if self.encoder_scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + encoder_scheduder_cls = getattr(schedulers, self.encoder_scheduler_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.encoder_scheduler_cls}` encoder scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) + + if self.encoder_scheduler_params is not None: + try: + encoder_scheduder_cls( + encoder_optimizer, **self.encoder_scheduler_params + ) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{encoder_scheduder_cls}` scheduler. " + f"Got {self.encoder_scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + + # Decoder optimizer and scheduler + try: + import torch.optim as optim + + decoder_optimizer_cls = getattr(optim, self.decoder_optimizer_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.decoder_optimizer_cls}` decoder optimizer " + "from 'torch.optim'. Check spelling and that it is part of " + "'torch.optim.Optimizers.'" + ) + if self.decoder_optimizer_params is not None: + try: + decoder_optimizer = decoder_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.decoder_learning_rate, + **self.decoder_optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{decoder_optimizer_cls}` optimizer. " + f"Got {self.decoder_optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + decoder_optimizer = decoder_optimizer_cls( + nn.Linear(2, 2).parameters(), lr=self.decoder_learning_rate + ) + + if self.decoder_scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + decoder_scheduder_cls = getattr(schedulers, self.decoder_scheduler_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.decoder_scheduler_cls}` decoder scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) + + if self.decoder_scheduler_params is not None: + try: + decoder_scheduder_cls( + decoder_optimizer, **self.decoder_scheduler_params + ) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{decoder_scheduder_cls}` scheduler. " + f"Got {self.decoder_scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + + # Discriminator optimizer and scheduler + try: + discriminator_optimizer_cls = getattr( + optim, self.discriminator_optimizer_cls + ) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.discriminator_optimizer_cls}` discriminator optimizer " + "from 'torch.optim'. Check spelling and that it is part of " + "'torch.optim.Optimizers.'" + ) + if self.discriminator_optimizer_params is not None: + try: + discriminator_optimizer = discriminator_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.discriminator_learning_rate, + **self.discriminator_optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{discriminator_optimizer_cls}` optimizer. " + f"Got {self.discriminator_optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + discriminator_optimizer = discriminator_optimizer_cls( + nn.Linear(2, 2).parameters(), lr=self.discriminator_learning_rate + ) + + if self.discriminator_scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + discriminator_scheduder_cls = getattr( + schedulers, self.discriminator_scheduler_cls + ) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.discriminator_scheduler_cls}` discriminator scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) - encoder_optim_decay: float = 0 - decoder_optim_decay: float = 0 - discriminator_optim_decay: float = 0 + if self.discriminator_scheduler_params is not None: + try: + discriminator_scheduder_cls( + discriminator_optimizer, **self.discriminator_scheduler_params + ) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{discriminator_scheduder_cls}` scheduler. " + f"Got {self.discriminator_scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e diff --git a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py index 765a10c7..5c1b6dea 100644 --- a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py +++ b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py @@ -5,12 +5,14 @@ from typing import List, Optional import torch +import torch.distributed as dist import torch.optim as optim -from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.optim.lr_scheduler as lr_scheduler from ...data.datasets import BaseDataset from ...models import BaseAE from ..base_trainer import BaseTrainer +from ..trainer_utils import set_seed from ..training_callbacks import TrainingCallback from .coupled_optimizer_trainer_config import CoupledOptimizerTrainerConfig @@ -50,10 +52,6 @@ def __init__( train_dataset: BaseDataset, eval_dataset: Optional[BaseDataset] = None, training_config: Optional[CoupledOptimizerTrainerConfig] = None, - encoder_optimizer: Optional[torch.optim.Optimizer] = None, - decoder_optimizer: Optional[torch.optim.Optimizer] = None, - encoder_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, - decoder_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, callbacks: List[TrainingCallback] = None, ): @@ -63,60 +61,109 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, - optimizer=None, callbacks=callbacks, ) - # set encoder optimizer - if encoder_optimizer is None: - encoder_optimizer = self.set_default_encoder_optimizer(model) + def set_encoder_optimizer(self): + encoder_optimizer_cls = getattr( + optim, self.training_config.encoder_optimizer_cls + ) + if self.training_config.encoder_optimizer_params is not None: + if self.distributed: + encoder_optimizer = encoder_optimizer_cls( + self.model.module.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + **self.training_config.encoder_optimizer_params, + ) + else: + encoder_optimizer = encoder_optimizer_cls( + self.model.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + **self.training_config.encoder_optimizer_params, + ) else: - encoder_optimizer = self._set_optimizer_on_device( - encoder_optimizer, self.device - ) + if self.distributed: + encoder_optimizer = encoder_optimizer_cls( + self.model.module.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + ) + else: + encoder_optimizer = encoder_optimizer_cls( + self.model.encoder.parameters(), + lr=self.training_config.encoder_learning_rate, + ) - if encoder_scheduler is None: - encoder_scheduler = self.set_default_scheduler(model, encoder_optimizer) + self.encoder_optimizer = encoder_optimizer + + def set_encoder_scheduler(self): + if self.training_config.encoder_scheduler_cls is not None: + encoder_scheduler_cls = getattr( + lr_scheduler, self.training_config.encoder_scheduler_cls + ) - # set decoder optimizer - if decoder_optimizer is None: - decoder_optimizer = self.set_default_decoder_optimizer(model) + if self.training_config.encoder_scheduler_params is not None: + scheduler = encoder_scheduler_cls( + self.encoder_optimizer, + **self.training_config.encoder_scheduler_params, + ) + else: + scheduler = encoder_scheduler_cls(self.encoder_optimizer) else: - decoder_optimizer = self._set_optimizer_on_device( - decoder_optimizer, self.device - ) + scheduler = None - if decoder_scheduler is None: - decoder_scheduler = self.set_default_scheduler(model, decoder_optimizer) + self.encoder_scheduler = scheduler - self.encoder_optimizer = encoder_optimizer - self.decoder_optimizer = decoder_optimizer - self.encoder_scheduler = encoder_scheduler - self.decoder_scheduler = decoder_scheduler + def set_decoder_optimizer(self): + decoder_cls = getattr(optim, self.training_config.decoder_optimizer_cls) - self.optimizer = None + if self.training_config.decoder_optimizer_params is not None: + if self.distributed: + decoder_optimizer = decoder_cls( + self.model.module.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + **self.training_config.decoder_optimizer_params, + ) + else: + decoder_optimizer = decoder_cls( + self.model.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + **self.training_config.decoder_optimizer_params, + ) - def set_default_encoder_optimizer(self, model: BaseAE) -> torch.optim.Optimizer: + else: + if self.distributed: + decoder_optimizer = decoder_cls( + self.model.module.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + ) + else: + decoder_optimizer = decoder_cls( + self.model.decoder.parameters(), + lr=self.training_config.decoder_learning_rate, + ) - optimizer = optim.Adam( - model.encoder.parameters(), - lr=self.training_config.learning_rate, - weight_decay=self.training_config.encoder_optim_decay, - ) + self.decoder_optimizer = decoder_optimizer - return optimizer + def set_decoder_scheduler(self) -> torch.optim.lr_scheduler: + if self.training_config.decoder_scheduler_cls is not None: + decoder_scheduler_cls = getattr( + lr_scheduler, self.training_config.decoder_scheduler_cls + ) - def set_default_decoder_optimizer(self, model: BaseAE) -> torch.optim.Optimizer: + if self.training_config.decoder_scheduler_params is not None: + scheduler = decoder_scheduler_cls( + self.decoder_optimizer, + **self.training_config.decoder_scheduler_params, + ) + else: + scheduler = decoder_scheduler_cls(self.decoder_optimizer) - optimizer = optim.Adam( - model.decoder.parameters(), - lr=self.training_config.learning_rate, - weight_decay=self.training_config.decoder_optim_decay, - ) + else: + scheduler = None - return optimizer + self.decoder_scheduler = scheduler def _optimizers_step(self, model_output): @@ -139,18 +186,44 @@ def _optimizers_step(self, model_output): self.decoder_optimizer.step() def _schedulers_step(self, encoder_metrics=None, decoder_metrics=None): - if isinstance(self.encoder_scheduler, ReduceLROnPlateau): + + if self.encoder_scheduler is None: + pass + + elif isinstance(self.encoder_scheduler, lr_scheduler.ReduceLROnPlateau): self.encoder_scheduler.step(encoder_metrics) else: self.encoder_scheduler.step() - if isinstance(self.decoder_scheduler, ReduceLROnPlateau): + if self.decoder_scheduler is None: + pass + + elif isinstance(self.decoder_scheduler, lr_scheduler.ReduceLROnPlateau): self.decoder_scheduler.step(decoder_metrics) else: self.decoder_scheduler.step() + def prepare_training(self): + + # set random seed + set_seed(self.training_config.seed) + + # set autoencoder optimizer and scheduler + self.set_encoder_optimizer() + self.set_encoder_scheduler() + + # set discriminator optimizer and scheduler + self.set_decoder_optimizer() + self.set_decoder_scheduler() + + # create foder for saving + self._set_output_dir() + + # set callbacks + self._setup_callbacks() + def train(self, log_output_dir: str = None): """This function is the main training function @@ -158,72 +231,40 @@ def train(self, log_output_dir: str = None): log_output_dir (str): The path in which the log will be stored """ + self.prepare_training() + self.callback_handler.on_train_begin( - training_config=self.training_config, model_config=self.model.model_config + training_config=self.training_config, model_config=self.model_config ) - # run sanity check on the model - self._run_model_sanity_check(self.model, self.train_loader) - - logger.info("Model passed sanity check !\n") - - self._training_signature = ( - str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-") - ) + log_verbose = False - training_dir = os.path.join( - self.training_config.output_dir, - f"{self.model.model_name}_training_{self._training_signature}", + msg = ( + f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" + " - per_device_train_batch_size: " + f"{self.training_config.per_device_train_batch_size}\n" + " - per_device_eval_batch_size: " + f"{self.training_config.per_device_eval_batch_size}\n" + f" - checkpoint saving every: {self.training_config.steps_saving}\n" + f"Encoder Optimizer: {self.encoder_optimizer}\n" + f"Encoder Scheduler: {self.encoder_scheduler}\n" + f"Decoder Optimizer: {self.decoder_optimizer}\n" + f"Decoder Scheduler: {self.decoder_scheduler}\n" ) - self.training_dir = training_dir - - if not os.path.exists(training_dir): - os.makedirs(training_dir) - logger.info( - f"Created {training_dir}. \n" - "Training config, checkpoints and final model will be saved here.\n" - ) - - log_verbose = False + if self.is_main_process: + logger.info(msg) # set up log file - if log_output_dir is not None: - log_dir = log_output_dir + if log_output_dir is not None and self.is_main_process: log_verbose = True - # if dir does not exist create it - if not os.path.exists(log_dir): - os.makedirs(log_dir) - logger.info(f"Created {log_dir} folder since did not exists.") - logger.info("Training logs will be recodered here.\n") - logger.info(" -> Training can be monitored here.\n") - - # create and set logger - log_name = f"training_logs_{self._training_signature}" - - file_logger = logging.getLogger(log_name) - file_logger.setLevel(logging.INFO) - f_handler = logging.FileHandler( - os.path.join(log_dir, f"training_logs_{self._training_signature}.log") - ) - f_handler.setLevel(logging.INFO) - file_logger.addHandler(f_handler) - - # Do not output logs in the console - file_logger.propagate = False - - file_logger.info("Training started !\n") - file_logger.info( - f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" - f" - batch_size: {self.training_config.batch_size}\n" - f" - checkpoint saving every {self.training_config.steps_saving}\n" - ) + file_logger = self._get_file_logger(log_output_dir=log_output_dir) - file_logger.info(f"Model Architecture: {self.model}\n") - file_logger.info(f"Optimizer: {self.optimizer}\n") + file_logger.info(msg) - logger.info("Successfully launched training !\n") + if self.is_main_process: + logger.info("Successfully launched training !\n") # set best losses for early stopping best_train_loss = 1e10 @@ -294,6 +335,7 @@ def train(self, log_output_dir: str = None): if ( self.training_config.steps_predict is not None and epoch % self.training_config.steps_predict == 0 + and self.is_main_process ): true_data, reconstructions, generations = self.predict(best_model) @@ -313,23 +355,33 @@ def train(self, log_output_dir: str = None): self.training_config.steps_saving is not None and epoch % self.training_config.steps_saving == 0 ): - self.save_checkpoint( - model=best_model, dir_path=training_dir, epoch=epoch - ) - logger.info(f"Saved checkpoint at epoch {epoch}\n") + if self.is_main_process: + self.save_checkpoint( + model=best_model, dir_path=self.training_dir, epoch=epoch + ) + logger.info(f"Saved checkpoint at epoch {epoch}\n") - if log_verbose: - file_logger.info(f"Saved checkpoint at epoch {epoch}\n") + if log_verbose: + file_logger.info(f"Saved checkpoint at epoch {epoch}\n") self.callback_handler.on_log( - self.training_config, metrics, logger=logger, global_step=epoch + self.training_config, + metrics, + logger=logger, + global_step=epoch, + rank=self.rank, ) - final_dir = os.path.join(training_dir, "final_model") + final_dir = os.path.join(self.training_dir, "final_model") - self.save_model(best_model, dir_path=final_dir) - logger.info("Training ended!") - logger.info(f"Saved final model in {final_dir}") + if self.is_main_process: + self.save_model(best_model, dir_path=final_dir) + logger.info("----------------------------------") + logger.info("Training ended!") + logger.info(f"Saved final model in {final_dir}") + + if self.distributed: + dist.destroy_process_group() self.callback_handler.on_train_end(self.training_config) @@ -346,6 +398,7 @@ def eval_step(self, epoch: int): training_config=self.training_config, eval_loader=self.eval_loader, epoch=epoch, + rank=self.rank, ) self.model.eval() @@ -362,12 +415,18 @@ def eval_step(self, epoch: int): with torch.no_grad(): model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) except RuntimeError: model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.eval_loader.dataset), + uses_ddp=self.distributed, ) encoder_loss = model_output.encoder_loss @@ -403,6 +462,7 @@ def train_step(self, epoch: int): training_config=self.training_config, train_loader=self.train_loader, epoch=epoch, + rank=self.rank, ) # set model in train model @@ -417,7 +477,10 @@ def train_step(self, epoch: int): inputs = self._set_inputs_to_device(inputs) model_output = self.model( - inputs, epoch=epoch, dataset_size=len(self.train_loader.dataset) + inputs, + epoch=epoch, + dataset_size=len(self.train_loader.dataset), + uses_ddp=self.distributed, ) self._optimizers_step(model_output) @@ -439,7 +502,10 @@ def train_step(self, epoch: int): ) # Allows model updates if needed - self.model.update() + if self.distributed: + self.model.module.update() + else: + self.model.update() epoch_encoder_loss /= len(self.train_loader) epoch_decoder_loss /= len(self.train_loader) @@ -470,7 +536,11 @@ def save_checkpoint(self, model: BaseAE, dir_path, epoch: int): ) # save model - model.save(checkpoint_dir) + if self.distributed: + model.module.save(checkpoint_dir) + + else: + model.save(checkpoint_dir) # save training config self.training_config.save_json(checkpoint_dir, "training_config") diff --git a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py index 34cf8013..e458442d 100644 --- a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py +++ b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py @@ -1,3 +1,6 @@ +from typing import Union + +import torch.nn as nn from pydantic.dataclasses import dataclass from ..base_trainer import BaseTrainerConfig @@ -12,18 +15,173 @@ class CoupledOptimizerTrainerConfig(BaseTrainerConfig): output_dir (str): The directory where model checkpoints, configs and final model will be stored. Default: None. - - batch_size (int): The number of training samples per batch. Default 100 + per_device_train_batch_size (int): The number of training samples per batch and per device. + Default 64 + per_device_eval_batch_size (int): The number of evaluation samples per batch and per device. + Default 64 num_epochs (int): The maximal number of epochs for training. Default: 100 - learning_rate (int): The learning rate applied to the `Optimizer`. Default: 1e-4 + train_dataloader_num_workers (int): Number of subprocesses to use for train data loading. + 0 means that the data will be loaded in the main process. Default: 0 + eval_dataloader_num_workers (int): Number of subprocesses to use for evaluation data + loading. 0 means that the data will be loaded in the main process. Default: 0 + encoder_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + the training of the encoder. Default: :class:`~torch.optim.Adam`. + encoder_optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer` for the encoder. If None, uses the default parameters. + Default: None. + encoder_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + the training of the encoder. Default :class:`~torch.optim.Adam`. + encoder_scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler` for the encoder. If None, uses the default parameters. + Default: None. + decoder_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for + the training of the decoder. Default: :class:`~torch.optim.Adam`. + decoder_optimizer_params (dict): A dict containing the parameters to use for the + `torch.optim.Optimizer` for the decoder. If None, uses the default parameters. + Default: None. + decoder_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for + the training of the decoder. Default :class:`~torch.optim.Adam`. + decoder_scheduler_params (dict): A dict containing the parameters to use for the + `torch.optim.le_scheduler` for the decoder. If None, uses the default parameters. + Default: None. + encoder_learning_rate (int): The learning rate applied to the `Optimizer` for the encoder. + Default: 1e-4 + decoder_learning_rate (int): The learning rate applied to the `Optimizer` for the + decoder. Default: 1e-4 steps_saving (int): A model checkpoint will be saved every `steps_saving` epoch. Default: None - keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False. + steps_saving (int): A prediction using the best model will be run every `steps_predict` + epoch. Default: None + keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False seed (int): The random seed for reproducibility no_cuda (bool): Disable `cuda` training. Default: False - encoder_optim_decay (float): The decay to apply in the optimizer. Default: 0 - decoder_optim_decay (float): The decay to apply in the optimizer. Default: 1e-5 + world_size (int): The total number of process to run. Default: -1 + local_rank (int): The rank of the node for distributed training. Default: -1 + rank (int): The rank of the process for distributed training. Default: -1 + dist_backend (str): The distributed backend to use. Default: 'nccl' + master_addr (str): The master address for distributed training. Default: 'localhost' + master_port (str): The master port for distributed training. Default: '12345' + """ - encoder_optim_decay: float = 0 - decoder_optim_decay: float = 1e-5 + encoder_optimizer_cls: str = "Adam" + encoder_optimizer_params: Union[dict, None] = None + encoder_scheduler_cls: str = None + encoder_scheduler_params: Union[dict, None] = None + decoder_optimizer_cls: str = "Adam" + decoder_optimizer_params: Union[dict, None] = None + decoder_scheduler_cls: str = None + decoder_scheduler_params: Union[dict, None] = None + encoder_learning_rate: float = 1e-4 + decoder_learning_rate: float = 1e-4 + + def __post_init_post_parse__(self): + """Check compatibilty""" + + # encoder optimizer and scheduler + try: + import torch.optim as optim + + encoder_optimizer_cls = getattr(optim, self.encoder_optimizer_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.encoder_optimizer_cls}` encoder optimizer " + "from 'torch.optim'. Check spelling and that it is part of " + "'torch.optim.Optimizers.'" + ) + if self.encoder_optimizer_params is not None: + try: + encoder_optimizer = encoder_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.encoder_learning_rate, + **self.encoder_optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{encoder_optimizer_cls}` optimizer. " + f"Got {self.encoder_optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + encoder_optimizer = encoder_optimizer_cls( + nn.Linear(2, 2).parameters(), lr=self.encoder_learning_rate + ) + + if self.encoder_scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + encoder_scheduder_cls = getattr(schedulers, self.encoder_scheduler_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.encoder_scheduler_cls}` encoder scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) + + if self.encoder_scheduler_params is not None: + try: + encoder_scheduder_cls( + encoder_optimizer, **self.encoder_scheduler_params + ) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{encoder_scheduder_cls}` scheduler. " + f"Got {self.encoder_scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + + # decoder optimizer and scheduler + try: + decoder_optimizer_cls = getattr(optim, self.decoder_optimizer_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.decoder_optimizer_cls}` decoder optimizer " + "from 'torch.optim'. Check spelling and that it is part of " + "'torch.optim.Optimizers.'" + ) + if self.decoder_optimizer_params is not None: + try: + decoder_optimizer = decoder_optimizer_cls( + nn.Linear(2, 2).parameters(), + lr=self.decoder_learning_rate, + **self.decoder_optimizer_params, + ) + except TypeError as e: + raise TypeError( + "Error in optimizer's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{decoder_optimizer_cls}` optimizer. " + f"Got {self.decoder_optimizer_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e + else: + decoder_optimizer = decoder_optimizer_cls( + nn.Linear(2, 2).parameters(), lr=self.decoder_learning_rate + ) + + if self.decoder_scheduler_cls is not None: + try: + import torch.optim.lr_scheduler as schedulers + + decoder_scheduder_cls = getattr(schedulers, self.decoder_scheduler_cls) + except AttributeError as e: + raise AttributeError( + f"Unable to import `{self.decoder_scheduler_cls}` decoder scheduler from " + "'torch.optim.lr_scheduler'. Check spelling and that it is part of " + "'torch.optim.lr_scheduler.'" + ) + + if self.decoder_scheduler_params is not None: + try: + decoder_scheduder_cls( + decoder_optimizer, **self.decoder_scheduler_params + ) + except TypeError as e: + raise TypeError( + "Error in scheduler's parameters. Check that the provided dict contains only " + f"keys and values suitable for `{decoder_scheduder_cls}` scheduler. " + f"Got {self.decoder_scheduler_params} as parameters.\n" + f"Exception raised: {type(e)} with message: " + str(e) + ) from e diff --git a/src/pythae/trainers/training_callbacks.py b/src/pythae/trainers/training_callbacks.py index 3760f0ec..3864aaea 100644 --- a/src/pythae/trainers/training_callbacks.py +++ b/src/pythae/trainers/training_callbacks.py @@ -117,13 +117,11 @@ class CallbackHandler: Class to handle list of Callback. """ - def __init__(self, callbacks, model, optimizer, scheduler): + def __init__(self, callbacks, model): self.callbacks = [] for cb in callbacks: self.add_callback(cb) self.model = model - self.optimizer = optimizer - self.scheduler = scheduler def add_callback(self, callback): cb = callback() if isinstance(callback, type) else callback @@ -187,8 +185,6 @@ def call_event(self, event, training_config, **kwargs): result = getattr(callback, event)( training_config, model=self.model, - optimizer=self.optimizer, - scheduler=self.scheduler, **kwargs, ) @@ -207,9 +203,11 @@ def __init__(self): self.logger.setLevel(logging.INFO) def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs): + logger = kwargs.pop("logger", self.logger) + rank = kwargs.pop("rank", -1) - if logger is not None: + if logger is not None and (rank == -1 or rank == 0): epoch_train_loss = logs.get("train_epoch_loss", None) epoch_eval_loss = logs.get("eval_epoch_loss", None) @@ -237,22 +235,26 @@ def __init__(self): def on_train_step_begin(self, training_config: BaseTrainerConfig, **kwargs): epoch = kwargs.pop("epoch", None) train_loader = kwargs.pop("train_loader", None) + rank = kwargs.pop("rank", -1) if train_loader is not None: - self.train_progress_bar = tqdm( - total=len(train_loader), - unit="batch", - desc=f"Training of epoch {epoch}/{training_config.num_epochs}", - ) + if rank == 0 or rank == -1: + self.train_progress_bar = tqdm( + total=len(train_loader), + unit="batch", + desc=f"Training of epoch {epoch}/{training_config.num_epochs}", + ) def on_eval_step_begin(self, training_config: BaseTrainerConfig, **kwargs): epoch = kwargs.pop("epoch", None) eval_loader = kwargs.pop("eval_loader", None) + rank = kwargs.pop("rank", -1) if eval_loader is not None: - self.eval_progress_bar = tqdm( - total=len(eval_loader), - unit="batch", - desc=f"Eval of epoch {epoch}/{training_config.num_epochs}", - ) + if rank == 0 or rank == -1: + self.eval_progress_bar = tqdm( + total=len(eval_loader), + unit="batch", + desc=f"Eval of epoch {epoch}/{training_config.num_epochs}", + ) def on_train_step_end(self, training_config: BaseTrainerConfig, **kwargs): if self.train_progress_bar is not None: @@ -581,12 +583,8 @@ def setup( ) experiment.log_other("Created from", "pythae") - experiment.log_parameters( - training_config, prefix="training_config/" - ) - experiment.log_parameters( - model_config, prefix="model_config/" - ) + experiment.log_parameters(training_config, prefix="training_config/") + experiment.log_parameters(model_config, prefix="model_config/") def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs): model_config = kwargs.pop("model_config", None) diff --git a/tests/data/baseAE/configs/training_config00.json b/tests/data/baseAE/configs/training_config00.json index 00743d58..347fb9ff 100644 --- a/tests/data/baseAE/configs/training_config00.json +++ b/tests/data/baseAE/configs/training_config00.json @@ -1,6 +1,7 @@ { "name": "BaseTrainerConfig", - "batch_size": 13, + "per_device_train_batch_size": 13, + "per_device_eval_batch_size": 42, "num_epochs": 2, "learning_rate": 1e-5 } \ No newline at end of file diff --git a/tests/data/custom_architectures.py b/tests/data/custom_architectures.py index 0b086b1e..c6497227 100644 --- a/tests/data/custom_architectures.py +++ b/tests/data/custom_architectures.py @@ -1,12 +1,13 @@ import typing +from typing import List +import numpy as np import torch import torch.nn as nn -import numpy as np -from typing import List -from pythae.models.nn import * + from pythae.models.base.base_utils import ModelOutput -import torch.nn as nn +from pythae.models.nn import * + class Layer(nn.Module): def __init__(self) -> None: @@ -365,7 +366,9 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.layers = nn.Sequential(nn.Linear(np.prod(args.input_dim), 10), nn.ReLU(), Layer()) + self.layers = nn.Sequential( + nn.Linear(np.prod(args.input_dim), 10), nn.ReLU(), Layer() + ) self.mu = nn.Linear(10, self.latent_dim) def forward(self, x): @@ -375,6 +378,7 @@ def forward(self, x): return output + class Encoder_VAE_MLP_Custom(BaseEncoder): def __init__(self, args: dict): BaseEncoder.__init__(self) @@ -390,7 +394,9 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.layers = nn.Sequential(nn.Linear(np.prod(args.input_dim), 10), nn.ReLU(), Layer()) + self.layers = nn.Sequential( + nn.Linear(np.prod(args.input_dim), 10), nn.ReLU(), Layer() + ) self.mu = nn.Linear(10, self.latent_dim) self.std = nn.Linear(10, self.latent_dim) diff --git a/tests/test_AE.py b/tests/test_AE.py index 4f5917f6..3c12b653 100644 --- a/tests/test_AE.py +++ b/tests/test_AE.py @@ -3,16 +3,18 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError -from pythae.models.base.base_utils import ModelOutput from pythae.models import AE, AEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, IAFSamplerConfig - - +from pythae.models.base.base_utils import ModelOutput +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_AE_Conv, @@ -118,7 +120,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -208,7 +212,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -289,14 +293,15 @@ def test_model_train_output(self, ae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -311,23 +316,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return AE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -338,12 +350,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return AE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + @pytest.mark.slow class Test_AE_Training: @pytest.fixture @@ -387,26 +399,21 @@ def ae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, ae, training_configs): - if request.param is not None: - optimizer = request.param( - ae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_ae_train_step(self, ae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, ae, train_dataset, training_configs): trainer = BaseTrainer( model=ae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_ae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -421,14 +428,7 @@ def test_ae_train_step(self, ae, train_dataset, training_configs, optimizers): ] ) - def test_ae_eval_step(self, ae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_ae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -444,16 +444,7 @@ def test_ae_eval_step(self, ae, train_dataset, training_configs, optimizers): ] ) - def test_ae_predict_step( - self, ae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_ae_predict_step(self, train_dataset, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -469,21 +460,11 @@ def test_ae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_ae_main_train_loop( - self, tmpdir, ae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_ae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -499,19 +480,10 @@ def test_ae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, ae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, ae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -593,21 +565,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, ae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, ae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -655,19 +618,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, ae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, ae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -775,10 +729,13 @@ def test_ae_training_pipeline(self, tmpdir, ae, train_dataset, training_configs) assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_AE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -789,7 +746,7 @@ def ae_model(self): NormalSamplerConfig(), GaussianMixtureSamplerConfig(), MAFSamplerConfig(), - IAFSamplerConfig() + IAFSamplerConfig(), ] ) def sampler_configs(self, request): @@ -804,7 +761,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_Adversarial_AE.py b/tests/test_Adversarial_AE.py index fbfabc00..0265fa8e 100644 --- a/tests/test_Adversarial_AE.py +++ b/tests/test_Adversarial_AE.py @@ -1,30 +1,30 @@ import os -import numpy as np from copy import deepcopy +import numpy as np import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError -from pythae.models.base.base_utils import ModelOutput from pythae.models import Adversarial_AE, Adversarial_AE_Config, AutoModel +from pythae.models.base.base_utils import ModelOutput +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import ( AdversarialTrainer, AdversarialTrainerConfig, BaseTrainerConfig, ) -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline -from pythae.models.nn.default_architectures import ( - Decoder_AE_MLP, - Encoder_VAE_MLP, - Discriminator_MLP, -) from tests.data.custom_architectures import ( Decoder_AE_Conv, - Encoder_VAE_Conv, Discriminator_MLP_Custom, + Encoder_VAE_Conv, NetBadInheritance, ) @@ -178,7 +178,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -309,7 +311,7 @@ def test_full_custom_model_saving( "encoder.pkl", "decoder.pkl", "discriminator.pkl", - "environment.json" + "environment.json", ] ) @@ -405,28 +407,32 @@ def test_model_train_output(self, adversarial_ae, demo_data): assert isinstance(out, ModelOutput) - assert set( - [ - "loss", - "recon_loss", - "autoencoder_loss", - "discriminator_loss", - "recon_x", - "z", - ] - ) == set(out.keys()) + assert ( + set( + [ + "loss", + "recon_loss", + "autoencoder_loss", + "discriminator_loss", + "recon_x", + "z", + ] + ) + == set(out.keys()) + ) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -441,23 +447,30 @@ def adversarial_ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return Adversarial_AE(model_configs) - def test_interpolate(self, adversarial_ae, demo_data, granularity): with pytest.raises(AssertionError): adversarial_ae.interpolate(demo_data, demo_data[1:], granularity) interp = adversarial_ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -468,9 +481,8 @@ def adversarial_ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return Adversarial_AE(model_configs) - def test_reconstruct(self, adversarial_ae, demo_data): - + recon = adversarial_ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -554,34 +566,21 @@ def adversarial_ae( return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, adversarial_ae, training_configs): - if request.param is not None: - encoder_optimizer = request.param( - adversarial_ae.encoder.parameters(), lr=training_configs.learning_rate - ) - decoder_optimizer = request.param( - adversarial_ae.discriminator.parameters(), - lr=training_configs.learning_rate, - ) - - else: - encoder_optimizer = None - decoder_optimizer = None - - return (encoder_optimizer, decoder_optimizer) - - def test_adversarial_ae_train_step( - self, adversarial_ae, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, adversarial_ae, train_dataset, training_configs): trainer = AdversarialTrainer( model=adversarial_ae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], ) + trainer.prepare_training() + + return trainer + + def test_adversarial_ae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -596,17 +595,7 @@ def test_adversarial_ae_train_step( ] ) - def test_adversarial_ae_eval_step( - self, adversarial_ae, train_dataset, training_configs, optimizers - ): - trainer = AdversarialTrainer( - model=adversarial_ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) + def test_adversarial_ae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -622,17 +611,7 @@ def test_adversarial_ae_eval_step( ] ) - def test_adversarial_ae_predict_step( - self, adversarial_ae, train_dataset, training_configs, optimizers - ): - trainer = AdversarialTrainer( - model=adversarial_ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) + def test_adversarial_ae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -648,22 +627,11 @@ def test_adversarial_ae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_adversarial_ae_main_train_loop( - self, tmpdir, adversarial_ae, train_dataset, training_configs, optimizers - ): - - trainer = AdversarialTrainer( - model=adversarial_ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) + def test_adversarial_ae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -679,20 +647,10 @@ def test_adversarial_ae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, adversarial_ae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, adversarial_ae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = AdversarialTrainer( - model=adversarial_ae, - train_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -819,21 +777,13 @@ def test_checkpoint_saving( ) def test_checkpoint_saving_during_training( - self, tmpdir, adversarial_ae, train_dataset, training_configs, optimizers + self, adversarial_ae, trainer, training_configs ): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = AdversarialTrainer( - model=adversarial_ae, - train_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - model = deepcopy(trainer.model) trainer.train() @@ -893,20 +843,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, adversarial_ae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, adversarial_ae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = AdversarialTrainer( - model=adversarial_ae, - train_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -963,7 +903,7 @@ def test_final_model_saving( assert type(model_rec.discriminator.cpu()) == type(model.discriminator.cpu()) def test_adversarial_ae_training_pipeline( - self, tmpdir, adversarial_ae, train_dataset, training_configs + self, adversarial_ae, train_dataset, training_configs ): with pytest.raises(AssertionError): @@ -1037,14 +977,19 @@ def test_adversarial_ae_training_pipeline( assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) assert type(model_rec.discriminator.cpu()) == type(model.discriminator.cpu()) + class Test_Adversarial_AE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): - return Adversarial_AE(Adversarial_AE_Config(input_dim=(1, 28, 28), latent_dim=7)) + return Adversarial_AE( + Adversarial_AE_Config(input_dim=(1, 28, 28), latent_dim=7) + ) @pytest.fixture( params=[ @@ -1052,7 +997,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -1067,7 +1012,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_BetaTCVAE.py b/tests/test_BetaTCVAE.py index 177b4903..ea5f6829 100644 --- a/tests/test_BetaTCVAE.py +++ b/tests/test_BetaTCVAE.py @@ -3,15 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import AutoModel, BetaTCVAE, BetaTCVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import BetaTCVAE, BetaTCVAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig - +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -126,7 +130,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -216,7 +222,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -305,9 +311,9 @@ class Test_Model_interpolate: params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -322,23 +328,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return BetaTCVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -349,12 +362,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return BetaTCVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -426,28 +439,21 @@ def betavae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, betavae, training_configs): - if request.param is not None: - optimizer = request.param( - betavae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_betavae_train_step( - self, betavae, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, betavae, train_dataset, training_configs): trainer = BaseTrainer( model=betavae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_betavae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -462,16 +468,7 @@ def test_betavae_train_step( ] ) - def test_betavae_eval_step( - self, betavae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -487,16 +484,7 @@ def test_betavae_eval_step( ] ) - def test_betavae_predict_step( - self, betavae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -512,22 +500,11 @@ def test_betavae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - - def test_betavae_main_train_loop( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -543,19 +520,10 @@ def test_betavae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, betavae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -638,20 +606,13 @@ def test_checkpoint_saving( ) def test_checkpoint_saving_during_training( - self, tmpdir, betavae, train_dataset, training_configs, optimizers + self, betavae, trainer, training_configs ): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -699,19 +660,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, betavae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -821,10 +773,13 @@ def test_betavae_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_BetaTC_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -836,7 +791,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -851,7 +806,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_BetaVAE.py b/tests/test_BetaVAE.py index ccafa4c1..dd821559 100644 --- a/tests/test_BetaVAE.py +++ b/tests/test_BetaVAE.py @@ -3,16 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import AutoModel, BetaVAE, BetaVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import BetaVAE, BetaVAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig - - +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -118,7 +121,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -208,7 +213,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -291,14 +296,15 @@ def test_model_train_output(self, betavae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -313,23 +319,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return BetaVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -340,9 +353,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return BetaVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -418,28 +430,21 @@ def betavae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, betavae, training_configs): - if request.param is not None: - optimizer = request.param( - betavae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_betavae_train_step( - self, betavae, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, betavae, train_dataset, training_configs): trainer = BaseTrainer( model=betavae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_betavae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -454,16 +459,7 @@ def test_betavae_train_step( ] ) - def test_betavae_eval_step( - self, betavae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -479,16 +475,7 @@ def test_betavae_eval_step( ] ) - def test_betavae_predict_step( - self, betavae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -504,21 +491,11 @@ def test_betavae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_betavae_main_train_loop( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -534,19 +511,10 @@ def test_betavae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, betavae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -629,20 +597,13 @@ def test_checkpoint_saving( ) def test_checkpoint_saving_during_training( - self, tmpdir, betavae, train_dataset, training_configs, optimizers + self, betavae, trainer, training_configs ): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -690,19 +651,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, betavae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -750,9 +702,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_betavae_training_pipeline( - self, tmpdir, betavae, train_dataset, training_configs - ): + def test_betavae_training_pipeline(self, betavae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -812,10 +762,13 @@ def test_betavae_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_BetaVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -827,7 +780,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -842,7 +795,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_CIWAE.py b/tests/test_CIWAE.py index 3391dbb2..74df8809 100644 --- a/tests/test_CIWAE.py +++ b/tests/test_CIWAE.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop from pythae.customexception import BadInheritanceError +from pythae.models import CIWAE, AutoModel, CIWAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import CIWAE, CIWAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -124,7 +129,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -214,7 +221,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -297,8 +304,7 @@ def test_model_train_output(self, CIWAE, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert ( out.recon_x.shape - == (demo_data["data"].shape[0],) - + demo_data["data"].shape[1:] + == (demo_data["data"].shape[0],) + demo_data["data"].shape[1:] ) @@ -307,9 +313,9 @@ class Test_Model_interpolate: params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -324,23 +330,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return CIWAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -351,12 +364,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return CIWAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -407,7 +420,7 @@ def training_configs(self, tmpdir, request): torch.rand(1), ] ) - def CIWAE(self, model_configs, custom_encoder, custom_decoder, request): + def ciwae(self, model_configs, custom_encoder, custom_decoder, request): # randomized alpha = request.param @@ -426,26 +439,21 @@ def CIWAE(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, CIWAE, training_configs): - if request.param is not None: - optimizer = request.param( - CIWAE.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_CIWAE_train_step(self, CIWAE, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, ciwae, train_dataset, training_configs): trainer = BaseTrainer( - model=CIWAE, + model=ciwae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_ciwae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -460,14 +468,7 @@ def test_CIWAE_train_step(self, CIWAE, train_dataset, training_configs, optimize ] ) - def test_CIWAE_eval_step(self, CIWAE, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=CIWAE, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_ciwae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -483,16 +484,7 @@ def test_CIWAE_eval_step(self, CIWAE, train_dataset, training_configs, optimizer ] ) - def test_CIWAE_predict_step( - self, CIWAE, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=CIWAE, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_ciwae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -508,21 +500,11 @@ def test_CIWAE_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_CIWAE_main_train_loop( - self, tmpdir, CIWAE, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=CIWAE, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_ciwae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -538,19 +520,10 @@ def test_CIWAE_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, CIWAE, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, ciwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=CIWAE, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -570,14 +543,14 @@ def test_checkpoint_saving( ) # check pickled custom decoder - if not CIWAE.model_config.uses_default_decoder: + if not ciwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not CIWAE.model_config.uses_default_encoder: + if not ciwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -632,21 +605,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, CIWAE, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, ciwae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=CIWAE, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -670,14 +634,14 @@ def test_checkpoint_saving_during_training( ) # check pickled custom decoder - if not CIWAE.model_config.uses_default_decoder: + if not ciwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not CIWAE.model_config.uses_default_encoder: + if not ciwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -694,19 +658,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, CIWAE, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, ciwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=CIWAE, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -726,14 +681,14 @@ def test_final_model_saving( ) # check pickled custom decoder - if not CIWAE.model_config.uses_default_decoder: + if not ciwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not CIWAE.model_config.uses_default_encoder: + if not ciwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -754,14 +709,12 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_CIWAE_training_pipeline( - self, tmpdir, CIWAE, train_dataset, training_configs - ): + def test_CIWAE_training_pipeline(self, ciwae, train_dataset, training_configs): dir_path = training_configs.output_dir # build pipeline - pipeline = TrainingPipeline(model=CIWAE, training_config=training_configs) + pipeline = TrainingPipeline(model=ciwae, training_config=training_configs) assert pipeline.training_config.__dict__ == training_configs.__dict__ @@ -788,14 +741,14 @@ def test_CIWAE_training_pipeline( ) # check pickled custom decoder - if not CIWAE.model_config.uses_default_decoder: + if not ciwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not CIWAE.model_config.uses_default_encoder: + if not ciwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -816,10 +769,13 @@ def test_CIWAE_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_CIWAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -831,7 +787,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -846,7 +802,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_DisentangledBetaVAE.py b/tests/test_DisentangledBetaVAE.py index 59a958ee..f74618a9 100644 --- a/tests/test_DisentangledBetaVAE.py +++ b/tests/test_DisentangledBetaVAE.py @@ -3,16 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import AutoModel, DisentangledBetaVAE, DisentangledBetaVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import DisentangledBetaVAE, DisentangledBetaVAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig - - +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -133,7 +136,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -225,7 +230,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -310,14 +315,15 @@ def test_model_train_output(self, betavae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -332,23 +338,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return DisentangledBetaVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -359,9 +372,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return DisentangledBetaVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -437,28 +449,21 @@ def betavae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, betavae, training_configs): - if request.param is not None: - optimizer = request.param( - betavae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_betavae_train_step( - self, betavae, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, betavae, train_dataset, training_configs): trainer = BaseTrainer( model=betavae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_betavae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -473,16 +478,7 @@ def test_betavae_train_step( ] ) - def test_betavae_eval_step( - self, betavae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -498,16 +494,7 @@ def test_betavae_eval_step( ] ) - def test_betavae_predict_step( - self, betavae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -523,21 +510,11 @@ def test_betavae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_betavae_main_train_loop( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_betavae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -553,19 +530,10 @@ def test_betavae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, betavae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -648,20 +616,13 @@ def test_checkpoint_saving( ) def test_checkpoint_saving_during_training( - self, tmpdir, betavae, train_dataset, training_configs, optimizers + self, betavae, trainer, training_configs ): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -709,19 +670,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, betavae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, betavae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=betavae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -769,9 +721,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_betavae_training_pipeline( - self, tmpdir, betavae, train_dataset, training_configs - ): + def test_betavae_training_pipeline(self, betavae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -832,14 +782,19 @@ def test_betavae_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_DisBetaVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): - return DisentangledBetaVAE(DisentangledBetaVAEConfig(input_dim=(1, 28, 28), latent_dim=7)) + return DisentangledBetaVAE( + DisentangledBetaVAEConfig(input_dim=(1, 28, 28), latent_dim=7) + ) @pytest.fixture( params=[ @@ -847,7 +802,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -862,7 +817,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_FactorVAE.py b/tests/test_FactorVAE.py index 226e27f1..6c5db068 100644 --- a/tests/test_FactorVAE.py +++ b/tests/test_FactorVAE.py @@ -1,24 +1,26 @@ import os -import numpy as np from copy import deepcopy import pytest import torch -from torch.optim import Adam -from pythae.data.preprocessors import DataProcessor from pythae.customexception import BadInheritanceError +from pythae.data.preprocessors import DataProcessor +from pythae.models import AutoModel, FactorVAE, FactorVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import FactorVAE, FactorVAEConfig, AutoModel +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import ( AdversarialTrainer, AdversarialTrainerConfig, BaseTrainerConfig, ) -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig - -from pythae.pipelines import TrainingPipeline, GenerationPipeline - from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -79,10 +81,7 @@ def test_raises_bad_inheritance(self, model_configs, bad_net): factor_ae = FactorVAE(model_configs, decoder=bad_net) def test_raises_no_input_dim( - self, - model_configs_no_input_dim, - custom_encoder, - custom_decoder + self, model_configs_no_input_dim, custom_encoder, custom_decoder ): with pytest.raises(AttributeError): factor_ae = FactorVAE(model_configs_no_input_dim) @@ -94,14 +93,10 @@ def test_raises_no_input_dim( factor_ae = FactorVAE(model_configs_no_input_dim, decoder=custom_decoder) factor_ae = FactorVAE( - model_configs_no_input_dim, - encoder=custom_encoder, - decoder=custom_decoder + model_configs_no_input_dim, encoder=custom_encoder, decoder=custom_decoder ) - def test_build_custom_arch( - self, model_configs, custom_encoder, custom_decoder - ): + def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): factor_ae = FactorVAE( model_configs, encoder=custom_encoder, decoder=custom_decoder @@ -115,6 +110,7 @@ def test_build_custom_arch( assert factor_ae.model_config.uses_default_discriminator + class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): @@ -127,7 +123,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -198,23 +196,14 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder ] ) - def test_full_custom_model_saving( - self, - tmpdir, - model_configs, - custom_encoder, - custom_decoder + self, tmpdir, model_configs, custom_encoder, custom_decoder ): tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") - model = FactorVAE( - model_configs, - encoder=custom_encoder, - decoder=custom_decoder - ) + model = FactorVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) model.state_dict()["encoder.layers.0.0.weight"][0] = 0 @@ -226,7 +215,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -244,21 +233,13 @@ def test_full_custom_model_saving( ) def test_raises_missing_files( - self, - tmpdir, - model_configs, - custom_encoder, - custom_decoder + self, tmpdir, model_configs, custom_encoder, custom_decoder ): tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") - model = FactorVAE( - model_configs, - encoder=custom_encoder, - decoder=custom_decoder - ) + model = FactorVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) model.state_dict()["encoder.layers.0.0.weight"][0] = 0 @@ -309,40 +290,49 @@ def test_model_train_output(self, factor_ae, demo_data): # factor_ae = FactorVAE(model_configs) with pytest.raises(ArithmeticError): - factor_ae({'data': demo_data['data'][0]}) + factor_ae({"data": demo_data["data"][0]}) factor_ae.train() - + out = factor_ae(demo_data) assert isinstance(out, ModelOutput) - assert set( - [ - "loss", - "recon_loss", - "autoencoder_loss", - "discriminator_loss", - "recon_x", - "z", - "z_bis_permuted" - ] - ) == set(out.keys()) + assert ( + set( + [ + "loss", + "recon_loss", + "autoencoder_loss", + "discriminator_loss", + "recon_x", + "z", + "z_bis_permuted", + ] + ) + == set(out.keys()) + ) - assert out.z.shape[0] == int(demo_data["data"].shape[0] / 2) + 1 * (demo_data["data"].shape[0] % 2 != 0) + assert out.z.shape[0] == int(demo_data["data"].shape[0] / 2) + 1 * ( + demo_data["data"].shape[0] % 2 != 0 + ) assert out.z_bis_permuted.shape[0] == int(demo_data["data"].shape[0] / 2) - assert out.recon_x.shape == (int(demo_data["data"].shape[0] / 2) + 1 * (demo_data["data"].shape[0] % 2 != 0),) + (demo_data["data"].shape[1:]) - + assert out.recon_x.shape == ( + int(demo_data["data"].shape[0] / 2) + + 1 * (demo_data["data"].shape[0] % 2 != 0), + ) + (demo_data["data"].shape[1:]) + assert not torch.equal(out.z, out.z_bis_permuted) + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -357,23 +347,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return FactorVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -384,9 +381,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return FactorVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -424,7 +420,7 @@ def train_dataset(self): data = torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ : ] - return DataProcessor().to_dataset(data['data']) + return DataProcessor().to_dataset(data["data"]) @pytest.fixture( params=[ @@ -446,13 +442,7 @@ def training_configs(self, tmpdir, request): torch.rand(1), ] ) - def factor_ae( - self, - model_configs, - custom_encoder, - custom_decoder, - request - ): + def factor_ae(self, model_configs, custom_encoder, custom_decoder, request): # randomized alpha = request.param @@ -475,54 +465,33 @@ def factor_ae( ) elif 0.625 <= alpha < 0: - model = FactorVAE( - model_configs, - encoder=custom_encoder - ) + model = FactorVAE(model_configs, encoder=custom_encoder) elif 0.750 <= alpha < 0.875: - model = FactorVAE( - model_configs, - decoder=custom_decoder - ) + model = FactorVAE(model_configs, decoder=custom_decoder) else: model = FactorVAE( - model_configs, - encoder=custom_encoder, - decoder=custom_decoder + model_configs, encoder=custom_encoder, decoder=custom_decoder ) return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, factor_ae, training_configs): - if request.param is not None: - encoder_optimizer = request.param( - factor_ae.encoder.parameters(), lr=training_configs.learning_rate - ) - decoder_optimizer = request.param( - factor_ae.discriminator.parameters(), - lr=training_configs.learning_rate, - ) - - else: - encoder_optimizer = None - decoder_optimizer = None - - return (encoder_optimizer, decoder_optimizer) - - def test_factor_ae_train_step( - self, factor_ae, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, factor_ae, train_dataset, training_configs): trainer = AdversarialTrainer( model=factor_ae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], ) + trainer.prepare_training() + + return trainer + + def test_factor_ae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -537,17 +506,7 @@ def test_factor_ae_train_step( ] ) - def test_factor_ae_eval_step( - self, factor_ae, train_dataset, training_configs, optimizers - ): - trainer = AdversarialTrainer( - model=factor_ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) + def test_factor_ae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -563,17 +522,7 @@ def test_factor_ae_eval_step( ] ) - def test_factor_ae_predict_step( - self, factor_ae, train_dataset, training_configs, optimizers - ): - trainer = AdversarialTrainer( - model=factor_ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) + def test_factor_ae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -589,22 +538,15 @@ def test_factor_ae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) - assert recon.shape == (int(inputs.shape[0] / 2) + 1 * (inputs.shape[0] % 2 != 0),) + (inputs.shape[1:]) - assert generated.shape == (int(inputs.shape[0] / 2) + 1 * (inputs.shape[0] % 2 != 0),) + (inputs.shape[1:]) + assert inputs.cpu() in train_dataset.data + assert recon.shape == ( + int(inputs.shape[0] / 2) + 1 * (inputs.shape[0] % 2 != 0), + ) + (inputs.shape[1:]) + assert generated.shape == ( + int(inputs.shape[0] / 2) + 1 * (inputs.shape[0] % 2 != 0), + ) + (inputs.shape[1:]) - def test_factor_ae_main_train_loop( - self, tmpdir, factor_ae, train_dataset, training_configs, optimizers - ): - - trainer = AdversarialTrainer( - model=factor_ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) + def test_factor_ae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -620,20 +562,10 @@ def test_factor_ae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, factor_ae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, factor_ae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = AdversarialTrainer( - model=factor_ae, - train_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -753,21 +685,13 @@ def test_checkpoint_saving( ) def test_checkpoint_saving_during_training( - self, tmpdir, factor_ae, train_dataset, training_configs, optimizers + self, factor_ae, trainer, training_configs ): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = AdversarialTrainer( - model=factor_ae, - train_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - model = deepcopy(trainer.model) trainer.train() @@ -820,20 +744,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, factor_ae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, factor_ae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = AdversarialTrainer( - model=factor_ae, - train_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -883,7 +797,7 @@ def test_final_model_saving( assert type(model_rec.discriminator.cpu()) == type(model.discriminator.cpu()) def test_factor_ae_training_pipeline( - self, tmpdir, factor_ae, train_dataset, training_configs + self, factor_ae, train_dataset, training_configs ): with pytest.raises(AssertionError): @@ -948,10 +862,13 @@ def test_factor_ae_training_pipeline( assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) assert type(model_rec.discriminator.cpu()) == type(model.discriminator.cpu()) + class Test_FactorVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -963,7 +880,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -978,7 +895,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_HVAE.py b/tests/test_HVAE.py index 25490435..56f03ea5 100644 --- a/tests/test_HVAE.py +++ b/tests/test_HVAE.py @@ -3,15 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import HVAE, AutoModel, HVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import HVAE, HVAEConfig, AutoModel - +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -35,7 +39,10 @@ def model_configs_no_input_dim(request): learn_beta_zero=True, ), HVAEConfig( - input_dim=(1, 2, 18), latent_dim=5, eps_lf=0.0001, learn_eps_lf=True, + input_dim=(1, 2, 18), + latent_dim=5, + eps_lf=0.0001, + learn_eps_lf=True, ), ] ) @@ -124,7 +131,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -214,7 +223,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -323,14 +332,15 @@ def test_nll_compute(self, hvae, demo_data, nll_params): assert isinstance(nll, float) assert nll < 0 + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -345,23 +355,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return HVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -372,9 +389,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return HVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -422,26 +438,21 @@ def hvae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, hvae, training_configs): - if request.param is not None: - optimizer = request.param( - hvae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_hvae_train_step(self, hvae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, hvae, train_dataset, training_configs): trainer = BaseTrainer( model=hvae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_hvae_train_step(self, hvae, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -474,14 +485,7 @@ def test_hvae_train_step(self, hvae, train_dataset, training_configs, optimizers == step_1_model_state_dict["beta_zero_sqrt"] ) - def test_hvae_eval_step(self, hvae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=hvae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_hvae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -497,16 +501,7 @@ def test_hvae_eval_step(self, hvae, train_dataset, training_configs, optimizers) ] ) - def test_hvae_predict_step( - self, hvae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=hvae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_hvae_predict_step(self, train_dataset, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -522,21 +517,11 @@ def test_hvae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_hvae_main_train_loop( - self, tmpdir, hvae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=hvae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_hvae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -552,19 +537,10 @@ def test_hvae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, hvae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, hvae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=hvae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -646,21 +622,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, hvae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, hvae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=hvae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -708,19 +675,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, hvae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, hvae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=hvae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -830,10 +788,13 @@ def test_hvae_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_HVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -845,7 +806,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -860,7 +821,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_IAF.py b/tests/test_IAF.py index 0116f458..89334cdf 100644 --- a/tests/test_IAF.py +++ b/tests/test_IAF.py @@ -1,21 +1,16 @@ -import pytest import os -import torch -import numpy as np -import shutil - from copy import deepcopy -from torch.optim import Adam -from pythae.models.base.base_utils import ModelOutput -from pythae.models.normalizing_flows import IAF, IAFConfig -from pythae.models.normalizing_flows import NFModel +import numpy as np +import pytest +import torch + from pythae.data.datasets import BaseDataset from pythae.models import AutoModel - - -from pythae.trainers import BaseTrainer, BaseTrainerConfig +from pythae.models.base.base_utils import ModelOutput +from pythae.models.normalizing_flows import IAF, IAFConfig, NFModel from pythae.pipelines import TrainingPipeline +from pythae.trainers import BaseTrainer, BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -80,7 +75,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -210,31 +207,24 @@ def prior(self, model_configs, request): torch.eye(np.prod(model_configs.input_dim)).to(device), ) - @pytest.fixture(params=[Adam]) - def optimizers(self, request, iaf, training_configs): - if request.param is not None: - optimizer = request.param( - iaf.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_iaf_train_step( - self, iaf, prior, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, prior, iaf, train_dataset, training_configs): nf_model = NFModel(prior=prior, flow=iaf) trainer = BaseTrainer( model=nf_model, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_iaf_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -249,19 +239,7 @@ def test_iaf_train_step( ] ) - def test_iaf_eval_step( - self, iaf, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=iaf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_iaf_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -277,18 +255,7 @@ def test_iaf_eval_step( ] ) - def test_iaf_main_train_loop( - self, iaf, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=iaf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_iaf_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -304,21 +271,10 @@ def test_iaf_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, iaf, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=iaf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -383,23 +339,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, iaf, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=iaf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model.flow) trainer.train() @@ -433,21 +378,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, iaf, prior, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=iaf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model.flow) diff --git a/tests/test_IWAE.py b/tests/test_IWAE.py index 9b076f29..cfc0f716 100644 --- a/tests/test_IWAE.py +++ b/tests/test_IWAE.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop from pythae.customexception import BadInheritanceError +from pythae.models import IWAE, AutoModel, IWAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import IWAE, IWAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -20,7 +25,14 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -@pytest.fixture(params=[IWAEConfig(), IWAEConfig(latent_dim=5,)]) +@pytest.fixture( + params=[ + IWAEConfig(), + IWAEConfig( + latent_dim=5, + ), + ] +) def model_configs_no_input_dim(request): return request.param @@ -121,7 +133,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -211,7 +225,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -294,8 +308,7 @@ def test_model_train_output(self, iwae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert ( out.recon_x.shape - == (demo_data["data"].shape[0],) - + demo_data["data"].shape[1:] + == (demo_data["data"].shape[0],) + demo_data["data"].shape[1:] ) @@ -304,9 +317,9 @@ class Test_Model_interpolate: params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -321,23 +334,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return IWAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -348,12 +368,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return IWAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -423,26 +443,21 @@ def iwae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, iwae, training_configs): - if request.param is not None: - optimizer = request.param( - iwae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_iwae_train_step(self, iwae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, iwae, train_dataset, training_configs): trainer = BaseTrainer( model=iwae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_iwae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -457,14 +472,7 @@ def test_iwae_train_step(self, iwae, train_dataset, training_configs, optimizers ] ) - def test_iwae_eval_step(self, iwae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=iwae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_iwae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -480,16 +488,7 @@ def test_iwae_eval_step(self, iwae, train_dataset, training_configs, optimizers) ] ) - def test_iwae_predict_step( - self, iwae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=iwae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_iwae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -505,21 +504,11 @@ def test_iwae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_iwae_main_train_loop( - self, tmpdir, iwae, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=iwae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_iwae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -535,19 +524,10 @@ def test_iwae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, iwae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, iwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=iwae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -629,21 +609,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, iwae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, iwae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=iwae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -691,19 +662,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, iwae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, iwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=iwae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -813,10 +775,13 @@ def test_iwae_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_IWAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -828,7 +793,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -843,7 +808,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_MADE.py b/tests/test_MADE.py index e5634bbb..a6a7a8e2 100644 --- a/tests/test_MADE.py +++ b/tests/test_MADE.py @@ -1,20 +1,12 @@ -import pytest import os -import torch - -from copy import deepcopy -from torch.optim import Adam import numpy as np -import shutil +import pytest +import torch +from pythae.models import AutoModel from pythae.models.base.base_utils import ModelOutput from pythae.models.normalizing_flows import MADE, MADEConfig -from pythae.models.normalizing_flows import NFModel -from pythae.models import AutoModel - -from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline PATH = os.path.dirname(os.path.abspath(__file__)) @@ -73,7 +65,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) diff --git a/tests/test_MAF.py b/tests/test_MAF.py index 720dd671..d50d3c75 100644 --- a/tests/test_MAF.py +++ b/tests/test_MAF.py @@ -1,20 +1,15 @@ -import pytest import os -import torch -import numpy as np -import shutil - from copy import deepcopy -from torch.optim import Adam - -from pythae.models.base.base_utils import ModelOutput -from pythae.models.normalizing_flows import MAF, MAFConfig -from pythae.models.normalizing_flows import NFModel -from pythae.models import AutoModel +import numpy as np +import pytest +import torch -from pythae.trainers import BaseTrainer, BaseTrainerConfig +from pythae.models import AutoModel +from pythae.models.base.base_utils import ModelOutput +from pythae.models.normalizing_flows import MAF, MAFConfig, NFModel from pythae.pipelines import TrainingPipeline +from pythae.trainers import BaseTrainer, BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -72,7 +67,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -202,31 +199,24 @@ def prior(self, model_configs, request): torch.eye(np.prod(model_configs.input_dim)).to(device), ) - @pytest.fixture(params=[Adam]) - def optimizers(self, request, maf, training_configs): - if request.param is not None: - optimizer = request.param( - maf.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_maf_train_step( - self, maf, prior, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, maf, prior, train_dataset, training_configs): nf_model = NFModel(prior=prior, flow=maf) trainer = BaseTrainer( model=nf_model, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_maf_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -241,19 +231,7 @@ def test_maf_train_step( ] ) - def test_maf_eval_step( - self, maf, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=maf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_maf_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -269,18 +247,7 @@ def test_maf_eval_step( ] ) - def test_maf_main_train_loop( - self, maf, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=maf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_maf_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -296,21 +263,10 @@ def test_maf_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, maf, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=maf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -375,23 +331,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, maf, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=maf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model.flow) trainer.train() @@ -425,21 +370,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, maf, prior, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=maf) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model.flow) diff --git a/tests/test_MIWAE.py b/tests/test_MIWAE.py index b7f9427e..80424203 100644 --- a/tests/test_MIWAE.py +++ b/tests/test_MIWAE.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop from pythae.customexception import BadInheritanceError +from pythae.models import MIWAE, AutoModel, MIWAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import MIWAE, MIWAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -20,7 +25,9 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -@pytest.fixture(params=[MIWAEConfig(), MIWAEConfig(latent_dim=5, number_gradient_estimates=3)]) +@pytest.fixture( + params=[MIWAEConfig(), MIWAEConfig(latent_dim=5, number_gradient_estimates=3)] +) def model_configs_no_input_dim(request): return request.param @@ -121,7 +128,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -211,7 +220,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -294,8 +303,7 @@ def test_model_train_output(self, MIWAE, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert ( out.recon_x.shape - == (demo_data["data"].shape[0],) - + demo_data["data"].shape[1:] + == (demo_data["data"].shape[0],) + demo_data["data"].shape[1:] ) @@ -304,9 +312,9 @@ class Test_Model_interpolate: params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -321,23 +329,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return MIWAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -348,12 +363,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return MIWAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -404,7 +419,7 @@ def training_configs(self, tmpdir, request): torch.rand(1), ] ) - def MIWAE(self, model_configs, custom_encoder, custom_decoder, request): + def miwae(self, model_configs, custom_encoder, custom_decoder, request): # randomized alpha = request.param @@ -423,26 +438,21 @@ def MIWAE(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, MIWAE, training_configs): - if request.param is not None: - optimizer = request.param( - MIWAE.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_MIWAE_train_step(self, MIWAE, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, miwae, train_dataset, training_configs): trainer = BaseTrainer( - model=MIWAE, + model=miwae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_miwae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -457,14 +467,7 @@ def test_MIWAE_train_step(self, MIWAE, train_dataset, training_configs, optimize ] ) - def test_MIWAE_eval_step(self, MIWAE, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=MIWAE, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_miwae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -480,16 +483,7 @@ def test_MIWAE_eval_step(self, MIWAE, train_dataset, training_configs, optimizer ] ) - def test_MIWAE_predict_step( - self, MIWAE, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=MIWAE, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_miwae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -505,21 +499,11 @@ def test_MIWAE_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_MIWAE_main_train_loop( - self, tmpdir, MIWAE, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=MIWAE, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_miwae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -535,19 +519,10 @@ def test_MIWAE_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, MIWAE, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, miwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=MIWAE, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -567,14 +542,14 @@ def test_checkpoint_saving( ) # check pickled custom decoder - if not MIWAE.model_config.uses_default_decoder: + if not miwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not MIWAE.model_config.uses_default_encoder: + if not miwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -629,21 +604,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, MIWAE, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, miwae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=MIWAE, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -667,14 +633,14 @@ def test_checkpoint_saving_during_training( ) # check pickled custom decoder - if not MIWAE.model_config.uses_default_decoder: + if not miwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not MIWAE.model_config.uses_default_encoder: + if not miwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -691,19 +657,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, MIWAE, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, miwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=MIWAE, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -723,14 +680,14 @@ def test_final_model_saving( ) # check pickled custom decoder - if not MIWAE.model_config.uses_default_decoder: + if not miwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not MIWAE.model_config.uses_default_encoder: + if not miwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -751,14 +708,12 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_MIWAE_training_pipeline( - self, tmpdir, MIWAE, train_dataset, training_configs - ): + def test_MIWAE_training_pipeline(self, miwae, train_dataset, training_configs): dir_path = training_configs.output_dir # build pipeline - pipeline = TrainingPipeline(model=MIWAE, training_config=training_configs) + pipeline = TrainingPipeline(model=miwae, training_config=training_configs) assert pipeline.training_config.__dict__ == training_configs.__dict__ @@ -785,14 +740,14 @@ def test_MIWAE_training_pipeline( ) # check pickled custom decoder - if not MIWAE.model_config.uses_default_decoder: + if not miwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not MIWAE.model_config.uses_default_encoder: + if not miwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -813,14 +768,21 @@ def test_MIWAE_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_MIWAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): - return MIWAE(MIWAEConfig(input_dim=(1, 28, 28), latent_dim=7, number_gradient_estimates=2)) + return MIWAE( + MIWAEConfig( + input_dim=(1, 28, 28), latent_dim=7, number_gradient_estimates=2 + ) + ) @pytest.fixture( params=[ @@ -828,7 +790,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -843,7 +805,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_MSSSIMVAE.py b/tests/test_MSSSIMVAE.py index 0eb9e6d7..dda5d998 100644 --- a/tests/test_MSSSIMVAE.py +++ b/tests/test_MSSSIMVAE.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import MSSSIM_VAE, AutoModel, MSSSIM_VAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import MSSSIM_VAE, MSSSIM_VAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -122,7 +127,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -214,7 +221,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -299,12 +306,13 @@ def test_model_train_output(self, msssim_vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ] ] ) def demo_data(self, request): @@ -319,21 +327,28 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return MSSSIM_VAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ] ] ) def demo_data(self, request): @@ -344,12 +359,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return MSSSIM_VAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -421,28 +436,21 @@ def msssim_vae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, msssim_vae, training_configs): - if request.param is not None: - optimizer = request.param( - msssim_vae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_msssim_vae_train_step( - self, msssim_vae, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, msssim_vae, train_dataset, training_configs): trainer = BaseTrainer( model=msssim_vae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_msssim_vae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -457,16 +465,7 @@ def test_msssim_vae_train_step( ] ) - def test_msssim_vae_eval_step( - self, msssim_vae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=msssim_vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_msssim_vae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -482,16 +481,7 @@ def test_msssim_vae_eval_step( ] ) - def test_msssim_vae_predict_step( - self, msssim_vae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=msssim_vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_msssim_vae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -507,21 +497,11 @@ def test_msssim_vae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_msssim_vae_main_train_loop( - self, tmpdir, msssim_vae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=msssim_vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_msssim_vae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -537,19 +517,10 @@ def test_msssim_vae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, msssim_vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, msssim_vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=msssim_vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -632,20 +603,13 @@ def test_checkpoint_saving( ) def test_checkpoint_saving_during_training( - self, tmpdir, msssim_vae, train_dataset, training_configs, optimizers + self, msssim_vae, trainer, training_configs ): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=msssim_vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -693,19 +657,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, msssim_vae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, msssim_vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=msssim_vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -754,7 +709,7 @@ def test_final_model_saving( assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_msssim_vae_training_pipeline( - self, tmpdir, msssim_vae, train_dataset, training_configs + self, msssim_vae, train_dataset, training_configs ): dir_path = training_configs.output_dir @@ -815,10 +770,13 @@ def test_msssim_vae_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_MSSSIM_VAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -830,7 +788,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -845,7 +803,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_PIWAE.py b/tests/test_PIWAE.py index e55d9c55..67346738 100644 --- a/tests/test_PIWAE.py +++ b/tests/test_PIWAE.py @@ -3,18 +3,22 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import PIWAE, AutoModel, PIWAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import PIWAE, PIWAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, +) from pythae.trainers import ( + BaseTrainerConfig, CoupledOptimizerTrainer, CoupledOptimizerTrainerConfig, - BaseTrainerConfig, ) -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -31,9 +35,7 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ - PIWAEConfig( - input_dim=(1, 28, 28), latent_dim=10, number_gradient_estimates=3 - ), + PIWAEConfig(input_dim=(1, 28, 28), latent_dim=10, number_gradient_estimates=3), PIWAEConfig(input_dim=(1, 2, 18), latent_dim=5), ] ) @@ -122,7 +124,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -212,7 +216,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -276,42 +280,47 @@ def demo_data(self): return data # This is an extract of 3 data from MNIST (unnormalized) used to test custom architecture @pytest.fixture - def rae(self, model_configs, demo_data): + def piwae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data["data"][0].shape) return PIWAE(model_configs) - def test_model_train_output(self, rae, demo_data): + def test_model_train_output(self, piwae, demo_data): - rae.train() + piwae.train() - out = rae(demo_data) + out = piwae(demo_data) assert isinstance(out, ModelOutput) - assert set([ - "loss", - "reconstruction_loss", - "encoder_loss", - "decoder_loss", - "update_encoder", - "update_decoder", - "reg_loss", - "recon_x", - "z"]) == set( - out.keys() + assert ( + set( + [ + "loss", + "reconstruction_loss", + "encoder_loss", + "decoder_loss", + "update_encoder", + "update_decoder", + "reg_loss", + "recon_x", + "z", + ] + ) + == set(out.keys()) ) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -326,23 +335,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return PIWAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -353,9 +369,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return PIWAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -371,9 +386,10 @@ def train_dataset(self): CoupledOptimizerTrainerConfig( num_epochs=3, steps_saving=2, - learning_rate=1e-5, - encoder_optim_decay=1e-3, - decoder_optim_decay=1e-3, + encoder_learning_rate=1e-5, + decoder_learning_rate=1e-6, + encoder_optimizer_cls="AdamW", + decoder_optimizer_cls="SGD", ) ] ) @@ -392,7 +408,7 @@ def training_configs(self, tmpdir, request): torch.rand(1), ] ) - def rae(self, model_configs, custom_encoder, custom_decoder, request): + def piwae(self, model_configs, custom_encoder, custom_decoder, request): # randomized alpha = request.param @@ -407,37 +423,25 @@ def rae(self, model_configs, custom_encoder, custom_decoder, request): model = PIWAE(model_configs, decoder=custom_decoder) else: - model = PIWAE( - model_configs, encoder=custom_encoder, decoder=custom_decoder - ) + model = PIWAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, rae, training_configs): - if request.param is not None: - encoder_optimizer = request.param( - rae.encoder.parameters(), lr=training_configs.learning_rate - ) - decoder_optimizer = request.param( - rae.decoder.parameters(), lr=training_configs.learning_rate - ) - - else: - encoder_optimizer = None - decoder_optimizer = None - - return (encoder_optimizer, decoder_optimizer) - - def test_rae_train_step(self, rae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, piwae, train_dataset, training_configs): trainer = CoupledOptimizerTrainer( - model=rae, + model=piwae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], ) + trainer.prepare_training() + + return trainer + + def test_piwae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -452,15 +456,7 @@ def test_rae_train_step(self, rae, train_dataset, training_configs, optimizers): ] ) - def test_rae_eval_step(self, rae, train_dataset, training_configs, optimizers): - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) + def test_piwae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -476,17 +472,7 @@ def test_rae_eval_step(self, rae, train_dataset, training_configs, optimizers): ] ) - def test_rae_predict_step( - self, rae, train_dataset, training_configs, optimizers - ): - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) + def test_piwae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -502,22 +488,11 @@ def test_rae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_rae_main_train_loop( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): - - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) + def test_piwae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -533,20 +508,10 @@ def test_rae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, piwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -572,14 +537,14 @@ def test_checkpoint_saving( ).issubset(set(files_list)) # check pickled custom decoder - if not rae.model_config.uses_default_decoder: + if not piwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not rae.model_config.uses_default_encoder: + if not piwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -660,22 +625,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, piwae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) - model = deepcopy(trainer.model) trainer.train() @@ -704,14 +659,14 @@ def test_checkpoint_saving_during_training( ).issubset(set(files_list)) # check pickled custom decoder - if not rae.model_config.uses_default_decoder: + if not piwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not rae.model_config.uses_default_encoder: + if not piwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -728,20 +683,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, piwae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -761,14 +706,14 @@ def test_final_model_saving( ) # check pickled custom decoder - if not rae.model_config.uses_default_decoder: + if not piwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not rae.model_config.uses_default_encoder: + if not piwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -789,15 +734,19 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_configs): + def test_piwae_training_pipeline( + self, tmpdir, piwae, train_dataset, training_configs + ): with pytest.raises(AssertionError): - pipeline = TrainingPipeline(model=rae, training_config=BaseTrainerConfig()) + pipeline = TrainingPipeline( + model=piwae, training_config=BaseTrainerConfig() + ) dir_path = training_configs.output_dir # build pipeline - pipeline = TrainingPipeline(model=rae, training_config=training_configs) + pipeline = TrainingPipeline(model=piwae, training_config=training_configs) assert pipeline.training_config.__dict__ == training_configs.__dict__ @@ -824,14 +773,14 @@ def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_config ) # check pickled custom decoder - if not rae.model_config.uses_default_decoder: + if not piwae.model_config.uses_default_decoder: assert "decoder.pkl" in files_list else: assert not "decoder.pkl" in files_list # check pickled custom encoder - if not rae.model_config.uses_default_encoder: + if not piwae.model_config.uses_default_encoder: assert "encoder.pkl" in files_list else: @@ -852,10 +801,13 @@ def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) -class Test_RAE_Generation: + +class Test_PIWAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -881,7 +833,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_PixelCNN.py b/tests/test_PixelCNN.py index 23ac9e45..4bb3cd49 100644 --- a/tests/test_PixelCNN.py +++ b/tests/test_PixelCNN.py @@ -1,17 +1,15 @@ -import pytest import os -import torch -import numpy as np - from copy import deepcopy -from torch.optim import Adam +import numpy as np +import pytest +import torch + +from pythae.models import AutoModel from pythae.models.base.base_utils import ModelOutput from pythae.models.normalizing_flows import PixelCNN, PixelCNNConfig -from pythae.models import AutoModel - -from pythae.trainers import BaseTrainer, BaseTrainerConfig from pythae.pipelines import TrainingPipeline +from pythae.trainers import BaseTrainer, BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -72,7 +70,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -125,7 +125,7 @@ def demo_data(self): data = torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ : ] - return data # This is an extract of 3 data from MNIST (unnormalized) used to test custom architecture + return data # This is an extract of 3 data from MNIST (unnormalized) used to test custom architecture @pytest.fixture def pixelcnn(self, model_configs, demo_data): @@ -140,8 +140,17 @@ def test_model_train_output(self, pixelcnn, demo_data): assert isinstance(out, ModelOutput) assert set(["out", "loss"]) == set(out.keys()) - - assert out.out.shape == (demo_data["data"].shape[0], pixelcnn.model_config.n_embeddings,demo_data["data"].shape[1] ) + demo_data["data"].shape[2:] + + assert ( + out.out.shape + == ( + demo_data["data"].shape[0], + pixelcnn.model_config.n_embeddings, + demo_data["data"].shape[1], + ) + + demo_data["data"].shape[2:] + ) + @pytest.mark.slow class Test_PixelCNN_Training: @@ -160,9 +169,7 @@ def training_configs(self, tmpdir, request): @pytest.fixture( params=[ - PixelCNNConfig( - input_dim=(1, 28, 28), n_layers=10, kernel_size=7 - ), + PixelCNNConfig(input_dim=(1, 28, 28), n_layers=10, kernel_size=7), ] ) def model_configs(self, request): @@ -173,30 +180,21 @@ def pixelcnn(self, model_configs): model = PixelCNN(model_configs) return model - - @pytest.fixture(params=[Adam]) - def optimizers(self, request, pixelcnn, training_configs): - if request.param is not None: - optimizer = request.param( - pixelcnn.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_pixelcnn_train_step( - self, pixelcnn, train_dataset, training_configs, optimizers - ): - + @pytest.fixture + def trainer(self, pixelcnn, train_dataset, training_configs): trainer = BaseTrainer( model=pixelcnn, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_pixelcnn_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -211,17 +209,7 @@ def test_pixelcnn_train_step( ] ) - def test_pixelcnn_eval_step( - self, pixelcnn, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=pixelcnn, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_pixelcnn_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -231,9 +219,10 @@ def test_pixelcnn_eval_step( # check that weights were updated - for key in start_model_state_dict.keys(): - if not torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]): + if not torch.equal( + start_model_state_dict[key], step_1_model_state_dict[key] + ): print(key, pixelcnn) assert all( @@ -242,16 +231,8 @@ def test_pixelcnn_eval_step( for key in start_model_state_dict.keys() ] ) - def test_pixelcnn_main_train_loop( - self, pixelcnn, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=pixelcnn, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_pixelcnn_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -267,19 +248,10 @@ def test_pixelcnn_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, pixelcnn, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=pixelcnn, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -344,21 +316,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, pixelcnn, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=pixelcnn, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -392,19 +355,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, pixelcnn, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=pixelcnn, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -436,7 +390,7 @@ def test_final_model_saving( ) def test_pixelcnn_training_pipeline( - self, tmpdir, pixelcnn, train_dataset, training_configs + self, pixelcnn, train_dataset, training_configs ): dir_path = training_configs.output_dir @@ -448,8 +402,8 @@ def test_pixelcnn_training_pipeline( # Launch Pipeline pipeline( - train_data=train_dataset.data, # gives tensor to pipeline - eval_data=train_dataset.data, # gives tensor to pipeline + train_data=train_dataset.data, # gives tensor to pipeline + eval_data=train_dataset.data, # gives tensor to pipeline ) model = deepcopy(pipeline.trainer._best_model) diff --git a/tests/test_PoincareVAE.py b/tests/test_PoincareVAE.py index 9c086d79..a292800d 100644 --- a/tests/test_PoincareVAE.py +++ b/tests/test_PoincareVAE.py @@ -2,36 +2,64 @@ from copy import deepcopy import pytest -from sklearn import manifold import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import AutoModel, PoincareVAE, PoincareVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import PoincareVAE, PoincareVAEConfig, AutoModel from pythae.models.pvae.pvae_utils import PoincareBall -from pythae.samplers import PoincareDiskSamplerConfig, NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + PoincareDiskSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, - Encoder_VAE_Conv, Encoder_SVAE_Conv, + Encoder_VAE_Conv, NetBadInheritance, ) PATH = os.path.dirname(os.path.abspath(__file__)) -@pytest.fixture(params=[PoincareVAEConfig(), PoincareVAEConfig(latent_dim=5, prior_distribution="wrapped_normal", posterior_distribution="riemannian_normal", curvature=0.5)]) +@pytest.fixture( + params=[ + PoincareVAEConfig(), + PoincareVAEConfig( + latent_dim=5, + prior_distribution="wrapped_normal", + posterior_distribution="riemannian_normal", + curvature=0.5, + ), + ] +) def model_configs_no_input_dim(request): return request.param @pytest.fixture( params=[ - PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=2, reconstruction_loss="bce", prior_distribution="wrapped_normal", posterior_distribution="wrapped_normal", curvature=0.7), - PoincareVAEConfig(input_dim=(1, 28), latent_dim=5, prior_distribution="riemannian_normal", posterior_distribution="riemannian_normal", curvature=0.8), + PoincareVAEConfig( + input_dim=(1, 28, 28), + latent_dim=2, + reconstruction_loss="bce", + prior_distribution="wrapped_normal", + posterior_distribution="wrapped_normal", + curvature=0.7, + ), + PoincareVAEConfig( + input_dim=(1, 28), + latent_dim=5, + prior_distribution="riemannian_normal", + posterior_distribution="riemannian_normal", + curvature=0.8, + ), ] ) def model_configs(request): @@ -89,7 +117,9 @@ def test_raises_no_input_dim( def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) + model = PoincareVAE( + model_configs, encoder=custom_encoder, decoder=custom_decoder + ) assert model.encoder == custom_encoder assert not model.model_config.uses_default_encoder @@ -109,7 +139,7 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): assert not model.model_config.uses_default_decoder def test_misc_manifold_func(self): - + manifold = PoincareBall(dim=2, c=0.7) x = torch.randn(10, 2) y = torch.randn(10, 2) @@ -120,7 +150,6 @@ def test_misc_manifold_func(self): manifold.normdist2plane(x, x, x, signed=True, norm=True) - class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): @@ -133,7 +162,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -211,7 +242,9 @@ def test_full_custom_model_saving( tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") - model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) + model = PoincareVAE( + model_configs, encoder=custom_encoder, decoder=custom_decoder + ) model.state_dict()["encoder.layers.0.0.weight"][0] = 0 @@ -223,7 +256,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -247,7 +280,9 @@ def test_raises_missing_files( tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") - model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) + model = PoincareVAE( + model_configs, encoder=custom_encoder, decoder=custom_decoder + ) model.state_dict()["encoder.layers.0.0.weight"][0] = 0 @@ -306,14 +341,15 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -328,23 +364,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return PoincareVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -355,9 +398,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return PoincareVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -427,30 +469,27 @@ def vae(self, model_configs, custom_encoder, custom_decoder, request): model = PoincareVAE(model_configs, decoder=custom_decoder) else: - model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) - - return model - - @pytest.fixture(params=[Adam]) - def optimizers(self, request, vae, training_configs): - if request.param is not None: - optimizer = request.param( - vae.parameters(), lr=training_configs.learning_rate + model = PoincareVAE( + model_configs, encoder=custom_encoder, decoder=custom_decoder ) - else: - optimizer = None - - return optimizer + return model - def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, vae, train_dataset, training_configs): trainer = BaseTrainer( model=vae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_vae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -465,14 +504,7 @@ def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -488,16 +520,7 @@ def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_predict_step( - self, vae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -513,21 +536,11 @@ def test_vae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_vae_main_train_loop( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -543,19 +556,10 @@ def test_vae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -637,21 +641,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, vae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -699,19 +694,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -759,7 +745,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_configs): + def test_vae_training_pipeline(self, vae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -819,15 +805,30 @@ def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_VAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data - @pytest.fixture(params=[ - PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=7, prior_distribution="wrapped_normal", curvature=0.2), - PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=2, prior_distribution="riemannian_normal", curvature=0.7) - ]) + @pytest.fixture( + params=[ + PoincareVAEConfig( + input_dim=(1, 28, 28), + latent_dim=7, + prior_distribution="wrapped_normal", + curvature=0.2, + ), + PoincareVAEConfig( + input_dim=(1, 28, 28), + latent_dim=2, + prior_distribution="riemannian_normal", + curvature=0.7, + ), + ] + ) def ae_config(self, request): return request.param @@ -842,7 +843,7 @@ def ae_model(self, ae_config): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -857,7 +858,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_RHVAE.py b/tests/test_RHVAE.py index 80d95cb1..35809316 100644 --- a/tests/test_RHVAE.py +++ b/tests/test_RHVAE.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError -from pythae.models import RHVAE, RHVAEConfig, AutoModel -from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.models import RHVAE, AutoModel, RHVAEConfig from pythae.models.rhvae.rhvae_config import RHVAEConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) +from pythae.trainers import BaseTrainer, BaseTrainerConfig from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -141,7 +146,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -289,7 +296,7 @@ def test_full_custom_model_saving( "encoder.pkl", "decoder.pkl", "metric.pkl", - "environment.json" + "environment.json", ] ) @@ -398,21 +405,24 @@ def test_model_train_output(self, rhvae, demo_data): rhvae.train() out = rhvae(demo_data) - assert set( - [ - "loss", - "recon_x", - "z", - "z0", - "rho", - "eps0", - "gamma", - "mu", - "log_var", - "G_inv", - "G_log_det", - ] - ) == set(out.keys()) + assert ( + set( + [ + "loss", + "recon_x", + "z", + "z0", + "rho", + "eps0", + "gamma", + "mu", + "log_var", + "G_inv", + "G_log_det", + ] + ) + == set(out.keys()) + ) rhvae.update() @@ -423,33 +433,37 @@ def test_model_output(self, rhvae, demo_data): rhvae.eval() out = rhvae(demo_data) - assert set( - [ - "loss", - "recon_x", - "z", - "z0", - "rho", - "eps0", - "gamma", - "mu", - "log_var", - "G_inv", - "G_log_det", - ] - ) == set(out.keys()) + assert ( + set( + [ + "loss", + "recon_x", + "z", + "z0", + "rho", + "eps0", + "gamma", + "mu", + "log_var", + "G_inv", + "G_log_det", + ] + ) + == set(out.keys()) + ) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -464,23 +478,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return RHVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -491,9 +512,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return RHVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -578,26 +598,22 @@ def rhvae( return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, rhvae, training_configs): - if request.param is not None: - optimizer = request.param( - rhvae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_rhvae_train_step(self, rhvae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, rhvae, train_dataset, training_configs): trainer = BaseTrainer( model=rhvae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + return optimizer + + def test_rhvae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -612,14 +628,7 @@ def test_rhvae_train_step(self, rhvae, train_dataset, training_configs, optimize ] ) - def test_rhvae_eval_step(self, rhvae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=rhvae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_rhvae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -635,16 +644,7 @@ def test_rhvae_eval_step(self, rhvae, train_dataset, training_configs, optimizer ] ) - def test_rhvae_predict_step( - self, rhvae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=rhvae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_rhvae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -660,21 +660,11 @@ def test_rhvae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_rhvae_main_train_loop( - self, tmpdir, rhvae, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=rhvae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_rhvae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -690,19 +680,10 @@ def test_rhvae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, rhvae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, rhvae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=rhvae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -798,21 +779,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, rhvae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, rhvae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=rhvae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -867,19 +839,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, rhvae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, rhvae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=rhvae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -937,9 +900,7 @@ def test_final_model_saving( assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) assert type(model_rec.metric.cpu()) == type(model.metric.cpu()) - def test_rhvae_training_pipeline( - self, tmpdir, rhvae, train_dataset, training_configs - ): + def test_rhvae_training_pipeline(self, rhvae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -1009,10 +970,13 @@ def test_rhvae_training_pipeline( assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) assert type(model_rec.metric.cpu()) == type(model.metric.cpu()) + class Test_RHVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -1024,7 +988,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -1039,7 +1003,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_SVAE.py b/tests/test_SVAE.py index e3090f92..e978b30a 100644 --- a/tests/test_SVAE.py +++ b/tests/test_SVAE.py @@ -3,14 +3,13 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import SVAE, AutoModel, SVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import SVAE, SVAEConfig, AutoModel +from pythae.pipelines import GenerationPipeline, TrainingPipeline from pythae.samplers import HypersphereUniformSamplerConfig from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_SVAE_Conv, @@ -116,7 +115,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -206,7 +207,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -289,14 +290,15 @@ def test_model_train_output(self, svae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -311,23 +313,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return SVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -338,9 +347,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return SVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -414,26 +422,21 @@ def svae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, svae, training_configs): - if request.param is not None: - optimizer = request.param( - svae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_svae_train_step(self, svae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, svae, train_dataset, training_configs): trainer = BaseTrainer( model=svae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_svae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -448,14 +451,7 @@ def test_svae_train_step(self, svae, train_dataset, training_configs, optimizers ] ) - def test_svae_eval_step(self, svae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=svae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_svae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -471,16 +467,7 @@ def test_svae_eval_step(self, svae, train_dataset, training_configs, optimizers) ] ) - def test_svae_predict_step( - self, svae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=svae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_svae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -496,21 +483,11 @@ def test_svae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_svae_main_train_loop( - self, tmpdir, svae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=svae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_svae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -526,19 +503,10 @@ def test_svae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, svae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, svae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=svae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -620,21 +588,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, svae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, svae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=svae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -682,19 +641,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, svae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, svae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=svae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -742,9 +692,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_svae_training_pipeline( - self, tmpdir, svae, train_dataset, training_configs - ): + def test_svae_training_pipeline(self, svae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -804,20 +752,19 @@ def test_svae_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_SVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): return SVAE(SVAEConfig(input_dim=(1, 28, 28), latent_dim=7)) - @pytest.fixture( - params=[ - HypersphereUniformSamplerConfig() - ] - ) + @pytest.fixture(params=[HypersphereUniformSamplerConfig()]) def sampler_configs(self, request): return request.param @@ -830,7 +777,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_VAE.py b/tests/test_VAE.py index a244b837..f914daf8 100644 --- a/tests/test_VAE.py +++ b/tests/test_VAE.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import VAE, AutoModel, VAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import VAE, VAEConfig, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -116,7 +121,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -206,7 +213,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -289,14 +296,15 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -311,23 +319,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -338,9 +353,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -414,26 +428,21 @@ def vae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, vae, training_configs): - if request.param is not None: - optimizer = request.param( - vae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, vae, train_dataset, training_configs): trainer = BaseTrainer( model=vae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_vae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -448,14 +457,7 @@ def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -471,16 +473,7 @@ def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_predict_step( - self, vae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -496,21 +489,11 @@ def test_vae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_vae_main_train_loop( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -526,19 +509,10 @@ def test_vae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -620,21 +594,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, vae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -682,19 +647,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -742,7 +698,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_configs): + def test_vae_training_pipeline(self, vae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -802,10 +758,13 @@ def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_VAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -817,7 +776,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -832,7 +791,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_VAEGAN.py b/tests/test_VAEGAN.py index 410127ef..8cd2b2b4 100644 --- a/tests/test_VAEGAN.py +++ b/tests/test_VAEGAN.py @@ -1,25 +1,31 @@ import os -import numpy as np from copy import deepcopy +import numpy as np import pytest import torch from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import VAEGAN, AutoModel, VAEGANConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import VAEGAN, VAEGANConfig, AutoModel +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import ( + BaseTrainerConfig, CoupledOptimizerAdversarialTrainer, CoupledOptimizerAdversarialTrainerConfig, - BaseTrainerConfig, ) -from pythae.pipelines import TrainingPipeline, GenerationPipeline -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig from tests.data.custom_architectures import ( Decoder_AE_Conv, - Encoder_VAE_Conv, Discriminator_MLP_Custom, + Encoder_VAE_Conv, NetBadInheritance, ) @@ -171,7 +177,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -302,7 +310,7 @@ def test_full_custom_model_saving( "encoder.pkl", "decoder.pkl", "discriminator.pkl", - "environment.json" + "environment.json", ] ) @@ -398,32 +406,36 @@ def test_model_train_output(self, vaegan, demo_data): assert isinstance(out, ModelOutput) - assert set( - [ - "loss", - "recon_loss", - "encoder_loss", - "decoder_loss", - "discriminator_loss", - "recon_x", - "z", - "update_discriminator", - "update_encoder", - "update_decoder", - ] - ) == set(out.keys()) + assert ( + set( + [ + "loss", + "recon_loss", + "encoder_loss", + "decoder_loss", + "discriminator_loss", + "recon_x", + "z", + "update_discriminator", + "update_encoder", + "update_decoder", + ] + ) + == set(out.keys()) + ) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -438,23 +450,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAEGAN(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -465,9 +484,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAEGAN(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -579,39 +597,21 @@ def vaegan( return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, vaegan, training_configs): - if request.param is not None: - encoder_optimizer = request.param( - vaegan.encoder.parameters(), lr=training_configs.learning_rate - ) - decoder_optimizer = request.param( - vaegan.decoder.parameters(), lr=training_configs.learning_rate - ) - discriminator_optimizer = request.param( - vaegan.discriminator.parameters(), - lr=training_configs.learning_rate, - ) - - else: - encoder_optimizer = None - decoder_optimizer = None - discriminator_optimizer = None - - return (encoder_optimizer, decoder_optimizer, discriminator_optimizer) - - def test_vaegan_train_step( - self, vaegan, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, vaegan, train_dataset, training_configs): trainer = CoupledOptimizerAdversarialTrainer( model=vaegan, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[1], ) + trainer.prepare_training() + + return trainer + + def test_vaegan_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -626,18 +626,7 @@ def test_vaegan_train_step( ] ) - def test_vaegan_eval_step( - self, vaegan, train_dataset, training_configs, optimizers - ): - trainer = CoupledOptimizerAdversarialTrainer( - model=vaegan, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[1], - ) + def test_vaegan_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -653,18 +642,7 @@ def test_vaegan_eval_step( ] ) - def test_vaegan_predict_step( - self, vaegan, train_dataset, training_configs, optimizers - ): - trainer = CoupledOptimizerAdversarialTrainer( - model=vaegan, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[1], - ) + def test_vaegan_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -680,23 +658,11 @@ def test_vaegan_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_vaegan_main_train_loop( - self, tmpdir, vaegan, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = CoupledOptimizerAdversarialTrainer( - model=vaegan, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[1], - ) + def test_vaegan_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -712,21 +678,10 @@ def test_vaegan_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, vaegan, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, vaegan, trainer, training_configs): dir_path = training_configs.output_dir - trainer = CoupledOptimizerAdversarialTrainer( - model=vaegan, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[1], - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -877,22 +832,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, vaegan, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, vaegan, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = CoupledOptimizerAdversarialTrainer( - model=vaegan, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - model = deepcopy(trainer.model) trainer.train() @@ -953,20 +898,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, vaegan, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, vaegan, trainer, training_configs): dir_path = training_configs.output_dir - trainer = CoupledOptimizerAdversarialTrainer( - model=vaegan, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -1022,9 +957,7 @@ def test_final_model_saving( assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) assert type(model_rec.discriminator.cpu()) == type(model.discriminator.cpu()) - def test_vaegan_training_pipeline( - self, tmpdir, vaegan, train_dataset, training_configs - ): + def test_vaegan_training_pipeline(self, vaegan, train_dataset, training_configs): with pytest.raises(AssertionError): pipeline = TrainingPipeline( @@ -1097,10 +1030,13 @@ def test_vaegan_training_pipeline( assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) assert type(model_rec.discriminator.cpu()) == type(model.discriminator.cpu()) + class Test_VAEGAN_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -1112,7 +1048,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -1127,7 +1063,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_VAE_IAF.py b/tests/test_VAE_IAF.py index 09eb68ca..6e9b7274 100644 --- a/tests/test_VAE_IAF.py +++ b/tests/test_VAE_IAF.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import VAE_IAF, AutoModel, VAE_IAF_Config from pythae.models.base.base_utils import ModelOutput -from pythae.models import VAE_IAF, VAE_IAF_Config, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -122,7 +127,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -212,7 +219,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -295,14 +302,15 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -317,23 +325,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAE_IAF(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -344,9 +359,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAE_IAF(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -422,26 +436,21 @@ def vae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, vae, training_configs): - if request.param is not None: - optimizer = request.param( - vae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, vae, train_dataset, training_configs): trainer = BaseTrainer( model=vae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_vae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -456,14 +465,7 @@ def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -479,16 +481,7 @@ def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_predict_step( - self, vae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -504,21 +497,11 @@ def test_vae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_vae_main_train_loop( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -534,19 +517,10 @@ def test_vae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -628,21 +602,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, vae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -690,19 +655,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -750,7 +706,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_configs): + def test_vae_training_pipeline(self, vae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -810,10 +766,13 @@ def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_VAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -825,7 +784,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -840,7 +799,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py index f392c46a..91d121b0 100644 --- a/tests/test_VAE_LinFlow.py +++ b/tests/test_VAE_LinFlow.py @@ -3,14 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import AutoModel, VAE_LinNF, VAE_LinNF_Config from pythae.models.base.base_utils import ModelOutput -from pythae.models import VAE_LinNF, VAE_LinNF_Config, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -136,7 +141,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -226,7 +233,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -309,14 +316,15 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -331,23 +339,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAE_LinNF(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -358,14 +373,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAE_LinNF(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape - class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -437,26 +450,21 @@ def vae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, vae, training_configs): - if request.param is not None: - optimizer = request.param( - vae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, vae, train_dataset, training_configs): trainer = BaseTrainer( model=vae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_vae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -471,14 +479,7 @@ def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -494,16 +495,7 @@ def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_predict_step( - self, vae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -519,21 +511,11 @@ def test_vae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_vae_main_train_loop( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -549,19 +531,10 @@ def test_vae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -643,21 +616,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, vae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -705,19 +669,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -765,7 +720,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_configs): + def test_vae_training_pipeline(self, vae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -825,10 +780,13 @@ def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_VAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -840,7 +798,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -855,7 +813,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_VAMP.py b/tests/test_VAMP.py index 3a747bfd..548f68aa 100644 --- a/tests/test_VAMP.py +++ b/tests/test_VAMP.py @@ -3,14 +3,13 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import VAMP, AutoModel, VAMPConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import VAMP, VAMPConfig, AutoModel +from pythae.pipelines import GenerationPipeline, TrainingPipeline from pythae.samplers import VAMPSamplerConfig from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -120,7 +119,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -210,7 +211,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -293,14 +294,15 @@ def test_model_train_output(self, vamp, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -315,23 +317,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAMP(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -342,14 +351,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VAMP(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape - class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -419,26 +426,21 @@ def vamp(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, vamp, training_configs): - if request.param is not None: - optimizer = request.param( - vamp.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_vamp_train_step(self, vamp, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, vamp, train_dataset, training_configs): trainer = BaseTrainer( model=vamp, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_vamp_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -453,16 +455,7 @@ def test_vamp_train_step(self, vamp, train_dataset, training_configs, optimizers ] ) - def test_vamp_predict_step( - self, vamp, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=vamp, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vamp_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -470,21 +463,11 @@ def test_vamp_predict_step( step_1_model_state_dict = deepcopy(trainer.model.state_dict()) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_vamp_main_train_loop( - self, tmpdir, vamp, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=vamp, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vamp_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -500,19 +483,10 @@ def test_vamp_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, vamp, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, vamp, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vamp, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -594,21 +568,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, vamp, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, vamp, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vamp, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -656,19 +621,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, vamp, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, vamp, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vamp, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -716,9 +672,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_vamp_training_pipeline( - self, tmpdir, vamp, train_dataset, training_configs - ): + def test_vamp_training_pipeline(self, vamp, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -778,20 +732,19 @@ def test_vamp_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_VAMP_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): return VAMP(VAMPConfig(input_dim=(1, 28, 28), latent_dim=7)) - @pytest.fixture( - params=[ - VAMPSamplerConfig() - ] - ) + @pytest.fixture(params=[VAMPSamplerConfig()]) def sampler_configs(self, request): return request.param @@ -804,7 +757,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_VQVAE.py b/tests/test_VQVAE.py index b053c414..9ed177ab 100644 --- a/tests/test_VQVAE.py +++ b/tests/test_VQVAE.py @@ -1,18 +1,16 @@ import os +from copy import deepcopy import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop - -from copy import deepcopy from pythae.customexception import BadInheritanceError +from pythae.models import VQVAE, AutoModel, VQVAEConfig from pythae.models.base.base_utils import ModelOutput -from pythae.models import VQVAE, VQVAEConfig, AutoModel from pythae.models.vq_vae.vq_vae_utils import Quantizer, QuantizerEMA +from pythae.pipelines import GenerationPipeline, TrainingPipeline from pythae.samplers import PixelCNNSamplerConfig from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_AE_Conv, @@ -38,7 +36,7 @@ def model_configs_no_input_dim(request): quantization_loss_factor=0.18, latent_dim=16, use_ema=True, - decay=0.001 + decay=0.001, ), ] ) @@ -135,7 +133,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -225,7 +225,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -301,19 +301,22 @@ def test_model_train_output(self, vae, demo_data): assert isinstance(out, ModelOutput) - assert set(["loss", "recon_loss", "vq_loss", "recon_x", "z", "quantized_indices"]) == set(out.keys()) + assert set( + ["loss", "recon_loss", "vq_loss", "recon_x", "z", "quantized_indices"] + ) == set(out.keys()) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -328,23 +331,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VQVAE(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -355,9 +365,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return VQVAE(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -405,26 +414,21 @@ def vae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[None, Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, vae, training_configs): - if request.param is not None: - optimizer = request.param( - vae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, vae, train_dataset, training_configs): trainer = BaseTrainer( model=vae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_vae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -439,14 +443,7 @@ def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -462,16 +459,7 @@ def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): ] ) - def test_vae_predict_step( - self, vae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -487,21 +475,11 @@ def test_vae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_vae_main_train_loop( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_vae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -517,19 +495,10 @@ def test_vae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -611,21 +580,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, vae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -673,19 +633,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, vae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, vae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=vae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -733,7 +684,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_configs): + def test_vae_training_pipeline(self, vae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -791,20 +742,19 @@ def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_VQVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): return VQVAE(VQVAEConfig(input_dim=(1, 28, 28), latent_dim=4)) - @pytest.fixture( - params=[ - PixelCNNSamplerConfig() - ] - ) + @pytest.fixture(params=[PixelCNNSamplerConfig()]) def sampler_configs(self, request): return request.param @@ -817,7 +767,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_WAE_MMD.py b/tests/test_WAE_MMD.py index ad7e8125..cb24dc1c 100644 --- a/tests/test_WAE_MMD.py +++ b/tests/test_WAE_MMD.py @@ -3,14 +3,18 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import WAE_MMD, AutoModel, WAE_MMD_Config from pythae.models.base.base_utils import ModelOutput -from pythae.models import WAE_MMD, WAE_MMD_Config, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_AE_Conv, @@ -27,7 +31,9 @@ def model_configs_no_input_dim(request): @pytest.fixture( params=[ - WAE_MMD_Config(input_dim=(1, 28, 28), latent_dim=10, kernel_choice="rbf", scales=None), + WAE_MMD_Config( + input_dim=(1, 28, 28), latent_dim=10, kernel_choice="rbf", scales=None + ), WAE_MMD_Config( input_dim=(1, 2, 18), latent_dim=5, reg_weight=1.0, kernel_bandwidth=0.2 ), @@ -118,7 +124,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -208,7 +216,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -291,14 +299,15 @@ def test_model_train_output(self, wae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -313,23 +322,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return WAE_MMD(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -340,9 +356,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return WAE_MMD(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -392,26 +407,21 @@ def wae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, wae, training_configs): - if request.param is not None: - optimizer = request.param( - wae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_wae_train_step(self, wae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, wae, train_dataset, training_configs): trainer = BaseTrainer( model=wae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_wae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -426,14 +436,7 @@ def test_wae_train_step(self, wae, train_dataset, training_configs, optimizers): ] ) - def test_wae_eval_step(self, wae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=wae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_wae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -449,16 +452,7 @@ def test_wae_eval_step(self, wae, train_dataset, training_configs, optimizers): ] ) - def test_wae_predict_step( - self, wae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=wae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_wae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -474,21 +468,11 @@ def test_wae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_wae_main_train_loop( - self, tmpdir, wae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=wae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_wae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -504,19 +488,10 @@ def test_wae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, wae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, wae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=wae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -598,21 +573,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, wae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, wae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=wae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -660,19 +626,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, wae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, wae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=wae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -720,7 +677,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_wae_training_pipeline(self, tmpdir, wae, train_dataset, training_configs): + def test_wae_training_pipeline(self, wae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -780,10 +737,13 @@ def test_wae_training_pipeline(self, tmpdir, wae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_WAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -794,7 +754,7 @@ def ae_model(self): NormalSamplerConfig(), GaussianMixtureSamplerConfig(), MAFSamplerConfig(), - IAFSamplerConfig() + IAFSamplerConfig(), ] ) def sampler_configs(self, request): @@ -809,7 +769,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_adversarial_trainer.py b/tests/test_adversarial_trainer.py index 78ab8e4b..12ed6dba 100644 --- a/tests/test_adversarial_trainer.py +++ b/tests/test_adversarial_trainer.py @@ -2,10 +2,7 @@ from copy import deepcopy import pytest -import itertools import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop -from torch.optim.lr_scheduler import StepLR, LinearLR, ExponentialLR from pythae.models import Adversarial_AE, Adversarial_AE_Config from pythae.trainers import AdversarialTrainer, AdversarialTrainerConfig @@ -34,12 +31,13 @@ def training_config(tmpdir): class Test_DataLoader: @pytest.fixture( params=[ - AdversarialTrainerConfig(autoencoder_optim_decay=0), - AdversarialTrainerConfig(batch_size=100, autoencoder_optim_decay=1e-7), + AdversarialTrainerConfig(), + AdversarialTrainerConfig( + per_device_train_batch_size=100, + per_device_eval_batch_size=35, + ), AdversarialTrainerConfig( - batch_size=10, - autoencoder_optim_decay=1e-7, - discriminator_optim_decay=1e-7, + per_device_train_batch_size=10, per_device_eval_batch_size=3 ), ] ) @@ -63,7 +61,10 @@ def test_build_train_data_loader( assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) assert train_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + train_data_loader.batch_size + == trainer.training_config.per_device_train_batch_size + ) def test_build_eval_data_loader( self, model_sample, train_dataset, training_config_batch_size @@ -74,20 +75,32 @@ def test_build_eval_data_loader( training_config=training_config_batch_size, ) - train_data_loader = trainer.get_eval_dataloader(train_dataset) + eval_data_loader = trainer.get_eval_dataloader(train_dataset) - assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) - assert train_data_loader.dataset == train_dataset + assert issubclass(type(eval_data_loader), torch.utils.data.DataLoader) + assert eval_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + eval_data_loader.batch_size + == trainer.training_config.per_device_eval_batch_size + ) class Test_Set_Training_config: @pytest.fixture( params=[ - AdversarialTrainerConfig(autoencoder_optim_decay=0), + AdversarialTrainerConfig(), AdversarialTrainerConfig( - batch_size=10, learning_rate=1e-3, autoencoder_optim_decay=0 + per_device_train_batch_size=10, + per_device_eval_batch_size=10, + autoencoder_learning_rate=1e-3, + discriminator_learning_rate=1e-5, + autoencoder_optimizer_cls="AdamW", + autoencoder_optimizer_params={"weight_decay": 0.01}, + discriminator_optimizer_cls="SGD", + discriminator_optimizer_params={"weight_decay": 0.01}, + autoencoder_scheduler_cls="ExponentialLR", + autoencoder_scheduler_params={"gamma": 0.321}, ), ] ) @@ -115,9 +128,31 @@ def test_set_training_config(self, model_sample, train_dataset, training_configs class Test_Build_Optimizer: + def test_wrong_optimizer_cls(self): + with pytest.raises(AttributeError): + AdversarialTrainerConfig(autoencoder_optimizer_cls="WrongOptim") + + with pytest.raises(AttributeError): + AdversarialTrainerConfig(discriminator_optimizer_cls="WrongOptim") + + def test_wrong_optimizer_params(self): + with pytest.raises(TypeError): + AdversarialTrainerConfig( + autoencoder_optimizer_cls="Adam", + autoencoder_optimizer_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + AdversarialTrainerConfig( + discriminator_optimizer_cls="Adam", + discriminator_optimizer_params={"wrong_config": 1}, + ) + @pytest.fixture( params=[ - AdversarialTrainerConfig(learning_rate=1e-2), + AdversarialTrainerConfig( + autoencoder_learning_rate=1e-2, discriminator_learning_rate=1e-3 + ), AdversarialTrainerConfig(), ] ) @@ -125,18 +160,47 @@ def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): + @pytest.fixture( + params=[ + { + "autoencoder_optimizer_cls": "Adagrad", + "autoencoder_optimizer_params": {"lr_decay": 0.1}, + "discriminator_optimizer_cls": "AdamW", + "discriminator_optimizer_params": {"betas": (0.1234, 0.4321)}, + }, + { + "autoencoder_optimizer_cls": "SGD", + "autoencoder_optimizer_params": {"momentum": 0.1}, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": {"momentum": 0.9}, + }, + { + "autoencoder_optimizer_cls": "SGD", + "autoencoder_optimizer_params": None, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": None, + }, + ] + ) + def optimizer_config(self, request, training_configs_learning_rate): - autoencoder_optimizer = request.param( - model_sample.encoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - discriminator_optimizer = request.param( - model_sample.decoder.parameters(), - lr=training_configs_learning_rate.learning_rate, + optimizer_config = request.param + + # set optim and params to training config + training_configs_learning_rate.autoencoder_optimizer_cls = optimizer_config[ + "autoencoder_optimizer_cls" + ] + training_configs_learning_rate.autoencoder_optimizer_params = optimizer_config[ + "autoencoder_optimizer_params" + ] + training_configs_learning_rate.discriminator_optimizer_cls = optimizer_config[ + "discriminator_optimizer_cls" + ] + training_configs_learning_rate.discriminator_optimizer_params = ( + optimizer_config["discriminator_optimizer_params"] ) - return (autoencoder_optimizer, discriminator_optimizer) + + return optimizer_config def test_default_optimizer_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -146,83 +210,146 @@ def test_default_optimizer_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - autoencoder_optimizer=None, - discriminator_optimizer=None, ) + trainer.set_autoencoder_optimizer() + trainer.set_discriminator_optimizer() + assert issubclass(type(trainer.autoencoder_optimizer), torch.optim.Adam) assert ( trainer.autoencoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.autoencoder_learning_rate ) assert issubclass(type(trainer.discriminator_optimizer), torch.optim.Adam) assert ( trainer.discriminator_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.discriminator_learning_rate ) def test_set_custom_optimizer( - self, model_sample, train_dataset, training_configs_learning_rate, optimizers + self, + model_sample, + train_dataset, + training_configs_learning_rate, + optimizer_config, ): trainer = AdversarialTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], ) - assert issubclass(type(trainer.autoencoder_optimizer), type(optimizers[0])) + trainer.set_autoencoder_optimizer() + trainer.set_discriminator_optimizer() + + assert issubclass( + type(trainer.autoencoder_optimizer), + getattr(torch.optim, optimizer_config["autoencoder_optimizer_cls"]), + ) assert ( trainer.autoencoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.autoencoder_learning_rate ) + if optimizer_config["autoencoder_optimizer_params"] is not None: + assert all( + [ + trainer.autoencoder_optimizer.defaults[key] + == optimizer_config["autoencoder_optimizer_params"][key] + for key in optimizer_config["autoencoder_optimizer_params"].keys() + ] + ) - assert issubclass(type(trainer.discriminator_optimizer), type(optimizers[1])) + assert issubclass( + type(trainer.discriminator_optimizer), + getattr(torch.optim, optimizer_config["discriminator_optimizer_cls"]), + ) assert ( trainer.discriminator_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.discriminator_learning_rate ) + if optimizer_config["discriminator_optimizer_params"] is not None: + assert all( + [ + trainer.discriminator_optimizer.defaults[key] + == optimizer_config["discriminator_optimizer_params"][key] + for key in optimizer_config["discriminator_optimizer_params"].keys() + ] + ) + class Test_Build_Scheduler: - @pytest.fixture(params=[AdversarialTrainerConfig(), AdversarialTrainerConfig(learning_rate=1e-5)]) + def test_wrong_scheduler_cls(self): + with pytest.raises(AttributeError): + AdversarialTrainerConfig(autoencoder_scheduler_cls="WrongOptim") + + with pytest.raises(AttributeError): + AdversarialTrainerConfig(discriminator_scheduler_cls="WrongOptim") + + def test_wrong_scheduler_params(self): + with pytest.raises(TypeError): + AdversarialTrainerConfig( + autoencoder_scheduler_cls="ReduceLROnPlateau", + autoencoder_scheduler_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + AdversarialTrainerConfig( + discriminator_scheduler_cls="ReduceLROnPlateau", + discriminator_scheduler_params={"wrong_config": 1}, + ) + + @pytest.fixture( + params=[ + AdversarialTrainerConfig(), + AdversarialTrainerConfig(learning_rate=1e-5), + ] + ) def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): - - autoencoder_optimizer = request.param( - model_sample.encoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - discriminator_optimizer = request.param( - model_sample.decoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - return (autoencoder_optimizer, discriminator_optimizer) - @pytest.fixture( params=[ - (StepLR, {"step_size": 1}), - (LinearLR, {"start_factor": 0.01}), - (ExponentialLR, {"gamma": 0.1}), + { + "autoencoder_scheduler_cls": "StepLR", + "autoencoder_scheduler_params": {"step_size": 1}, + "discriminator_scheduler_cls": "LinearLR", + "discriminator_scheduler_params": None, + }, + { + "autoencoder_scheduler_cls": None, + "autoencoder_scheduler_params": None, + "discriminator_scheduler_cls": "ExponentialLR", + "discriminator_scheduler_params": {"gamma": 0.1}, + }, + { + "autoencoder_scheduler_cls": "ReduceLROnPlateau", + "autoencoder_scheduler_params": {"patience": 12}, + "discriminator_scheduler_cls": None, + "discriminator_scheduler_params": None, + }, ] ) - def schedulers( - self, request, optimizers - ): - if request.param[0] is not None: - autoencoder_scheduler = request.param[0](optimizers[0], **request.param[1]) - discriminator_scheduler = request.param[0](optimizers[1], **request.param[1]) + def scheduler_config(self, request, training_configs_learning_rate): - else: - autoencoder_scheduler = None - discriminator_scheduler = None + scheduler_config = request.param + + # set scheduler and params to training config + training_configs_learning_rate.autoencoder_scheduler_cls = scheduler_config[ + "autoencoder_scheduler_cls" + ] + training_configs_learning_rate.autoencoder_scheduler_params = scheduler_config[ + "autoencoder_scheduler_params" + ] + training_configs_learning_rate.discriminator_scheduler_cls = scheduler_config[ + "discriminator_scheduler_cls" + ] + training_configs_learning_rate.discriminator_scheduler_params = ( + scheduler_config["discriminator_scheduler_params"] + ) - return (autoencoder_scheduler, discriminator_scheduler) + return request.param def test_default_scheduler_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -232,39 +359,99 @@ def test_default_scheduler_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - autoencoder_optimizer=None, - discriminator_optimizer=None ) - assert issubclass( - type(trainer.autoencoder_scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau - ) + trainer.set_autoencoder_optimizer() + trainer.set_autoencoder_scheduler() + trainer.set_discriminator_optimizer() + trainer.set_discriminator_scheduler() - assert issubclass( - type(trainer.discriminator_scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau - ) + assert trainer.autoencoder_scheduler is None + assert trainer.discriminator_scheduler is None def test_set_custom_scheduler( self, model_sample, train_dataset, training_configs_learning_rate, - optimizers, - schedulers, + scheduler_config, ): trainer = AdversarialTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - autoencoder_optimizer=optimizers[0], - autoencoder_scheduler=schedulers[0], - discriminator_optimizer=optimizers[1], - discriminator_scheduler=schedulers[1] - ) - assert issubclass(type(trainer.autoencoder_scheduler), type(schedulers[0])) - assert issubclass(type(trainer.discriminator_scheduler), type(schedulers[1])) + trainer.set_autoencoder_optimizer() + trainer.set_autoencoder_scheduler() + trainer.set_discriminator_optimizer() + trainer.set_discriminator_scheduler() + + if scheduler_config["autoencoder_scheduler_cls"] is None: + assert trainer.autoencoder_scheduler is None + else: + assert issubclass( + type(trainer.autoencoder_scheduler), + getattr( + torch.optim.lr_scheduler, + scheduler_config["autoencoder_scheduler_cls"], + ), + ) + if scheduler_config["autoencoder_scheduler_params"] is not None: + assert all( + [ + trainer.autoencoder_scheduler.state_dict()[key] + == scheduler_config["autoencoder_scheduler_params"][key] + for key in scheduler_config[ + "autoencoder_scheduler_params" + ].keys() + ] + ) + + if scheduler_config["discriminator_scheduler_cls"] is None: + assert trainer.discriminator_scheduler is None + + else: + assert issubclass( + type(trainer.discriminator_scheduler), + getattr( + torch.optim.lr_scheduler, + scheduler_config["discriminator_scheduler_cls"], + ), + ) + if scheduler_config["discriminator_scheduler_params"] is not None: + assert all( + [ + trainer.discriminator_scheduler.state_dict()[key] + == scheduler_config["discriminator_scheduler_params"][key] + for key in scheduler_config[ + "discriminator_scheduler_params" + ].keys() + ] + ) + + +class Test_Device_Checks: + def test_set_environ_variable(self): + os.environ["LOCAL_RANK"] = "1" + os.environ["WORLD_SIZE"] = "4" + os.environ["RANK"] = "3" + os.environ["MASTER_ADDR"] = "314" + os.environ["MASTER_PORT"] = "222" + + trainer_config = AdversarialTrainerConfig() + + assert int(trainer_config.local_rank) == 1 + assert int(trainer_config.world_size) == 4 + assert int(trainer_config.rank) == 3 + assert trainer_config.master_addr == "314" + assert trainer_config.master_port == "222" + + del os.environ["LOCAL_RANK"] + del os.environ["WORLD_SIZE"] + del os.environ["RANK"] + del os.environ["MASTER_ADDR"] + del os.environ["MASTER_PORT"] @pytest.mark.slow @@ -356,58 +543,102 @@ def ae( return model - @pytest.fixture(params=[None, Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, ae, training_configs): - if request.param is not None: - autoencoder_optimizer = request.param( - itertools.chain(ae.encoder.parameters(), ae.decoder.parameters()), lr=training_configs.learning_rate - ) - - discriminator_optimizer = request.param( - ae.discriminator.parameters(), lr=training_configs.learning_rate - ) + @pytest.fixture( + params=[ + { + "autoencoder_optimizer_cls": "Adagrad", + "autoencoder_optimizer_params": {"lr_decay": 0.1}, + "discriminator_optimizer_cls": "AdamW", + "discriminator_optimizer_params": {"betas": (0.1234, 0.4321)}, + }, + { + "autoencoder_optimizer_cls": "SGD", + "autoencoder_optimizer_params": {"momentum": 0.1}, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": {"momentum": 0.9}, + }, + { + "autoencoder_optimizer_cls": "SGD", + "autoencoder_optimizer_params": None, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": None, + }, + ] + ) + def optimizer_config(self, request): - else: - autoencoder_optimizer = None - discriminator_optimizer = None + optimizer_config = request.param - return (autoencoder_optimizer, discriminator_optimizer) + return optimizer_config @pytest.fixture( params=[ - (None, None), - (StepLR, {"step_size": 1, "gamma": 0.99}), - (LinearLR, {"start_factor": 0.99}), - (ExponentialLR, {"gamma": 0.99}), + { + "autoencoder_scheduler_cls": "LinearLR", + "autoencoder_scheduler_params": None, + "discriminator_scheduler_cls": "LinearLR", + "discriminator_scheduler_params": None, + }, + { + "autoencoder_scheduler_cls": None, + "autoencoder_scheduler_params": None, + "discriminator_scheduler_cls": "ExponentialLR", + "discriminator_scheduler_params": {"gamma": 0.13}, + }, + { + "autoencoder_scheduler_cls": "ReduceLROnPlateau", + "autoencoder_scheduler_params": {"patience": 12}, + "discriminator_scheduler_cls": None, + "discriminator_scheduler_params": None, + }, ] ) - def schedulers(self, request, optimizers): - if request.param[0] is not None and optimizers[0] is not None: - autoencoder_scheduler = request.param[0](optimizers[0], **request.param[1]) - - else: - autoencoder_scheduler = None - - if request.param[0] is not None and optimizers[1] is not None: - discriminator_scheduler = request.param[0](optimizers[1], **request.param[1]) - - else: - discriminator_scheduler = None + def scheduler_config(self, request): + return request.param - return (autoencoder_scheduler, discriminator_scheduler) + @pytest.fixture + def trainer( + self, ae, train_dataset, optimizer_config, scheduler_config, training_configs + ): + training_configs.autoencoder_optimizer_cls = optimizer_config[ + "autoencoder_optimizer_cls" + ] + training_configs.autoencoder_optimizer_params = optimizer_config[ + "autoencoder_optimizer_params" + ] + training_configs.discriminator_optimizer_cls = optimizer_config[ + "discriminator_optimizer_cls" + ] + training_configs.discriminator_optimizer_params = optimizer_config[ + "discriminator_optimizer_params" + ] + training_configs.autoencoder_scheduler_cls = scheduler_config[ + "autoencoder_scheduler_cls" + ] + training_configs.autoencoder_scheduler_params = scheduler_config[ + "autoencoder_scheduler_params" + ] + training_configs.discriminator_scheduler_cls = scheduler_config[ + "discriminator_scheduler_cls" + ] + training_configs.discriminator_scheduler_params = scheduler_config[ + "discriminator_scheduler_params" + ] - def test_train_step(self, ae, train_dataset, training_configs, optimizers, schedulers): trainer = AdversarialTrainer( model=ae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - autoencoder_scheduler=schedulers[0], - discriminator_scheduler=schedulers[1] ) + trainer.prepare_training() + + return trainer + + def test_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=3) @@ -417,26 +648,21 @@ def test_train_step(self, ae, train_dataset, training_configs, optimizers, sched # check that weights were updated for key in start_model_state_dict.keys(): if "encoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) if "decoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) if "discriminator" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) - + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) - def test_eval_step(self, ae, train_dataset, training_configs, optimizers, schedulers): - trainer = AdversarialTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - autoencoder_scheduler=schedulers[0], - discriminator_scheduler=schedulers[1] - ) + def test_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -452,20 +678,7 @@ def test_eval_step(self, ae, train_dataset, training_configs, optimizers, schedu ] ) - def test_main_train_loop( - self, ae, train_dataset, training_configs, optimizers, schedulers - ): - - trainer = AdversarialTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - autoencoder_optimizer=optimizers[0], - discriminator_optimizer=optimizers[1], - autoencoder_scheduler=schedulers[0], - discriminator_scheduler=schedulers[1] - ) + def test_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -476,13 +689,20 @@ def test_main_train_loop( # check that weights were updated for key in start_model_state_dict.keys(): if "encoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) if "decoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) if "discriminator" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) + class Test_Logging: @pytest.fixture diff --git a/tests/test_auto_model.py b/tests/test_auto_model.py index 0f5ff164..7be99742 100644 --- a/tests/test_auto_model.py +++ b/tests/test_auto_model.py @@ -1,19 +1,23 @@ -import pytest import os +import pytest + from pythae.models import AutoModel PATH = os.path.dirname(os.path.abspath(__file__)) + @pytest.fixture() def corrupted_config(): return os.path.join(PATH, "data", "corrupted_config") + def test_raises_file_not_found(): - + with pytest.raises(FileNotFoundError): - AutoModel.load_from_folder('wrong_file_dir') + AutoModel.load_from_folder("wrong_file_dir") + def test_raises_name_error(corrupted_config): with pytest.raises(NameError): - AutoModel.load_from_folder(corrupted_config) \ No newline at end of file + AutoModel.load_from_folder(corrupted_config) diff --git a/tests/test_baseAE.py b/tests/test_baseAE.py index c02c90c1..c0081cb0 100644 --- a/tests/test_baseAE.py +++ b/tests/test_baseAE.py @@ -1,8 +1,8 @@ import os +import shutil import pytest import torch -import shutil from pythae.customexception import BadInheritanceError from pythae.models import BaseAE, BaseAEConfig @@ -85,7 +85,9 @@ def test_default_model_saving(self, tmpdir, model_config_with_input_dim): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = BaseAE.load_from_folder(dir_path) diff --git a/tests/test_baseSampler.py b/tests/test_baseSampler.py index ea0cf5b0..ddf6d173 100644 --- a/tests/test_baseSampler.py +++ b/tests/test_baseSampler.py @@ -64,168 +64,3 @@ def test_save_image_tensor(self, img_tensors, tmpdir, sampler_sample): rec_img = torch.tensor(imread(img_path)) / 255.0 assert 1 >= rec_img.max() >= 0 - - -# class Test_Sampler_Set_up: -# @pytest.fixture( -# params=[ -# BaseSamplerConfig( -# batch_size=1 -# ), # (target full batch number, target last full batch size, target_batch_number) -# BaseSamplerConfig(), -# ] -# ) -# def sampler_config(self, tmpdir, request): -# return request.param -# -# def test_sampler_set_up(self, model_sample, sampler_config): -# sampler = BaseSampler(model=model_sample, sampler_config=sampler_config) -# -# assert sampler.batch_size == sampler_config.batch_size -# assert sampler.samples_per_save == sampler_config.samples_per_save - - -# class Test_RHVAE_Sampler: -# @pytest.fixture( -# params=[ -# RHVAESamplerConfig(batch_size=1, mcmc_steps_nbr=15, samples_per_save=5), -# RHVAESamplerConfig(batch_size=2, mcmc_steps_nbr=15, samples_per_save=1), -# RHVAESamplerConfig( -# batch_size=3, n_lf=1, eps_lf=0.01, mcmc_steps_nbr=10, samples_per_save=5 -# ), -# RHVAESamplerConfig( -# batch_size=3, n_lf=1, eps_lf=0.01, mcmc_steps_nbr=10, samples_per_save=3 -# ), -# RHVAESamplerConfig( -# batch_size=10, -# n_lf=1, -# eps_lf=0.01, -# mcmc_steps_nbr=10, -# samples_per_save=3, -# ), -# ] -# ) -# def rhvae_sampler_config(self, tmpdir, request): -# tmpdir.mkdir("dummy_folder") -# request.param.output_dir = os.path.join(tmpdir, "dummy_folder") -# return request.param -# -# @pytest.fixture( -# params=[ -# np.random.randint(1, 15), -# np.random.randint(1, 15), -# np.random.randint(1, 15), -# ] -# ) -# def samples_number(self, request): -# return request.param -# -# @pytest.fixture( -# params=[ -# RHVAE(RHVAEConfig(input_dim=784, latent_dim=2)), -# RHVAE(RHVAEConfig(input_dim=784, latent_dim=3)), -# ] -# ) -# def rhvae_sample(self, request): -# return request.param -# -# def test_hmc_sampling(self, rhvae_sample, rhvae_sampler_config): -# -# # simulates a trained model -# # rhvae_sample.centroids_tens = torch.randn(20, rhvae_sample.latent_dim) -# # rhvae_sample.M_tens = torch.randn(20, rhvae_sample.latent_dim, rhvae_sample.latent_dim) -# -# sampler = RHVAESampler(model=rhvae_sample, sampler_config=rhvae_sampler_config) -# -# out = sampler.hmc_sampling(rhvae_sampler_config.batch_size) -# -# assert out.shape == (rhvae_sampler_config.batch_size, rhvae_sample.latent_dim) -# -# assert sampler.eps_lf == rhvae_sampler_config.eps_lf -# -# assert all( -# [ -# not torch.equal(out[i], out[j]) -# for i in range(len(out)) -# for j in range(i + 1, len(out)) -# ] -# ) -# -# def test_sampling_loop_saving( -# self, tmpdir, rhvae_sample, rhvae_sampler_config, samples_number -# ): -# -# sampler = RHVAESampler(model=rhvae_sample, sampler_config=rhvae_sampler_config) -# sampler.sample(samples_number=samples_number) -# -# generation_folder = os.path.join(tmpdir, "dummy_folder") -# generation_folder_list = os.listdir(generation_folder) -# -# assert f"generation_{sampler._sampling_signature}" in generation_folder_list -# -# data_folder = os.path.join( -# generation_folder, f"generation_{sampler._sampling_signature}" -# ) -# files_list = os.listdir(data_folder) -# -# full_data_file_nbr = int(samples_number / rhvae_sampler_config.samples_per_save) -# last_file_data_nbr = samples_number % rhvae_sampler_config.samples_per_save -# -# if last_file_data_nbr == 0: -# expected_num_of_data_files = full_data_file_nbr -# else: -# expected_num_of_data_files = full_data_file_nbr + 1 -# -# assert len(files_list) == 1 + expected_num_of_data_files -# -# assert "sampler_config.json" in files_list -# -# assert all( -# [ -# f"generated_data_{rhvae_sampler_config.samples_per_save}_{i}.pt" -# in files_list -# for i in range(full_data_file_nbr) -# ] -# ) -# -# if last_file_data_nbr > 0: -# assert ( -# f"generated_data_{last_file_data_nbr}_{expected_num_of_data_files-1}.pt" -# in files_list -# ) -# -# data_rec = [] -# -# for i in range(full_data_file_nbr): -# data_rec.append( -# torch.load( -# os.path.join( -# data_folder, -# "generated_data_" -# f"{rhvae_sampler_config.samples_per_save}_{i}.pt", -# ) -# ) -# ) -# -# if last_file_data_nbr > 0: -# data_rec.append( -# torch.load( -# os.path.join( -# data_folder, -# f"generated_data_" -# f"{last_file_data_nbr}_{expected_num_of_data_files-1}.pt", -# ) -# ) -# ) -# -# data_rec = torch.cat(data_rec) -# assert data_rec.shape[0] == samples_number -# -# # check sampler_config -# -# sampler_config_rec = RHVAESamplerConfig.from_json_file( -# os.path.join(data_folder, "sampler_config.json") -# ) -# -# assert sampler_config_rec.__dict__ == rhvae_sampler_config.__dict__ -# diff --git a/tests/test_base_trainer.py b/tests/test_base_trainer.py index 5f2a87b8..5349c788 100644 --- a/tests/test_base_trainer.py +++ b/tests/test_base_trainer.py @@ -3,11 +3,9 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop -from torch.optim.lr_scheduler import StepLR, LinearLR, ExponentialLR from pythae.customexception import ModelError -from pythae.models import BaseAE, BaseAEConfig, AE, AEConfig, VAE, VAEConfig, RHVAE, RHVAEConfig +from pythae.models import AE, RHVAE, VAE, AEConfig, RHVAEConfig, VAEConfig from pythae.trainers import BaseTrainer, BaseTrainerConfig from tests.data.custom_architectures import * @@ -21,7 +19,7 @@ def train_dataset(): @pytest.fixture() def model_sample(): - return BaseAE(BaseAEConfig(input_dim=(1, 28, 28))) + return AE(AEConfig(input_dim=(1, 28, 28))) @pytest.fixture @@ -35,8 +33,12 @@ class Test_DataLoader: @pytest.fixture( params=[ BaseTrainerConfig(), - BaseTrainerConfig(batch_size=100), - BaseTrainerConfig(batch_size=10), + BaseTrainerConfig( + per_device_train_batch_size=35, per_device_eval_batch_size=100 + ), + BaseTrainerConfig( + per_device_train_batch_size=3, per_device_eval_batch_size=10 + ), ] ) def training_config_batch_size(self, request, tmpdir): @@ -59,7 +61,10 @@ def test_build_train_data_loader( assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) assert train_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + train_data_loader.batch_size + == trainer.training_config.per_device_train_batch_size + ) def test_build_eval_data_loader( self, model_sample, train_dataset, training_config_batch_size @@ -70,12 +75,15 @@ def test_build_eval_data_loader( training_config=training_config_batch_size, ) - train_data_loader = trainer.get_eval_dataloader(train_dataset) + eval_data_loader = trainer.get_eval_dataloader(train_dataset) - assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) - assert train_data_loader.dataset == train_dataset + assert issubclass(type(eval_data_loader), torch.utils.data.DataLoader) + assert eval_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + eval_data_loader.batch_size + == trainer.training_config.per_device_eval_batch_size + ) class Test_Set_Training_config: @@ -83,7 +91,15 @@ class Test_Set_Training_config: params=[ None, BaseTrainerConfig(), - BaseTrainerConfig(batch_size=10, learning_rate=1e-5), + BaseTrainerConfig( + per_device_train_batch_size=10, + per_device_eval_batch_size=20, + learning_rate=1e-5, + optimizer_cls="AdamW", + optimizer_params={"weight_decay": 0.01}, + scheduler_cls="ExponentialLR", + scheduler_params={"gamma": 0.321}, + ), ] ) def training_configs(self, request, tmpdir): @@ -113,18 +129,39 @@ def test_set_training_config(self, model_sample, train_dataset, training_configs class Test_Build_Optimizer: + def test_wrong_optimizer_cls(self): + with pytest.raises(AttributeError): + BaseTrainerConfig(optimizer_cls="WrongOptim") + + def test_wrong_optimizer_params(self): + with pytest.raises(TypeError): + BaseTrainerConfig( + optimizer_cls="Adam", optimizer_params={"wrong_config": 1} + ) + @pytest.fixture(params=[BaseTrainerConfig(), BaseTrainerConfig(learning_rate=1e-5)]) def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): + @pytest.fixture( + params=[ + {"optimizer_cls": "Adagrad", "optimizer_params": {"lr_decay": 0.1}}, + {"optimizer_cls": "AdamW", "optimizer_params": {"betas": (0.1234, 0.4321)}}, + {"optimizer_cls": "SGD", "optimizer_params": None}, + ] + ) + def optimizer_config(self, request, training_configs_learning_rate): - optimizer = request.param( - model_sample.parameters(), lr=training_configs_learning_rate.learning_rate - ) - return optimizer + optimizer_config = request.param + + # set optim and params to training config + training_configs_learning_rate.optimizer_cls = optimizer_config["optimizer_cls"] + training_configs_learning_rate.optimizer_params = optimizer_config[ + "optimizer_params" + ] + + return optimizer_config def test_default_optimizer_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -134,9 +171,10 @@ def test_default_optimizer_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - optimizer=None, ) + trainer.set_optimizer() + assert issubclass(type(trainer.optimizer), torch.optim.Adam) assert ( trainer.optimizer.defaults["lr"] @@ -144,53 +182,92 @@ def test_default_optimizer_building( ) def test_set_custom_optimizer( - self, model_sample, train_dataset, training_configs_learning_rate, optimizers + self, + model_sample, + train_dataset, + training_configs_learning_rate, + optimizer_config, ): + trainer = BaseTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - optimizer=optimizers, ) - assert issubclass(type(trainer.optimizer), type(optimizers)) + trainer.set_optimizer() + + assert issubclass( + type(trainer.optimizer), + getattr(torch.optim, optimizer_config["optimizer_cls"]), + ) assert ( trainer.optimizer.defaults["lr"] == training_configs_learning_rate.learning_rate ) + if optimizer_config["optimizer_params"] is not None: + assert all( + [ + trainer.optimizer.defaults[key] + == optimizer_config["optimizer_params"][key] + for key in optimizer_config["optimizer_params"].keys() + ] + ) class Test_Build_Scheduler: + def test_wrong_scheduler_cls(self): + with pytest.raises(AttributeError): + BaseTrainerConfig(scheduler_cls="WrongOptim") + + def test_wrong_scheduler_params(self): + with pytest.raises(TypeError): + BaseTrainerConfig( + scheduler_cls="ReduceLROnPlateau", scheduler_params={"wrong_config": 1} + ) + @pytest.fixture(params=[BaseTrainerConfig(), BaseTrainerConfig(learning_rate=1e-5)]) def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): + @pytest.fixture( + params=[ + {"optimizer_cls": "Adagrad", "optimizer_params": {"lr_decay": 0.1}}, + {"optimizer_cls": "AdamW", "optimizer_params": {"betas": (0.1234, 0.4321)}}, + {"optimizer_cls": "SGD", "optimizer_params": None}, + ] + ) + def optimizer_config(self, request, training_configs_learning_rate): - optimizer = request.param( - model_sample.parameters(), lr=training_configs_learning_rate.learning_rate - ) - return optimizer + optimizer_config = request.param + + # set optim and params to training config + training_configs_learning_rate.optimizer_cls = optimizer_config["optimizer_cls"] + training_configs_learning_rate.optimizer_params = optimizer_config[ + "optimizer_params" + ] + + return optimizer_config @pytest.fixture( params=[ - (StepLR, {"step_size": 1}), - (LinearLR, {"start_factor": 0.01}), - (ExponentialLR, {"gamma": 0.1}), + {"scheduler_cls": "StepLR", "scheduler_params": {"step_size": 1}}, + {"scheduler_cls": "LinearLR", "scheduler_params": None}, + {"scheduler_cls": "ExponentialLR", "scheduler_params": {"gamma": 3.14}}, ] ) - def schedulers( - self, request, optimizers - ): - if request.param[0] is not None: - scheduler = request.param[0](optimizers, **request.param[1]) + def scheduler_config(self, request, training_configs_learning_rate): - else: - scheduler = None + scheduler_config = request.param - return scheduler + # set scheduler and params to training config + training_configs_learning_rate.scheduler_cls = scheduler_config["scheduler_cls"] + training_configs_learning_rate.scheduler_params = scheduler_config[ + "scheduler_params" + ] + + return request.param def test_default_scheduler_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -200,33 +277,65 @@ def test_default_scheduler_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - optimizer=None, ) - assert issubclass( - type(trainer.scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau - ) + trainer.set_optimizer() + trainer.set_scheduler() + + assert trainer.scheduler is None def test_set_custom_scheduler( self, model_sample, train_dataset, training_configs_learning_rate, - optimizers, - schedulers, + scheduler_config, ): trainer = BaseTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - optimizer=optimizers, - scheduler=schedulers, ) - assert issubclass(type(trainer.scheduler), type(schedulers)) + trainer.set_optimizer() + trainer.set_scheduler() + + assert issubclass( + type(trainer.scheduler), + getattr(torch.optim.lr_scheduler, scheduler_config["scheduler_cls"]), + ) + if scheduler_config["scheduler_params"] is not None: + assert all( + [ + trainer.scheduler.state_dict()[key] + == scheduler_config["scheduler_params"][key] + for key in scheduler_config["scheduler_params"].keys() + ] + ) class Test_Device_Checks: + def test_set_environ_variable(self): + os.environ["LOCAL_RANK"] = "1" + os.environ["WORLD_SIZE"] = "4" + os.environ["RANK"] = "3" + os.environ["MASTER_ADDR"] = "314" + os.environ["MASTER_PORT"] = "222" + + trainer_config = BaseTrainerConfig() + + assert int(trainer_config.local_rank) == 1 + assert int(trainer_config.world_size) == 4 + assert int(trainer_config.rank) == 3 + assert trainer_config.master_addr == "314" + assert trainer_config.master_port == "222" + + del os.environ["LOCAL_RANK"] + del os.environ["WORLD_SIZE"] + del os.environ["RANK"] + del os.environ["MASTER_ADDR"] + del os.environ["MASTER_PORT"] + @pytest.fixture( params=[ BaseTrainerConfig(num_epochs=3, no_cuda=True), @@ -373,18 +482,22 @@ def rhvae( return model def test_raises_sanity_check_error(self, rhvae, train_dataset, training_config): - trainer = BaseTrainer( - model=rhvae, train_dataset=train_dataset, training_config=training_config - ) - with pytest.raises(ModelError): - trainer._run_model_sanity_check(rhvae, train_dataset) + _ = BaseTrainer( + model=rhvae, + train_dataset=train_dataset, + training_config=training_config, + ) @pytest.mark.slow class Test_Main_Training: @pytest.fixture( - params=[BaseTrainerConfig(num_epochs=3, steps_saving=2, steps_predict=2, learning_rate=1e-5)] + params=[ + BaseTrainerConfig( + num_epochs=3, steps_saving=2, steps_predict=2, learning_rate=1e-5 + ) + ] ) def training_configs(self, tmpdir, request): tmpdir.mkdir("dummy_folder") @@ -451,46 +564,49 @@ def ae(self, ae_config, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[None, Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, ae, training_configs): - if request.param is not None: - optimizer = request.param( - ae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer + @pytest.fixture( + params=[ + {"optimizer_cls": "Adagrad", "optimizer_params": {"lr_decay": 0.1}}, + {"optimizer_cls": "AdamW", "optimizer_params": {"betas": (0.1234, 0.4321)}}, + {"optimizer_cls": "SGD", "optimizer_params": None}, + ] + ) + def optimizer_config(self, request): + return request.param @pytest.fixture( params=[ - (None, None), - (StepLR, {"step_size": 1, "gamma": 0.99}), - (LinearLR, {"start_factor": 0.99}), - (ExponentialLR, {"gamma": 0.99}), + {"scheduler_cls": "StepLR", "scheduler_params": {"step_size": 0.1}}, + {"scheduler_cls": "ReduceLROnPlateau", "scheduler_params": None}, + {"scheduler_cls": None, "scheduler_params": None}, ] ) - def schedulers(self, request, optimizers): - if request.param[0] is not None and optimizers is not None: - scheduler = request.param[0](optimizers, **request.param[1]) + def scheduler_config(self, request): + return request.param - else: - scheduler = None + @pytest.fixture + def trainer( + self, ae, train_dataset, optimizer_config, scheduler_config, training_configs + ): - return scheduler + training_configs.optimizer_cls = optimizer_config["optimizer_cls"] + training_configs.optimizer_params = optimizer_config["optimizer_params"] + training_configs.scheduler_cls = scheduler_config["scheduler_cls"] + training_configs.scheduler_params = scheduler_config["scheduler_params"] - def test_train_step( - self, ae, train_dataset, training_configs, optimizers, schedulers - ): trainer = BaseTrainer( model=ae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, - scheduler=schedulers, ) + trainer.prepare_training() + + return trainer + + def test_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -500,22 +616,16 @@ def test_train_step( # check that weights were updated for key in start_model_state_dict.keys(): if "encoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]), key + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ), key if "decoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) - def test_eval_step( - self, ae, train_dataset, training_configs, optimizers, schedulers - ): - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - scheduler=schedulers, - ) + def test_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -531,38 +641,16 @@ def test_eval_step( ] ) - def test_predict_step( - self, ae, train_dataset, training_configs, optimizers, schedulers - ): - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - scheduler=schedulers, - ) + def test_predict_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) + _ = deepcopy(trainer.model.state_dict()) true_data, recon, gene = trainer.predict(trainer.model) - assert true_data.reshape(3, -1).shape == recon.reshape(3, -1).shape assert gene.reshape(3, -1).shape[1:] == true_data.reshape(3, -1).shape[1:] - def test_main_train_loop( - self, tmpdir, ae, train_dataset, training_configs, optimizers, schedulers - ): - - trainer = BaseTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - scheduler=schedulers, - ) + def test_main_train_loop(self, trainer, training_configs): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -573,13 +661,20 @@ def test_main_train_loop( # check that weights were updated for key in start_model_state_dict.keys(): if "encoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) if "decoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) # check changed lr with custom schedulers - if type(trainer.scheduler) != torch.optim.lr_scheduler.ReduceLROnPlateau: + if ( + type(trainer.scheduler) != torch.optim.lr_scheduler.ReduceLROnPlateau + and trainer.scheduler is not None + ): assert training_configs.learning_rate != trainer.scheduler.get_last_lr() diff --git a/tests/test_config.py b/tests/test_config.py index 5f7f8600..f7268735 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,9 +5,9 @@ from pydantic import ValidationError from pythae.config import BaseConfig -from pythae.models import BaseAEConfig, AEConfig +from pythae.models import AEConfig, BaseAEConfig from pythae.samplers import BaseSamplerConfig, NormalSamplerConfig -from pythae.trainers import BaseTrainerConfig, AdversarialTrainerConfig +from pythae.trainers import AdversarialTrainerConfig, BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -35,7 +35,12 @@ def corrupted_config_path(self): ], [ os.path.join(PATH, "data/baseAE/configs/training_config00.json"), - BaseTrainerConfig(batch_size=13, num_epochs=2, learning_rate=1e-5), + BaseTrainerConfig( + per_device_train_batch_size=13, + per_device_eval_batch_size=42, + num_epochs=2, + learning_rate=1e-5, + ), ], [ os.path.join(PATH, "data/baseAE/configs/generation_config00.json"), @@ -121,7 +126,11 @@ def test_save_json(self, tmpdir, model_configs): @pytest.fixture( params=[ BaseTrainerConfig(), - BaseTrainerConfig(learning_rate=100, batch_size=15), + BaseTrainerConfig( + learning_rate=100, + per_device_train_batch_size=15, + per_device_eval_batch_size=23, + ), ] ) def training_configs(self, request): diff --git a/tests/test_coupled_optimizers_adversarial_trainer.py b/tests/test_coupled_optimizers_adversarial_trainer.py index 7cfb30d3..f3a1a2ed 100644 --- a/tests/test_coupled_optimizers_adversarial_trainer.py +++ b/tests/test_coupled_optimizers_adversarial_trainer.py @@ -3,8 +3,6 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop -from torch.optim.lr_scheduler import StepLR, LinearLR, ExponentialLR from pythae.models import VAEGAN, VAEGANConfig from pythae.trainers import ( @@ -17,6 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" + @pytest.fixture def train_dataset(): return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")) @@ -37,12 +36,12 @@ def training_config(tmpdir): class Test_DataLoader: @pytest.fixture( params=[ - CoupledOptimizerAdversarialTrainerConfig(decoder_optim_decay=0), + CoupledOptimizerAdversarialTrainerConfig(), CoupledOptimizerAdversarialTrainerConfig( - batch_size=100, encoder_optim_decay=1e-7 + per_device_train_batch_size=100, per_device_eval_batch_size=35 ), CoupledOptimizerAdversarialTrainerConfig( - batch_size=10, encoder_optim_decay=1e-7, decoder_optim_decay=1e-7 + per_device_train_batch_size=10, per_device_eval_batch_size=3 ), ] ) @@ -66,7 +65,10 @@ def test_build_train_data_loader( assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) assert train_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + train_data_loader.batch_size + == trainer.training_config.per_device_train_batch_size + ) def test_build_eval_data_loader( self, model_sample, train_dataset, training_config_batch_size @@ -77,22 +79,35 @@ def test_build_eval_data_loader( training_config=training_config_batch_size, ) - train_data_loader = trainer.get_eval_dataloader(train_dataset) + eval_data_loader = trainer.get_eval_dataloader(train_dataset) - assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) - assert train_data_loader.dataset == train_dataset + assert issubclass(type(eval_data_loader), torch.utils.data.DataLoader) + assert eval_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + eval_data_loader.batch_size + == trainer.training_config.per_device_eval_batch_size + ) class Test_Set_Training_config: @pytest.fixture( params=[ + CoupledOptimizerAdversarialTrainerConfig(), CoupledOptimizerAdversarialTrainerConfig( - decoder_optim_decay=0, discriminator_optim_decay=0.7 - ), - CoupledOptimizerAdversarialTrainerConfig( - batch_size=10, learning_rate=1e-5, encoder_optim_decay=0 + per_device_train_batch_size=10, + per_device_eval_batch_size=3, + encoder_learning_rate=1e-5, + decoder_learning_rate=1e-3, + discriminator_learning_rate=1e-6, + encoder_optimizer_cls="AdamW", + encoder_optimizer_params={"weight_decay": 0.01}, + decoder_optimizer_cls="Adam", + decoder_optimizer_params=None, + discriminator_optimizer_cls="SGD", + discriminator_optimizer_params={"weight_decay": 0.01}, + encoder_scheduler_cls="ExponentialLR", + encoder_scheduler_params={"gamma": 0.321}, ), ] ) @@ -120,9 +135,44 @@ def test_set_training_config(self, model_sample, train_dataset, training_configs class Test_Build_Optimizer: + def test_wrong_optimizer_cls(self): + with pytest.raises(AttributeError): + CoupledOptimizerAdversarialTrainerConfig(encoder_optimizer_cls="WrongOptim") + + with pytest.raises(AttributeError): + CoupledOptimizerAdversarialTrainerConfig(decoder_optimizer_cls="WrongOptim") + + with pytest.raises(AttributeError): + CoupledOptimizerAdversarialTrainerConfig( + discriminator_optimizer_cls="WrongOptim" + ) + + def test_wrong_optimizer_params(self): + with pytest.raises(TypeError): + CoupledOptimizerAdversarialTrainerConfig( + encoder_optimizer_cls="Adam", + encoder_optimizer_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + CoupledOptimizerAdversarialTrainerConfig( + decoder_optimizer_cls="Adam", + decoder_optimizer_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + CoupledOptimizerAdversarialTrainerConfig( + discriminator_optimizer_cls="Adam", + discriminator_optimizer_params={"wrong_config": 1}, + ) + @pytest.fixture( params=[ - CoupledOptimizerAdversarialTrainerConfig(learning_rate=1e-6), + CoupledOptimizerAdversarialTrainerConfig( + encoder_learning_rate=1e-6, + decoder_learning_rate=1e-5, + discriminator_learning_rate=1e-3, + ), CoupledOptimizerAdversarialTrainerConfig(), ] ) @@ -130,22 +180,59 @@ def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): + @pytest.fixture( + params=[ + { + "encoder_optimizer_cls": "Adagrad", + "encoder_optimizer_params": {"lr_decay": 0.1}, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": {"momentum": 0.9}, + "discriminator_optimizer_cls": "AdamW", + "discriminator_optimizer_params": {"betas": (0.1234, 0.4321)}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": {"momentum": 0.1}, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": None, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": {"momentum": 0.9}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": None, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": None, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": None, + }, + ] + ) + def optimizer_config(self, request, training_configs_learning_rate): - encoder_optimizer = request.param( - model_sample.encoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - decoder_optimizer = request.param( - model_sample.decoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - discriminator_optimizer = request.param( - model_sample.discriminator.parameters(), - lr=training_configs_learning_rate.learning_rate, + optimizer_config = request.param + + # set optim and params to training config + training_configs_learning_rate.encoder_optimizer_cls = optimizer_config[ + "encoder_optimizer_cls" + ] + training_configs_learning_rate.encoder_optimizer_params = optimizer_config[ + "encoder_optimizer_params" + ] + training_configs_learning_rate.decoder_optimizer_cls = optimizer_config[ + "decoder_optimizer_cls" + ] + training_configs_learning_rate.decoder_optimizer_params = optimizer_config[ + "decoder_optimizer_params" + ] + training_configs_learning_rate.discriminator_optimizer_cls = optimizer_config[ + "discriminator_optimizer_cls" + ] + training_configs_learning_rate.discriminator_optimizer_params = ( + optimizer_config["discriminator_optimizer_params"] ) - return (encoder_optimizer, decoder_optimizer, discriminator_optimizer) + + return optimizer_config def test_default_optimizer_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -155,102 +242,194 @@ def test_default_optimizer_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=None, - decoder_optimizer=None, - discriminator_optimizer=None, ) + trainer.set_encoder_optimizer() + trainer.set_decoder_optimizer() + trainer.set_discriminator_optimizer() + assert issubclass(type(trainer.encoder_optimizer), torch.optim.Adam) assert ( trainer.encoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.encoder_learning_rate ) assert issubclass(type(trainer.decoder_optimizer), torch.optim.Adam) assert ( trainer.decoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.decoder_learning_rate ) assert issubclass(type(trainer.discriminator_optimizer), torch.optim.Adam) assert ( trainer.discriminator_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.discriminator_learning_rate ) def test_set_custom_optimizer( - self, model_sample, train_dataset, training_configs_learning_rate, optimizers + self, + model_sample, + train_dataset, + training_configs_learning_rate, + optimizer_config, ): trainer = CoupledOptimizerAdversarialTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[2], ) - assert issubclass(type(trainer.encoder_optimizer), type(optimizers[0])) + trainer.set_encoder_optimizer() + trainer.set_decoder_optimizer() + trainer.set_discriminator_optimizer() + + assert issubclass( + type(trainer.encoder_optimizer), + getattr(torch.optim, optimizer_config["encoder_optimizer_cls"]), + ) assert ( trainer.encoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate - ) + == training_configs_learning_rate.encoder_learning_rate + ) + if optimizer_config["encoder_optimizer_params"] is not None: + assert all( + [ + trainer.encoder_optimizer.defaults[key] + == optimizer_config["encoder_optimizer_params"][key] + for key in optimizer_config["encoder_optimizer_params"].keys() + ] + ) - assert issubclass(type(trainer.decoder_optimizer), type(optimizers[1])) + assert issubclass( + type(trainer.decoder_optimizer), + getattr(torch.optim, optimizer_config["decoder_optimizer_cls"]), + ) assert ( trainer.decoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.decoder_learning_rate + ) + if optimizer_config["decoder_optimizer_params"] is not None: + assert all( + [ + trainer.decoder_optimizer.defaults[key] + == optimizer_config["decoder_optimizer_params"][key] + for key in optimizer_config["decoder_optimizer_params"].keys() + ] + ) + + assert issubclass( + type(trainer.discriminator_optimizer), + getattr(torch.optim, optimizer_config["discriminator_optimizer_cls"]), ) - assert issubclass(type(trainer.discriminator_optimizer), type(optimizers[2])) assert ( trainer.discriminator_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate - ) + == training_configs_learning_rate.discriminator_learning_rate + ) + if optimizer_config["discriminator_optimizer_params"] is not None: + assert all( + [ + trainer.discriminator_optimizer.defaults[key] + == optimizer_config["discriminator_optimizer_params"][key] + for key in optimizer_config["discriminator_optimizer_params"].keys() + ] + ) + class Test_Build_Scheduler: - @pytest.fixture(params=[CoupledOptimizerAdversarialTrainerConfig(), CoupledOptimizerAdversarialTrainerConfig(learning_rate=1e-5)]) + def test_wrong_scheduler_cls(self): + with pytest.raises(AttributeError): + CoupledOptimizerAdversarialTrainerConfig(encoder_scheduler_cls="WrongOptim") + + with pytest.raises(AttributeError): + CoupledOptimizerAdversarialTrainerConfig(decoder_scheduler_cls="WrongOptim") + + with pytest.raises(AttributeError): + CoupledOptimizerAdversarialTrainerConfig( + discriminator_scheduler_cls="WrongOptim" + ) + + def test_wrong_scheduler_params(self): + with pytest.raises(TypeError): + CoupledOptimizerAdversarialTrainerConfig( + encoder_scheduler_cls="ReduceLROnPlateau", + encoder_scheduler_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + CoupledOptimizerAdversarialTrainerConfig( + decoder_scheduler_cls="ReduceLROnPlateau", + decoder_scheduler_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + CoupledOptimizerAdversarialTrainerConfig( + discriminator_scheduler_cls="ReduceLROnPlateau", + discriminator_scheduler_params={"wrong_config": 1}, + ) + + @pytest.fixture( + params=[ + CoupledOptimizerAdversarialTrainerConfig(), + CoupledOptimizerAdversarialTrainerConfig(learning_rate=1e-5), + ] + ) def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): - - encoder_optimizer = request.param( - model_sample.encoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - decoder_optimizer = request.param( - model_sample.decoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - discriminator_optimizer = request.param( - model_sample.discriminator.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - return (encoder_optimizer, decoder_optimizer, discriminator_optimizer) - @pytest.fixture( params=[ - (StepLR, {"step_size": 1}), - (LinearLR, {"start_factor": 0.01}), - (ExponentialLR, {"gamma": 0.1}), + { + "encoder_scheduler_cls": "LinearLR", + "encoder_scheduler_params": None, + "decoder_scheduler_cls": "LinearLR", + "decoder_scheduler_params": None, + "discriminator_scheduler_cls": "LinearLR", + "discriminator_scheduler_params": None, + }, + { + "encoder_scheduler_cls": "StepLR", + "encoder_scheduler_params": {"step_size": 0.71}, + "decoder_scheduler_cls": "LinearLR", + "decoder_scheduler_params": None, + "discriminator_scheduler_cls": "ExponentialLR", + "discriminator_scheduler_params": {"gamma": 0.1}, + }, + { + "encoder_scheduler_cls": None, + "encoder_scheduler_params": {"patience": 12}, + "decoder_scheduler_cls": None, + "decoder_scheduler_params": None, + "discriminator_scheduler_cls": None, + "discriminator_scheduler_params": None, + }, ] ) - def schedulers( - self, request, optimizers - ): - if request.param[0] is not None: - encoder_scheduler = request.param[0](optimizers[0], **request.param[1]) - decoder_scheduler = request.param[0](optimizers[1], **request.param[1]) - discriminator_scheduler = request.param[0](optimizers[2], **request.param[1]) + def scheduler_config(self, request, training_configs_learning_rate): - else: - encoder_scheduler = None - decoder_scheduler = None - discriminator_scheduler = None + scheduler_config = request.param + + # set scheduler and params to training config + training_configs_learning_rate.encoder_scheduler_cls = scheduler_config[ + "encoder_scheduler_cls" + ] + training_configs_learning_rate.encoder_scheduler_params = scheduler_config[ + "encoder_scheduler_params" + ] + training_configs_learning_rate.decoder_scheduler_cls = scheduler_config[ + "decoder_scheduler_cls" + ] + training_configs_learning_rate.decoder_scheduler_params = scheduler_config[ + "decoder_scheduler_params" + ] + training_configs_learning_rate.discriminator_scheduler_cls = scheduler_config[ + "discriminator_scheduler_cls" + ] + training_configs_learning_rate.discriminator_scheduler_params = ( + scheduler_config["discriminator_scheduler_params"] + ) - return (encoder_scheduler, decoder_scheduler, discriminator_scheduler) + return request.param def test_default_scheduler_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -260,46 +439,97 @@ def test_default_scheduler_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=None, - decoder_scheduler=None, - discriminator_optimizer=None - ) - - assert issubclass( - type(trainer.encoder_scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau ) - assert issubclass( - type(trainer.decoder_scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau - ) + trainer.set_encoder_optimizer() + trainer.set_encoder_scheduler() + trainer.set_decoder_optimizer() + trainer.set_decoder_scheduler() + trainer.set_discriminator_optimizer() + trainer.set_discriminator_scheduler() - assert issubclass( - type(trainer.discriminator_scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau - ) + assert trainer.encoder_scheduler is None + assert trainer.decoder_scheduler is None + assert trainer.discriminator_scheduler is None def test_set_custom_scheduler( self, model_sample, train_dataset, training_configs_learning_rate, - optimizers, - schedulers, + scheduler_config, ): trainer = CoupledOptimizerAdversarialTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=optimizers[0], - encoder_scheduler=schedulers[0], - decoder_optimizer=optimizers[1], - decoder_scheduler=schedulers[1], - discriminator_optimizer=optimizers[2], - discriminator_scheduler=schedulers[2] ) - assert issubclass(type(trainer.encoder_scheduler), type(schedulers[0])) - assert issubclass(type(trainer.decoder_scheduler), type(schedulers[1])) - assert issubclass(type(trainer.discriminator_scheduler), type(schedulers[2])) + trainer.set_encoder_optimizer() + trainer.set_encoder_scheduler() + trainer.set_decoder_optimizer() + trainer.set_decoder_scheduler() + trainer.set_discriminator_optimizer() + trainer.set_discriminator_scheduler() + + if scheduler_config["encoder_scheduler_cls"] is None: + assert trainer.encoder_scheduler is None + else: + assert issubclass( + type(trainer.encoder_scheduler), + getattr( + torch.optim.lr_scheduler, scheduler_config["encoder_scheduler_cls"] + ), + ) + if scheduler_config["encoder_scheduler_params"] is not None: + assert all( + [ + trainer.encoder_scheduler.state_dict()[key] + == scheduler_config["encoder_scheduler_params"][key] + for key in scheduler_config["encoder_scheduler_params"].keys() + ] + ) + + if scheduler_config["decoder_scheduler_cls"] is None: + assert trainer.decoder_scheduler is None + else: + assert issubclass( + type(trainer.decoder_scheduler), + getattr( + torch.optim.lr_scheduler, scheduler_config["decoder_scheduler_cls"] + ), + ) + if scheduler_config["decoder_scheduler_params"] is not None: + assert all( + [ + trainer.decoder_scheduler.state_dict()[key] + == scheduler_config["decoder_scheduler_params"][key] + for key in scheduler_config["decoder_scheduler_params"].keys() + ] + ) + + if scheduler_config["discriminator_scheduler_cls"] is None: + assert trainer.discriminator_scheduler is None + + else: + assert issubclass( + type(trainer.discriminator_scheduler), + getattr( + torch.optim.lr_scheduler, + scheduler_config["discriminator_scheduler_cls"], + ), + ) + if scheduler_config["discriminator_scheduler_params"] is not None: + assert all( + [ + trainer.discriminator_scheduler.state_dict()[key] + == scheduler_config["discriminator_scheduler_params"][key] + for key in scheduler_config[ + "discriminator_scheduler_params" + ].keys() + ] + ) + @pytest.mark.slow class Test_Main_Training: @@ -358,78 +588,130 @@ def ae(self, ae_config, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[None, Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, ae, training_configs): - if request.param is not None: - encoder_optimizer = request.param( - ae.encoder.parameters(), lr=training_configs.learning_rate - ) - - decoder_optimizer = request.param( - ae.decoder.parameters(), lr=training_configs.learning_rate - ) - - discriminator_optimizer = request.param( - ae.discriminator.parameters(), lr=training_configs.learning_rate - ) - - else: - encoder_optimizer = None - decoder_optimizer = None - discriminator_optimizer = None - - return (encoder_optimizer, decoder_optimizer, discriminator_optimizer) - @pytest.fixture( params=[ - (None, None), - (StepLR, {"step_size": 1, "gamma": 0.99}), - (LinearLR, {"start_factor": 0.99}), - (ExponentialLR, {"gamma": 0.99}), + { + "encoder_optimizer_cls": "Adagrad", + "encoder_optimizer_params": {"lr_decay": 0.1}, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": {"momentum": 0.9}, + "discriminator_optimizer_cls": "AdamW", + "discriminator_optimizer_params": {"betas": (0.1234, 0.4321)}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": {"momentum": 0.1}, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": None, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": {"momentum": 0.9}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": None, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": None, + "discriminator_optimizer_cls": "SGD", + "discriminator_optimizer_params": None, + }, ] ) - def schedulers(self, request, optimizers): - if request.param[0] is not None and optimizers[0] is not None: - encoder_scheduler = request.param[0](optimizers[0], **request.param[1]) - - else: - encoder_scheduler = None - - if request.param[0] is not None and optimizers[1] is not None: - decoder_scheduler = request.param[0](optimizers[1], **request.param[1]) - - else: - decoder_scheduler = None + def optimizer_config(self, request): + return request.param - if request.param[0] is not None and optimizers[2] is not None: - discriminator_scheduler = request.param[0](optimizers[2], **request.param[1]) + @pytest.fixture( + params=[ + { + "encoder_scheduler_cls": "LinearLR", + "encoder_scheduler_params": None, + "decoder_scheduler_cls": "LinearLR", + "decoder_scheduler_params": None, + "discriminator_scheduler_cls": "LinearLR", + "discriminator_scheduler_params": None, + }, + { + "encoder_scheduler_cls": "StepLR", + "encoder_scheduler_params": {"step_size": 0.71}, + "decoder_scheduler_cls": "LinearLR", + "decoder_scheduler_params": None, + "discriminator_scheduler_cls": "ExponentialLR", + "discriminator_scheduler_params": {"gamma": 0.1}, + }, + { + "encoder_scheduler_cls": None, + "encoder_scheduler_params": {"patience": 12}, + "decoder_scheduler_cls": None, + "decoder_scheduler_params": None, + "discriminator_scheduler_cls": None, + "discriminator_scheduler_params": None, + }, + ] + ) + def scheduler_config(self, request): + return request.param - else: - discriminator_scheduler = None + @pytest.fixture + def trainer( + self, ae, train_dataset, optimizer_config, scheduler_config, training_configs + ): - return (encoder_scheduler, decoder_scheduler, discriminator_scheduler) + training_configs.encoder_optimizer_cls = optimizer_config[ + "encoder_optimizer_cls" + ] + training_configs.encoder_optimizer_params = optimizer_config[ + "encoder_optimizer_params" + ] + training_configs.decoder_optimizer_cls = optimizer_config[ + "decoder_optimizer_cls" + ] + training_configs.decoder_optimizer_params = optimizer_config[ + "decoder_optimizer_params" + ] + training_configs.discriminator_optimizer_cls = optimizer_config[ + "discriminator_optimizer_cls" + ] + training_configs.discriminator_optimizer_params = optimizer_config[ + "discriminator_optimizer_params" + ] - @pytest.fixture(params=[ - torch.randint(2, (3,)), - torch.randint(2, (3,)), - torch.randint(2, (3,)) - ]) - def optimizer_updates(self, request): - return request.param + training_configs.encoder_scheduler_cls = scheduler_config[ + "encoder_scheduler_cls" + ] + training_configs.encoder_scheduler_params = scheduler_config[ + "encoder_scheduler_params" + ] + training_configs.decoder_scheduler_cls = scheduler_config[ + "decoder_scheduler_cls" + ] + training_configs.decoder_scheduler_params = scheduler_config[ + "decoder_scheduler_params" + ] + training_configs.discriminator_scheduler_cls = scheduler_config[ + "discriminator_scheduler_cls" + ] + training_configs.discriminator_scheduler_params = scheduler_config[ + "discriminator_scheduler_params" + ] - def test_train_step(self, ae, train_dataset, training_configs, optimizers, schedulers): trainer = CoupledOptimizerAdversarialTrainer( model=ae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[2], - encoder_scheduler=schedulers[0], - decoder_scheduler=schedulers[1], - discriminator_scheduler=schedulers[2] ) + trainer.prepare_training() + + return trainer + + @pytest.fixture( + params=[torch.randint(2, (3,)), torch.randint(2, (3,)), torch.randint(2, (3,))] + ) + def optimizer_updates(self, request): + return request.param + + def test_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -444,20 +726,11 @@ def test_train_step(self, ae, train_dataset, training_configs, optimizers, sched ] ) - def test_train_2_steps_updates(self, ae, train_dataset, training_configs, optimizers, optimizer_updates): - trainer = CoupledOptimizerAdversarialTrainer( - model=ae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[2], - ) + def test_train_2_steps_updates(self, ae, train_dataset, trainer, optimizer_updates): start_model_state_dict = deepcopy(trainer.model.state_dict()) - model_output = ae( - {"data": train_dataset.data.to(device)}) + model_output = ae({"data": train_dataset.data.to(device)}) model_output.update_encoder = False model_output.update_decoder = False @@ -475,8 +748,7 @@ def test_train_2_steps_updates(self, ae, train_dataset, training_configs, optimi ] ) - model_output = ae( - {"data": train_dataset.data.to(device)}) + model_output = ae({"data": train_dataset.data.to(device)}) model_output.update_encoder = bool(optimizer_updates[0]) model_output.update_decoder = bool(optimizer_updates[1]) @@ -490,39 +762,44 @@ def test_train_2_steps_updates(self, ae, train_dataset, training_configs, optimi for key in start_model_state_dict.keys(): if "encoder" in key: if bool(optimizer_updates[0]): - assert not torch.equal(step_1_model_state_dict[key], step_2_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], step_2_model_state_dict[key] + ) else: - assert torch.equal(step_1_model_state_dict[key], step_2_model_state_dict[key]) + assert torch.equal( + step_1_model_state_dict[key], step_2_model_state_dict[key] + ) if "decoder" in key: if bool(optimizer_updates[1]): - assert not torch.equal(step_1_model_state_dict[key], step_2_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], step_2_model_state_dict[key] + ) else: - assert torch.equal(step_1_model_state_dict[key], step_2_model_state_dict[key]) - + assert torch.equal( + step_1_model_state_dict[key], step_2_model_state_dict[key] + ) + if "discriminator" in key: if bool(optimizer_updates[2]): - assert not torch.equal(step_1_model_state_dict[key], step_2_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], step_2_model_state_dict[key] + ) else: - assert torch.equal(step_1_model_state_dict[key], step_2_model_state_dict[key]) + assert torch.equal( + step_1_model_state_dict[key], step_2_model_state_dict[key] + ) - def test_train_step_various_updates(self, ae, train_dataset, training_configs, optimizers, optimizer_updates): - trainer = CoupledOptimizerAdversarialTrainer( - model=ae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[2] - ) + def test_train_step_various_updates( + self, ae, train_dataset, trainer, optimizer_updates + ): start_model_state_dict = deepcopy(trainer.model.state_dict()) - model_output = ae( - {"data": train_dataset.data.to(device)}) + model_output = ae({"data": train_dataset.data.to(device)}) model_output.update_encoder = bool(optimizer_updates[0]) model_output.update_decoder = bool(optimizer_updates[1]) @@ -536,38 +813,38 @@ def test_train_step_various_updates(self, ae, train_dataset, training_configs, o for key in start_model_state_dict.keys(): if "encoder" in key: if bool(optimizer_updates[0]): - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) else: - assert torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) if "decoder" in key: if bool(optimizer_updates[1]): - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) else: - assert torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) - + assert torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) + if "discriminator" in key: if bool(optimizer_updates[2]): - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) else: - assert torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) - def test_eval_step(self, ae, train_dataset, training_configs, optimizers, schedulers): - trainer = CoupledOptimizerAdversarialTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[2], - encoder_scheduler=schedulers[0], - decoder_scheduler=schedulers[1], - discriminator_scheduler=schedulers[2] - ) + def test_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -583,22 +860,7 @@ def test_eval_step(self, ae, train_dataset, training_configs, optimizers, schedu ] ) - def test_main_train_loop( - self, tmpdir, ae, train_dataset, training_configs, optimizers, schedulers - ): - - trainer = CoupledOptimizerAdversarialTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - discriminator_optimizer=optimizers[2], - encoder_scheduler=schedulers[0], - decoder_scheduler=schedulers[1], - discriminator_scheduler=schedulers[2] - ) + def test_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) diff --git a/tests/test_coupled_optimizers_trainer.py b/tests/test_coupled_optimizers_trainer.py index 1471b030..d58e1105 100644 --- a/tests/test_coupled_optimizers_trainer.py +++ b/tests/test_coupled_optimizers_trainer.py @@ -3,8 +3,6 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop -from torch.optim.lr_scheduler import StepLR, LinearLR, ExponentialLR from pythae.models import RAE_L2, RAE_L2_Config from pythae.trainers import CoupledOptimizerTrainer, CoupledOptimizerTrainerConfig @@ -33,10 +31,12 @@ def training_config(tmpdir): class Test_DataLoader: @pytest.fixture( params=[ - CoupledOptimizerTrainerConfig(decoder_optim_decay=0), - CoupledOptimizerTrainerConfig(batch_size=100, encoder_optim_decay=1e-7), + CoupledOptimizerTrainerConfig(), + CoupledOptimizerTrainerConfig( + per_device_train_batch_size=100, per_device_eval_batch_size=35 + ), CoupledOptimizerTrainerConfig( - batch_size=10, encoder_optim_decay=1e-7, decoder_optim_decay=1e-7 + per_device_train_batch_size=10, per_device_eval_batch_size=3 ), ] ) @@ -60,7 +60,10 @@ def test_build_train_data_loader( assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) assert train_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + train_data_loader.batch_size + == trainer.training_config.per_device_train_batch_size + ) def test_build_eval_data_loader( self, model_sample, train_dataset, training_config_batch_size @@ -71,20 +74,32 @@ def test_build_eval_data_loader( training_config=training_config_batch_size, ) - train_data_loader = trainer.get_eval_dataloader(train_dataset) + eval_data_loader = trainer.get_eval_dataloader(train_dataset) - assert issubclass(type(train_data_loader), torch.utils.data.DataLoader) - assert train_data_loader.dataset == train_dataset + assert issubclass(type(eval_data_loader), torch.utils.data.DataLoader) + assert eval_data_loader.dataset == train_dataset - assert train_data_loader.batch_size == trainer.training_config.batch_size + assert ( + eval_data_loader.batch_size + == trainer.training_config.per_device_eval_batch_size + ) class Test_Set_Training_config: @pytest.fixture( params=[ - CoupledOptimizerTrainerConfig(decoder_optim_decay=0), + CoupledOptimizerTrainerConfig(), CoupledOptimizerTrainerConfig( - batch_size=10, learning_rate=1e-5, encoder_optim_decay=0 + per_device_train_batch_size=10, + per_device_eval_batch_size=10, + encoder_learning_rate=1e-5, + decoder_learning_rate=1e-3, + encoder_optimizer_cls="AdamW", + encoder_optimizer_params={"weight_decay": 0.01}, + decoder_optimizer_cls="SGD", + decoder_optimizer_params={"weight_decay": 0.01}, + encoder_scheduler_cls="ExponentialLR", + encoder_scheduler_params={"gamma": 0.321}, ), ] ) @@ -112,6 +127,26 @@ def test_set_training_config(self, model_sample, train_dataset, training_configs class Test_Build_Optimizer: + def test_wrong_optimizer_cls(self): + with pytest.raises(AttributeError): + CoupledOptimizerTrainerConfig(encoder_optimizer_cls="WrongOptim") + + with pytest.raises(AttributeError): + CoupledOptimizerTrainerConfig(decoder_optimizer_cls="WrongOptim") + + def test_wrong_optimizer_params(self): + with pytest.raises(TypeError): + CoupledOptimizerTrainerConfig( + encoder_optimizer_cls="Adam", + encoder_optimizer_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + CoupledOptimizerTrainerConfig( + decoder_optimizer_cls="Adam", + decoder_optimizer_params={"wrong_config": 1}, + ) + @pytest.fixture( params=[ CoupledOptimizerTrainerConfig(learning_rate=1e-5), @@ -122,18 +157,47 @@ def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): + @pytest.fixture( + params=[ + { + "encoder_optimizer_cls": "Adagrad", + "encoder_optimizer_params": {"lr_decay": 0.1}, + "decoder_optimizer_cls": "AdamW", + "decoder_optimizer_params": {"betas": (0.1234, 0.4321)}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": {"momentum": 0.1}, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": {"momentum": 0.9}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": None, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": None, + }, + ] + ) + def optimizer_config(self, request, training_configs_learning_rate): - encoder_optimizer = request.param( - model_sample.encoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - decoder_optimizer = request.param( - model_sample.decoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - return (encoder_optimizer, decoder_optimizer) + optimizer_config = request.param + + # set optim and params to training config + training_configs_learning_rate.encoder_optimizer_cls = optimizer_config[ + "encoder_optimizer_cls" + ] + training_configs_learning_rate.encoder_optimizer_params = optimizer_config[ + "encoder_optimizer_params" + ] + training_configs_learning_rate.decoder_optimizer_cls = optimizer_config[ + "decoder_optimizer_cls" + ] + training_configs_learning_rate.decoder_optimizer_params = optimizer_config[ + "decoder_optimizer_params" + ] + + return optimizer_config def test_default_optimizer_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -143,83 +207,146 @@ def test_default_optimizer_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=None, - decoder_optimizer=None, ) + trainer.set_encoder_optimizer() + trainer.set_decoder_optimizer() + assert issubclass(type(trainer.encoder_optimizer), torch.optim.Adam) assert ( trainer.encoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.encoder_learning_rate ) assert issubclass(type(trainer.decoder_optimizer), torch.optim.Adam) assert ( trainer.decoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.decoder_learning_rate ) def test_set_custom_optimizer( - self, model_sample, train_dataset, training_configs_learning_rate, optimizers + self, + model_sample, + train_dataset, + training_configs_learning_rate, + optimizer_config, ): trainer = CoupledOptimizerTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], ) - assert issubclass(type(trainer.encoder_optimizer), type(optimizers[0])) + trainer.set_encoder_optimizer() + trainer.set_decoder_optimizer() + + assert issubclass( + type(trainer.encoder_optimizer), + getattr(torch.optim, optimizer_config["encoder_optimizer_cls"]), + ) assert ( trainer.encoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.encoder_learning_rate ) + if optimizer_config["encoder_optimizer_params"] is not None: + assert all( + [ + trainer.encoder_optimizer.defaults[key] + == optimizer_config["encoder_optimizer_params"][key] + for key in optimizer_config["encoder_optimizer_params"].keys() + ] + ) - assert issubclass(type(trainer.decoder_optimizer), type(optimizers[1])) + assert issubclass( + type(trainer.decoder_optimizer), + getattr(torch.optim, optimizer_config["decoder_optimizer_cls"]), + ) assert ( trainer.decoder_optimizer.defaults["lr"] - == training_configs_learning_rate.learning_rate + == training_configs_learning_rate.decoder_learning_rate ) + if optimizer_config["decoder_optimizer_params"] is not None: + assert all( + [ + trainer.decoder_optimizer.defaults[key] + == optimizer_config["decoder_optimizer_params"][key] + for key in optimizer_config["decoder_optimizer_params"].keys() + ] + ) + class Test_Build_Scheduler: - @pytest.fixture(params=[CoupledOptimizerTrainerConfig(), CoupledOptimizerTrainerConfig(learning_rate=1e-5)]) + def test_wrong_scheduler_cls(self): + with pytest.raises(AttributeError): + CoupledOptimizerTrainerConfig(encoder_scheduler_cls="WrongOptim") + + with pytest.raises(AttributeError): + CoupledOptimizerTrainerConfig(decoder_scheduler_cls="WrongOptim") + + def test_wrong_scheduler_params(self): + with pytest.raises(TypeError): + CoupledOptimizerTrainerConfig( + encoder_scheduler_cls="ReduceLROnPlateau", + encoder_scheduler_params={"wrong_config": 1}, + ) + + with pytest.raises(TypeError): + CoupledOptimizerTrainerConfig( + decoder_scheduler_cls="ReduceLROnPlateau", + decoder_scheduler_params={"wrong_config": 1}, + ) + + @pytest.fixture( + params=[ + CoupledOptimizerTrainerConfig(), + CoupledOptimizerTrainerConfig(learning_rate=1e-5), + ] + ) def training_configs_learning_rate(self, tmpdir, request): request.param.output_dir = tmpdir.mkdir("dummy_folder") return request.param - @pytest.fixture(params=[Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, model_sample, training_configs_learning_rate): - - autoencoder_optimizer = request.param( - model_sample.encoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - discriminator_optimizer = request.param( - model_sample.decoder.parameters(), - lr=training_configs_learning_rate.learning_rate, - ) - return (autoencoder_optimizer, discriminator_optimizer) - @pytest.fixture( params=[ - (StepLR, {"step_size": 1}), - (LinearLR, {"start_factor": 0.01}), - (ExponentialLR, {"gamma": 0.1}), + { + "encoder_scheduler_cls": "StepLR", + "encoder_scheduler_params": {"step_size": 1}, + "decoder_scheduler_cls": "LinearLR", + "decoder_scheduler_params": None, + }, + { + "encoder_scheduler_cls": None, + "encoder_scheduler_params": None, + "decoder_scheduler_cls": "ExponentialLR", + "decoder_scheduler_params": {"gamma": 0.1}, + }, + { + "encoder_scheduler_cls": "ReduceLROnPlateau", + "encoder_scheduler_params": {"patience": 12}, + "decoder_scheduler_cls": None, + "decoder_scheduler_params": None, + }, ] ) - def schedulers( - self, request, optimizers - ): - if request.param[0] is not None: - encoder_scheduler = request.param[0](optimizers[0], **request.param[1]) - decoder_scheduler = request.param[0](optimizers[1], **request.param[1]) + def scheduler_config(self, request, training_configs_learning_rate): - else: - encoder_scheduler = None - decoder_scheduler = None + scheduler_config = request.param + + # set scheduler and params to training config + training_configs_learning_rate.encoder_scheduler_cls = scheduler_config[ + "encoder_scheduler_cls" + ] + training_configs_learning_rate.encoder_scheduler_params = scheduler_config[ + "encoder_scheduler_params" + ] + training_configs_learning_rate.decoder_scheduler_cls = scheduler_config[ + "decoder_scheduler_cls" + ] + training_configs_learning_rate.decoder_scheduler_params = scheduler_config[ + "decoder_scheduler_params" + ] - return (encoder_scheduler, decoder_scheduler) + return request.param def test_default_scheduler_building( self, model_sample, train_dataset, training_configs_learning_rate @@ -229,42 +356,77 @@ def test_default_scheduler_building( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=None, - decoder_optimizer=None ) - assert issubclass( - type(trainer.encoder_scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau - ) + trainer.set_encoder_optimizer() + trainer.set_encoder_scheduler() + trainer.set_decoder_optimizer() + trainer.set_decoder_scheduler() - assert issubclass( - type(trainer.decoder_scheduler), torch.optim.lr_scheduler.ReduceLROnPlateau - ) + assert trainer.encoder_scheduler is None + assert trainer.decoder_scheduler is None def test_set_custom_scheduler( self, model_sample, train_dataset, training_configs_learning_rate, - optimizers, - schedulers, + scheduler_config, ): trainer = CoupledOptimizerTrainer( model=model_sample, train_dataset=train_dataset, training_config=training_configs_learning_rate, - encoder_optimizer=optimizers[0], - encoder_scheduler=schedulers[0], - decoder_optimizer=optimizers[1], - decoder_scheduler=schedulers[1] ) - assert issubclass(type(trainer.encoder_scheduler), type(schedulers[0])) - assert issubclass(type(trainer.decoder_scheduler), type(schedulers[1])) + trainer.set_encoder_optimizer() + trainer.set_encoder_scheduler() + trainer.set_decoder_optimizer() + trainer.set_decoder_scheduler() + + if scheduler_config["encoder_scheduler_cls"] is None: + assert trainer.encoder_scheduler is None + else: + assert issubclass( + type(trainer.encoder_scheduler), + getattr( + torch.optim.lr_scheduler, scheduler_config["encoder_scheduler_cls"] + ), + ) + if scheduler_config["encoder_scheduler_params"] is not None: + assert all( + [ + trainer.encoder_scheduler.state_dict()[key] + == scheduler_config["encoder_scheduler_params"][key] + for key in scheduler_config["encoder_scheduler_params"].keys() + ] + ) + + if scheduler_config["decoder_scheduler_cls"] is None: + assert trainer.decoder_scheduler is None + + else: + assert issubclass( + type(trainer.decoder_scheduler), + getattr( + torch.optim.lr_scheduler, scheduler_config["decoder_scheduler_cls"] + ), + ) + if scheduler_config["decoder_scheduler_params"] is not None: + assert all( + [ + trainer.decoder_scheduler.state_dict()[key] + == scheduler_config["decoder_scheduler_params"][key] + for key in scheduler_config["decoder_scheduler_params"].keys() + ] + ) + @pytest.mark.slow class Test_Main_Training: - @pytest.fixture(params=[CoupledOptimizerTrainerConfig(num_epochs=3, learning_rate=1e-4)]) + @pytest.fixture( + params=[CoupledOptimizerTrainerConfig(num_epochs=3, learning_rate=1e-4)] + ) def training_configs(self, tmpdir, request): tmpdir.mkdir("dummy_folder") dir_path = os.path.join(tmpdir, "dummy_folder") @@ -311,57 +473,99 @@ def ae(self, ae_config, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[None, Adagrad, Adam, Adadelta, SGD, RMSprop]) - def optimizers(self, request, ae, training_configs): - if request.param is not None: - encoder_optimizer = request.param( - ae.encoder.parameters(), lr=training_configs.learning_rate - ) - - decoder_optimizer = request.param( - ae.decoder.parameters(), lr=training_configs.learning_rate - ) - - else: - encoder_optimizer = None - decoder_optimizer = None - - return (encoder_optimizer, decoder_optimizer) + @pytest.fixture( + params=[ + { + "encoder_optimizer_cls": "Adagrad", + "encoder_optimizer_params": {"lr_decay": 0.1}, + "decoder_optimizer_cls": "AdamW", + "decoder_optimizer_params": {"betas": (0.1234, 0.4321)}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": {"momentum": 0.1}, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": {"momentum": 0.9}, + }, + { + "encoder_optimizer_cls": "SGD", + "encoder_optimizer_params": None, + "decoder_optimizer_cls": "SGD", + "decoder_optimizer_params": None, + }, + ] + ) + def optimizer_config(self, request): + return request.param @pytest.fixture( params=[ - (None, None), - (StepLR, {"step_size": 1, "gamma": 0.99}), - (LinearLR, {"start_factor": 0.99}), - (ExponentialLR, {"gamma": 0.99}), + { + "encoder_scheduler_cls": "LinearLR", + "encoder_scheduler_params": None, + "decoder_scheduler_cls": "LinearLR", + "decoder_scheduler_params": None, + }, + { + "encoder_scheduler_cls": None, + "encoder_scheduler_params": None, + "decoder_scheduler_cls": "ExponentialLR", + "decoder_scheduler_params": {"gamma": 0.012}, + }, + { + "encoder_scheduler_cls": "ReduceLROnPlateau", + "encoder_scheduler_params": {"patience": 12}, + "decoder_scheduler_cls": None, + "decoder_scheduler_params": None, + }, ] ) - def schedulers(self, request, optimizers): - if request.param[0] is not None and optimizers[0] is not None: - encoder_scheduler = request.param[0](optimizers[0], **request.param[1]) - - else: - encoder_scheduler = None - - if request.param[0] is not None and optimizers[1] is not None: - decoder_scheduler = request.param[0](optimizers[1], **request.param[1]) + def scheduler_config(self, request): + return request.param - else: - decoder_scheduler = None + @pytest.fixture + def trainer( + self, ae, train_dataset, optimizer_config, scheduler_config, training_configs + ): - return (encoder_scheduler, decoder_scheduler) + training_configs.encoder_optimizer_cls = optimizer_config[ + "encoder_optimizer_cls" + ] + training_configs.encoder_optimizer_params = optimizer_config[ + "encoder_optimizer_params" + ] + training_configs.decoder_optimizer_cls = optimizer_config[ + "decoder_optimizer_cls" + ] + training_configs.decoder_optimizer_params = optimizer_config[ + "decoder_optimizer_params" + ] + training_configs.encoder_scheduler_cls = scheduler_config[ + "encoder_scheduler_cls" + ] + training_configs.encoder_scheduler_params = scheduler_config[ + "encoder_scheduler_params" + ] + training_configs.decoder_scheduler_cls = scheduler_config[ + "decoder_scheduler_cls" + ] + training_configs.decoder_scheduler_params = scheduler_config[ + "decoder_scheduler_params" + ] - def test_train_step(self, ae, train_dataset, training_configs, optimizers, schedulers): trainer = CoupledOptimizerTrainer( model=ae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - encoder_scheduler=schedulers[0], - decoder_scheduler=schedulers[1] ) + trainer.prepare_training() + + return trainer + + def test_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -371,22 +575,16 @@ def test_train_step(self, ae, train_dataset, training_configs, optimizers, sched # check that weights were updated for key in start_model_state_dict.keys(): if "encoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]), key + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ), key if "decoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) - def test_eval_step(self, ae, train_dataset, training_configs, optimizers, schedulers): - trainer = CoupledOptimizerTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - encoder_scheduler=schedulers[0], - decoder_scheduler=schedulers[1] - ) + def test_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -402,20 +600,7 @@ def test_eval_step(self, ae, train_dataset, training_configs, optimizers, schedu ] ) - def test_main_train_loop( - self, ae, train_dataset, training_configs, optimizers, schedulers - ): - - trainer = CoupledOptimizerTrainer( - model=ae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - encoder_scheduler=schedulers[0], - decoder_scheduler=schedulers[1] - ) + def test_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -426,10 +611,14 @@ def test_main_train_loop( # check that weights were updated for key in start_model_state_dict.keys(): if "encoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]), key + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ), key if "decoder" in key: - assert not torch.equal(step_1_model_state_dict[key], start_model_state_dict[key]) + assert not torch.equal( + step_1_model_state_dict[key], start_model_state_dict[key] + ) class Test_Logging: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index fd287126..0ce9dcca 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -7,14 +7,17 @@ PATH = os.path.dirname(os.path.abspath(__file__)) + @pytest.fixture def data(): return torch.randn(1000, 2) + @pytest.fixture def labels(): return torch.randint(10, (1000,)) + class Test_Dataset: def test_dataset_call(self, data, labels): dataset = BaseDataset(data, labels) diff --git a/tests/test_gaussian_mixture_sampler.py b/tests/test_gaussian_mixture_sampler.py index 2e06f0fb..7a16ae1e 100644 --- a/tests/test_gaussian_mixture_sampler.py +++ b/tests/test_gaussian_mixture_sampler.py @@ -4,9 +4,14 @@ import pytest import torch -from pythae.models import AE, AEConfig, VAE, VAEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, GaussianMixtureSampler, GaussianMixtureSamplerConfig +from pythae.models import AE, VAE, AEConfig, VAEConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + GaussianMixtureSampler, + GaussianMixtureSamplerConfig, + NormalSampler, + NormalSamplerConfig, +) PATH = os.path.dirname(os.path.abspath(__file__)) @@ -142,16 +147,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -166,13 +174,14 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, GaussianMixtureSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None diff --git a/tests/test_hypersphere_uniform_sampler.py b/tests/test_hypersphere_uniform_sampler.py index 1ba3033f..682d46a9 100644 --- a/tests/test_hypersphere_uniform_sampler.py +++ b/tests/test_hypersphere_uniform_sampler.py @@ -4,9 +4,14 @@ import pytest import torch -from pythae.models import AE, AEConfig, VAE, VAEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, HypersphereUniformSampler, HypersphereUniformSamplerConfig +from pythae.models import AE, VAE, AEConfig, VAEConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + HypersphereUniformSampler, + HypersphereUniformSamplerConfig, + NormalSampler, + NormalSamplerConfig, +) PATH = os.path.dirname(os.path.abspath(__file__)) @@ -23,6 +28,7 @@ def dummy_data(): def model(request): return request.param + @pytest.fixture( params=[ HypersphereUniformSamplerConfig(), @@ -32,6 +38,7 @@ def model(request): def sampler_config(request): return request.param + @pytest.fixture() def sampler(model, sampler_config): return HypersphereUniformSampler( @@ -132,16 +139,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -156,15 +166,16 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, HypersphereUniformSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None assert "sampler_config.json" not in os.listdir(dir_path) - assert len(os.listdir(dir_path)) == num_samples \ No newline at end of file + assert len(os.listdir(dir_path)) == num_samples diff --git a/tests/test_iaf_sampler.py b/tests/test_iaf_sampler.py index eff123de..5e5edce9 100644 --- a/tests/test_iaf_sampler.py +++ b/tests/test_iaf_sampler.py @@ -1,13 +1,18 @@ import os +from copy import deepcopy import pytest import torch -from copy import deepcopy -from pythae.models import VAE, VAEConfig, AE, AEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, IAFSampler, IAFSamplerConfig -from pythae.trainers import BaseTrainerConfig +from pythae.models import AE, VAE, AEConfig, VAEConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + IAFSampler, + IAFSamplerConfig, + NormalSampler, + NormalSamplerConfig, +) +from pythae.trainers import BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -192,16 +197,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -216,13 +224,14 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, IAFSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None diff --git a/tests/test_info_vae_mmd.py b/tests/test_info_vae_mmd.py index fbbe33a7..016d558e 100644 --- a/tests/test_info_vae_mmd.py +++ b/tests/test_info_vae_mmd.py @@ -3,16 +3,19 @@ import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError -from pythae.models.base.base_utils import ModelOutput from pythae.models import INFOVAE_MMD, INFOVAE_MMD_Config - +from pythae.models.base.base_utils import ModelOutput +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, + TwoStageVAESamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig - -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_VAE_Conv, @@ -125,7 +128,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = INFOVAE_MMD.load_from_folder(dir_path) @@ -217,7 +222,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -302,14 +307,15 @@ def test_model_train_output(self, info_vae_mmd, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -324,23 +330,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return INFOVAE_MMD(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -351,12 +364,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return INFOVAE_MMD(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): @@ -428,28 +441,21 @@ def info_vae_mmd(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, info_vae_mmd, training_configs): - if request.param is not None: - optimizer = request.param( - info_vae_mmd.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_info_vae_mmd_train_step( - self, info_vae_mmd, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, info_vae_mmd, train_dataset, training_configs): trainer = BaseTrainer( model=info_vae_mmd, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_info_vae_mmd_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -464,16 +470,7 @@ def test_info_vae_mmd_train_step( ] ) - def test_info_vae_mmd_eval_step( - self, info_vae_mmd, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=info_vae_mmd, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_info_vae_mmd_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -489,16 +486,7 @@ def test_info_vae_mmd_eval_step( ] ) - def test_info_vae_mmd_predict_step( - self, info_vae_mmd, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=info_vae_mmd, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_info_vae_mmd_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -514,21 +502,11 @@ def test_info_vae_mmd_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape - - def test_info_vae_mmd_main_train_loop( - self, tmpdir, info_vae_mmd, train_dataset, training_configs, optimizers - ): + assert generated.shape == inputs.shape - trainer = BaseTrainer( - model=info_vae_mmd, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_info_vae_mmd_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -544,19 +522,10 @@ def test_info_vae_mmd_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, info_vae_mmd, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, info_vae_mmd, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=info_vae_mmd, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -639,20 +608,13 @@ def test_checkpoint_saving( ) def test_checkpoint_saving_during_training( - self, tmpdir, info_vae_mmd, train_dataset, training_configs, optimizers + self, info_vae_mmd, trainer, training_configs ): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=info_vae_mmd, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -700,19 +662,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, info_vae_mmd, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, info_vae_mmd, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=info_vae_mmd, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -824,10 +777,13 @@ def test_info_vae_mmd_training_pipeline( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_InfoVAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -839,7 +795,7 @@ def ae_model(self): GaussianMixtureSamplerConfig(), MAFSamplerConfig(), IAFSamplerConfig(), - TwoStageVAESamplerConfig() + TwoStageVAESamplerConfig(), ] ) def sampler_configs(self, request): @@ -854,7 +810,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_maf_sampler.py b/tests/test_maf_sampler.py index 03af31a5..e98d3cc5 100644 --- a/tests/test_maf_sampler.py +++ b/tests/test_maf_sampler.py @@ -1,13 +1,18 @@ import os +from copy import deepcopy import pytest import torch -from copy import deepcopy -from pythae.models import VAE, VAEConfig, AE, AEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, MAFSampler, MAFSamplerConfig -from pythae.trainers import BaseTrainerConfig +from pythae.models import AE, VAE, AEConfig, VAEConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + MAFSampler, + MAFSamplerConfig, + NormalSampler, + NormalSamplerConfig, +) +from pythae.trainers import BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -192,16 +197,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -216,15 +224,16 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, MAFSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None assert "sampler_config.json" not in os.listdir(dir_path) - assert len(os.listdir(dir_path)) == num_samples \ No newline at end of file + assert len(os.listdir(dir_path)) == num_samples diff --git a/tests/test_nn_benchmark.py b/tests/test_nn_benchmark.py index e19c494a..53f9f7ec 100644 --- a/tests/test_nn_benchmark.py +++ b/tests/test_nn_benchmark.py @@ -1,11 +1,11 @@ +import numpy as np import pytest import torch -import numpy as np from pythae.models import AEConfig, VAEConfig -from pythae.models.nn.benchmarks.mnist import * from pythae.models.nn.benchmarks.celeba import * from pythae.models.nn.benchmarks.cifar import * +from pythae.models.nn.benchmarks.mnist import * from pythae.models.nn.default_architectures import * device = "cuda" if torch.cuda.is_available() else "cpu" @@ -20,10 +20,12 @@ def ae_mnist_config(request): return request.param + @pytest.fixture() def mnist_like_data(): return torch.rand(3, 1, 28, 28).to(device) + #### CIFAR configs #### @pytest.fixture( params=[ @@ -34,10 +36,12 @@ def mnist_like_data(): def ae_cifar_config(request): return request.param + @pytest.fixture() def cifar_like_data(): return torch.rand(3, 3, 32, 32).to(device) + #### CELEBA configs #### @pytest.fixture( params=[ @@ -48,13 +52,13 @@ def cifar_like_data(): def ae_celeba_config(request): return request.param + @pytest.fixture() def celeba_like_data(): return torch.rand(3, 3, 64, 64).to(device) class Test_MNIST_Default: - @pytest.fixture(params=[[1], None, [-1]]) def recon_layers_default(self, request): return request.param @@ -95,9 +99,8 @@ def test_ae_encoding_decoding_default( if recon_layers_default is None: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers_default: @@ -154,16 +157,14 @@ def test_vae_encoding_decoding_default( assert "log_covariance" in encoder_embed.keys() assert "reconstruction" in decoder_recon.keys() - if recon_layers_default is None: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( encoder_embed["log_covariance"].shape[1] == ae_mnist_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers_default: @@ -173,12 +174,13 @@ def test_vae_encoding_decoding_default( if -1 in recon_layers_default: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - encoder_embed["log_covariance"].shape[1] == ae_mnist_config.latent_dim + encoder_embed["log_covariance"].shape[1] + == ae_mnist_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) def test_svae_encoding_decoding_default( self, ae_mnist_config, mnist_like_data, recon_layers_default @@ -214,7 +216,7 @@ def test_svae_encoding_decoding_default( if lev != -1: assert f"embedding_layer_{lev}" in encoder_embed.keys() assert f"reconstruction_layer_{lev}" in decoder_recon.keys() - + else: assert "embedding" in encoder_embed.keys() assert "log_concentration" in encoder_embed.keys() @@ -224,9 +226,8 @@ def test_svae_encoding_decoding_default( assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers_default: @@ -237,9 +238,9 @@ def test_svae_encoding_decoding_default( assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) def test_discriminator_default( self, ae_mnist_config, mnist_like_data, recon_layers_default @@ -276,15 +277,13 @@ def test_discriminator_default( if -1 in recon_layers_default: assert scores["embedding"].shape[1] == 1 -class Test_MNIST_ConvNets: +class Test_MNIST_ConvNets: @pytest.fixture(params=[[3, 4], [np.random.randint(1, 5)], [1, 2, 4, -1], None]) def recon_layers(self, request): return request.param - def test_ae_encoding_decoding( - self, ae_mnist_config, mnist_like_data, recon_layers - ): + def test_ae_encoding_decoding(self, ae_mnist_config, mnist_like_data, recon_layers): encoder = Encoder_Conv_AE_MNIST(ae_mnist_config).to(device) decoder = Decoder_Conv_AE_MNIST(ae_mnist_config).to(device) @@ -296,9 +295,7 @@ def test_ae_encoding_decoding( assert reconstruction.shape == mnist_like_data.shape - encoder_embed = encoder( - mnist_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(mnist_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -307,7 +304,7 @@ def test_ae_encoding_decoding( else: for lev in recon_layers: - if lev !=-1: + if lev != -1: assert f"embedding_layer_{lev}" in encoder_embed.keys() assert f"reconstruction_layer_{lev}" in decoder_recon.keys() @@ -318,9 +315,8 @@ def test_ae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers: @@ -371,9 +367,7 @@ def test_vae_encoding_decoding( assert reconstruction.shape == mnist_like_data.shape - encoder_embed = encoder( - mnist_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(mnist_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -392,16 +386,14 @@ def test_vae_encoding_decoding( assert "log_covariance" in encoder_embed.keys() assert "reconstruction" in decoder_recon.keys() - if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( encoder_embed["log_covariance"].shape[1] == ae_mnist_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers: @@ -426,12 +418,13 @@ def test_vae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - encoder_embed["log_covariance"].shape[1] == ae_mnist_config.latent_dim + encoder_embed["log_covariance"].shape[1] + == ae_mnist_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) def test_svae_encoding_decoding( self, ae_mnist_config, mnist_like_data, recon_layers @@ -452,9 +445,7 @@ def test_svae_encoding_decoding( assert reconstruction.shape == mnist_like_data.shape - encoder_embed = encoder( - mnist_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(mnist_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -473,14 +464,12 @@ def test_svae_encoding_decoding( assert "log_concentration" in encoder_embed.keys() assert "reconstruction" in decoder_recon.keys() - if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers: @@ -506,21 +495,17 @@ def test_svae_encoding_decoding( assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) - def test_discriminator( - self, ae_mnist_config, mnist_like_data, recon_layers - ): + def test_discriminator(self, ae_mnist_config, mnist_like_data, recon_layers): ae_mnist_config.discriminator_input_dim = (1, 28, 28) discriminator = Discriminator_Conv_MNIST(ae_mnist_config).to(device) - scores = discriminator( - mnist_like_data, output_layer_levels=recon_layers - ) + scores = discriminator(mnist_like_data, output_layer_levels=recon_layers) if recon_layers is None: assert "embedding" in scores.keys() @@ -550,17 +535,15 @@ def test_discriminator( assert scores[f"embedding_layer_4"].shape[1] == 1024 if -1 in recon_layers: - assert scores["embedding"].shape[1] == 1 - + assert scores["embedding"].shape[1] == 1 + class Test_MNIST_ResNets: @pytest.fixture(params=[[3, 4], [np.random.randint(1, 5)], [1, 2, 4, -1], None]) def recon_layers(self, request): return request.param - def test_ae_encoding_decoding( - self, ae_mnist_config, mnist_like_data, recon_layers - ): + def test_ae_encoding_decoding(self, ae_mnist_config, mnist_like_data, recon_layers): encoder = Encoder_ResNet_AE_MNIST(ae_mnist_config).to(device) decoder = Decoder_ResNet_AE_MNIST(ae_mnist_config).to(device) @@ -572,9 +555,7 @@ def test_ae_encoding_decoding( assert reconstruction.shape == mnist_like_data.shape - encoder_embed = encoder( - mnist_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(mnist_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -589,14 +570,13 @@ def test_ae_encoding_decoding( else: assert "embedding" in encoder_embed.keys() - assert "reconstruction" in decoder_recon.keys() + assert "reconstruction" in decoder_recon.keys() if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers: @@ -614,7 +594,7 @@ def test_ae_encoding_decoding( if 4 in recon_layers: assert encoder_embed[f"embedding_layer_4"].shape[1] == 128 assert decoder_recon[f"reconstruction_layer_4"].shape[1] == 64 - + if 5 in recon_layers: assert ( decoder_recon[f"reconstruction_layer_5"].shape[1:] @@ -624,9 +604,9 @@ def test_ae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) def test_vae_encoding_decoding( self, ae_mnist_config, mnist_like_data, recon_layers @@ -642,9 +622,7 @@ def test_vae_encoding_decoding( assert reconstruction.shape == mnist_like_data.shape - encoder_embed = encoder( - mnist_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(mnist_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -669,9 +647,8 @@ def test_vae_encoding_decoding( encoder_embed["log_covariance"].shape[1] == ae_mnist_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers: @@ -689,7 +666,7 @@ def test_vae_encoding_decoding( if 4 in recon_layers: assert encoder_embed[f"embedding_layer_4"].shape[1] == 128 assert decoder_recon[f"reconstruction_layer_4"].shape[1] == 64 - + if 5 in recon_layers: assert ( decoder_recon[f"reconstruction_layer_5"].shape[1:] @@ -699,12 +676,13 @@ def test_vae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - encoder_embed["log_covariance"].shape[1] == ae_mnist_config.latent_dim + encoder_embed["log_covariance"].shape[1] + == ae_mnist_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) def test_svae_encoding_decoding( self, ae_mnist_config, mnist_like_data, recon_layers @@ -720,9 +698,7 @@ def test_svae_encoding_decoding( assert reconstruction.shape == mnist_like_data.shape - encoder_embed = encoder( - mnist_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(mnist_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -745,9 +721,8 @@ def test_svae_encoding_decoding( assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers: @@ -765,7 +740,7 @@ def test_svae_encoding_decoding( if 4 in recon_layers: assert encoder_embed[f"embedding_layer_4"].shape[1] == 128 assert decoder_recon[f"reconstruction_layer_4"].shape[1] == 64 - + if 5 in recon_layers: assert ( decoder_recon[f"reconstruction_layer_5"].shape[1:] @@ -776,10 +751,9 @@ def test_svae_encoding_decoding( assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) - + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) def test_vqvae_encoding_decoding( self, ae_mnist_config, mnist_like_data, recon_layers @@ -789,15 +763,16 @@ def test_vqvae_encoding_decoding( embedding = encoder(mnist_like_data).embedding - assert embedding.shape[:2] == (mnist_like_data.shape[0], ae_mnist_config.latent_dim) + assert embedding.shape[:2] == ( + mnist_like_data.shape[0], + ae_mnist_config.latent_dim, + ) reconstruction = decoder(embedding).reconstruction assert reconstruction.shape == mnist_like_data.shape - encoder_embed = encoder( - mnist_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(mnist_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -817,9 +792,8 @@ def test_vqvae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_mnist_config.input_dim + ) else: if 1 in recon_layers: @@ -847,10 +821,11 @@ def test_vqvae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_mnist_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_mnist_config.input_dim - ) - + decoder_recon[f"reconstruction"].shape[1:] + == ae_mnist_config.input_dim + ) + + class Test_CIFAR_ConvNets: @pytest.fixture(params=[[3, 4], [np.random.randint(1, 5)], [1, 2, 4, -1], None]) def recon_layers(self, request): @@ -888,9 +863,8 @@ def test_ae_encoding_decoding(self, ae_cifar_config, cifar_like_data, recon_laye if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_cifar_config.input_dim + ) else: if 1 in recon_layers: @@ -915,10 +889,9 @@ def test_ae_encoding_decoding(self, ae_cifar_config, cifar_like_data, recon_laye if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) - + decoder_recon[f"reconstruction"].shape[1:] + == ae_cifar_config.input_dim + ) def test_vae_encoding_decoding( self, ae_cifar_config, cifar_like_data, recon_layers @@ -967,9 +940,8 @@ def test_vae_encoding_decoding( encoder_embed["log_covariance"].shape[1] == ae_cifar_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_cifar_config.input_dim + ) else: if 1 in recon_layers: @@ -990,16 +962,17 @@ def test_vae_encoding_decoding( decoder_recon[f"reconstruction_layer_4"].shape[1:] == ae_cifar_config.input_dim ) - + if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - encoder_embed["log_covariance"].shape[1] == ae_cifar_config.latent_dim + encoder_embed["log_covariance"].shape[1] + == ae_cifar_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_cifar_config.input_dim + ) def test_svae_encoding_decoding( self, ae_cifar_config, cifar_like_data, recon_layers @@ -1044,13 +1017,10 @@ def test_svae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim + assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - encoder_embed["log_concentration"].shape[1] == 1 + decoder_recon[f"reconstruction"].shape[1:] == ae_cifar_config.input_dim ) - assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) else: if 1 in recon_layers: @@ -1074,26 +1044,19 @@ def test_svae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim + assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - encoder_embed["log_concentration"].shape[1] == 1 + decoder_recon[f"reconstruction"].shape[1:] + == ae_cifar_config.input_dim ) - assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) - - def test_discriminator( - self, ae_cifar_config, cifar_like_data, recon_layers - ): + def test_discriminator(self, ae_cifar_config, cifar_like_data, recon_layers): ae_cifar_config.discriminator_input_dim = (3, 32, 32) discriminator = Discriminator_Conv_CIFAR(ae_cifar_config).to(device) - scores = discriminator( - cifar_like_data, output_layer_levels=recon_layers - ) + scores = discriminator(cifar_like_data, output_layer_levels=recon_layers) if recon_layers is None: assert "embedding" in scores.keys() @@ -1125,14 +1088,13 @@ def test_discriminator( if -1 in recon_layers: assert scores["embedding"].shape[1] == 1 + class Test_CIFAR_ResNets: @pytest.fixture(params=[[3, 4], [np.random.randint(1, 5)], [1, 2, 4, -1], None]) def recon_layers(self, request): return request.param - def test_ae_encoding_decoding( - self, ae_cifar_config, cifar_like_data, recon_layers - ): + def test_ae_encoding_decoding(self, ae_cifar_config, cifar_like_data, recon_layers): encoder = Encoder_ResNet_AE_CIFAR(ae_cifar_config).to(device) decoder = Decoder_ResNet_AE_CIFAR(ae_cifar_config).to(device) @@ -1144,9 +1106,7 @@ def test_ae_encoding_decoding( assert reconstruction.shape == cifar_like_data.shape - encoder_embed = encoder( - cifar_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(cifar_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -1166,9 +1126,8 @@ def test_ae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_cifar_config.input_dim + ) else: if 1 in recon_layers: @@ -1193,10 +1152,9 @@ def test_ae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) - + decoder_recon[f"reconstruction"].shape[1:] + == ae_cifar_config.input_dim + ) def test_vae_encoding_decoding( self, ae_cifar_config, cifar_like_data, recon_layers @@ -1212,9 +1170,7 @@ def test_vae_encoding_decoding( assert reconstruction.shape == cifar_like_data.shape - encoder_embed = encoder( - cifar_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(cifar_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -1239,9 +1195,8 @@ def test_vae_encoding_decoding( encoder_embed["log_covariance"].shape[1] == ae_cifar_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_cifar_config.input_dim + ) else: if 1 in recon_layers: @@ -1266,12 +1221,13 @@ def test_vae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - encoder_embed["log_covariance"].shape[1] == ae_cifar_config.latent_dim + encoder_embed["log_covariance"].shape[1] + == ae_cifar_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_cifar_config.input_dim + ) def test_svae_encoding_decoding( self, ae_cifar_config, cifar_like_data, recon_layers @@ -1287,9 +1243,7 @@ def test_svae_encoding_decoding( assert reconstruction.shape == cifar_like_data.shape - encoder_embed = encoder( - cifar_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(cifar_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -1312,9 +1266,8 @@ def test_svae_encoding_decoding( assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_cifar_config.input_dim + ) else: if 1 in recon_layers: @@ -1340,9 +1293,9 @@ def test_svae_encoding_decoding( assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_cifar_config.input_dim + ) def test_vqvae_encoding_decoding( self, ae_cifar_config, cifar_like_data, recon_layers @@ -1352,15 +1305,16 @@ def test_vqvae_encoding_decoding( embedding = encoder(cifar_like_data).embedding - assert embedding.shape[:2] == (cifar_like_data.shape[0], ae_cifar_config.latent_dim) + assert embedding.shape[:2] == ( + cifar_like_data.shape[0], + ae_cifar_config.latent_dim, + ) reconstruction = decoder(embedding).reconstruction assert reconstruction.shape == cifar_like_data.shape - encoder_embed = encoder( - cifar_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(cifar_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -1380,9 +1334,8 @@ def test_vqvae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_cifar_config.input_dim + ) else: if 1 in recon_layers: @@ -1407,16 +1360,16 @@ def test_vqvae_encoding_decoding( if -1 in recon_layers: assert encoder_embed["embedding"].shape[1] == ae_cifar_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_cifar_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_cifar_config.input_dim + ) -class Test_CELEBA_ConvNets: +class Test_CELEBA_ConvNets: @pytest.fixture(params=[[3, 4], [np.random.randint(1, 5)], [1, 2, 4, -1], None]) def recon_layers(self, request): return request.param - + def test_ae_encoding_decoding( self, ae_celeba_config, celeba_like_data, recon_layers ): @@ -1454,9 +1407,8 @@ def test_ae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_celeba_config.input_dim + ) else: if 1 in recon_layers: @@ -1482,11 +1434,13 @@ def test_ae_encoding_decoding( ) if -1 in recon_layers: - assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + ) + assert ( + decoder_recon[f"reconstruction"].shape[1:] + == ae_celeba_config.input_dim + ) def test_vae_encoding_decoding( self, ae_celeba_config, celeba_like_data, recon_layers @@ -1535,9 +1489,8 @@ def test_vae_encoding_decoding( encoder_embed["log_covariance"].shape[1] == ae_celeba_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_celeba_config.input_dim + ) else: if 1 in recon_layers: @@ -1563,14 +1516,17 @@ def test_vae_encoding_decoding( ) if -1 in recon_layers: - assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - encoder_embed["log_covariance"].shape[1] == ae_celeba_config.latent_dim + encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + encoder_embed["log_covariance"].shape[1] + == ae_celeba_config.latent_dim + ) + assert ( + decoder_recon[f"reconstruction"].shape[1:] + == ae_celeba_config.input_dim + ) def test_svae_encoding_decoding( self, ae_celeba_config, celeba_like_data, recon_layers @@ -1614,9 +1570,8 @@ def test_svae_encoding_decoding( assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_celeba_config.input_dim + ) else: if 1 in recon_layers: @@ -1642,12 +1597,14 @@ def test_svae_encoding_decoding( ) if -1 in recon_layers: - assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + assert ( + encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + ) assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_celeba_config.input_dim + ) def test_discriminator(self, ae_celeba_config, celeba_like_data, recon_layers): @@ -1692,11 +1649,10 @@ def test_discriminator(self, ae_celeba_config, celeba_like_data, recon_layers): class Test_CELEBA_ResNets: - @pytest.fixture(params=[[3, 4], [np.random.randint(1, 6)], [1, 2, 4, -1], None]) def recon_layers(self, request): return request.param - + def test_ae_encoding_decoding( self, ae_celeba_config, celeba_like_data, recon_layers ): @@ -1734,9 +1690,8 @@ def test_ae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_celeba_config.input_dim + ) else: if 1 in recon_layers: @@ -1766,11 +1721,13 @@ def test_ae_encoding_decoding( ) if -1 in recon_layers: - assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + ) + assert ( + decoder_recon[f"reconstruction"].shape[1:] + == ae_celeba_config.input_dim + ) def test_vae_encoding_decoding( self, ae_celeba_config, celeba_like_data, recon_layers @@ -1819,9 +1776,8 @@ def test_vae_encoding_decoding( encoder_embed["log_covariance"].shape[1] == ae_celeba_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_celeba_config.input_dim + ) else: if 1 in recon_layers: @@ -1851,14 +1807,17 @@ def test_vae_encoding_decoding( ) if -1 in recon_layers: - assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - encoder_embed["log_covariance"].shape[1] == ae_celeba_config.latent_dim + encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + ) + assert ( + encoder_embed["log_covariance"].shape[1] + == ae_celeba_config.latent_dim ) assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] + == ae_celeba_config.input_dim + ) def test_svae_encoding_decoding( self, ae_celeba_config, celeba_like_data, recon_layers @@ -1902,9 +1861,8 @@ def test_svae_encoding_decoding( assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_celeba_config.input_dim + ) else: if 1 in recon_layers: @@ -1934,13 +1892,14 @@ def test_svae_encoding_decoding( ) if -1 in recon_layers: - assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + assert ( + encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + ) assert encoder_embed["log_concentration"].shape[1] == 1 assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) - + decoder_recon[f"reconstruction"].shape[1:] + == ae_celeba_config.input_dim + ) def test_vqvae_encoding_decoding( self, ae_celeba_config, celeba_like_data, recon_layers @@ -1950,15 +1909,16 @@ def test_vqvae_encoding_decoding( embedding = encoder(celeba_like_data).embedding - assert embedding.shape[:2] == (celeba_like_data.shape[0], ae_celeba_config.latent_dim) + assert embedding.shape[:2] == ( + celeba_like_data.shape[0], + ae_celeba_config.latent_dim, + ) reconstruction = decoder(embedding).reconstruction assert reconstruction.shape == celeba_like_data.shape - encoder_embed = encoder( - celeba_like_data, output_layer_levels=recon_layers - ) + encoder_embed = encoder(celeba_like_data, output_layer_levels=recon_layers) decoder_recon = decoder(embedding, output_layer_levels=recon_layers) if recon_layers is None: @@ -1978,9 +1938,8 @@ def test_vqvae_encoding_decoding( if recon_layers is None: assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) + decoder_recon[f"reconstruction"].shape[1:] == ae_celeba_config.input_dim + ) else: if 1 in recon_layers: @@ -2010,8 +1969,10 @@ def test_vqvae_encoding_decoding( ) if -1 in recon_layers: - assert encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim assert ( - decoder_recon[f"reconstruction"].shape[1:] - == ae_celeba_config.input_dim - ) \ No newline at end of file + encoder_embed["embedding"].shape[1] == ae_celeba_config.latent_dim + ) + assert ( + decoder_recon[f"reconstruction"].shape[1:] + == ae_celeba_config.input_dim + ) diff --git a/tests/test_normal_sampler.py b/tests/test_normal_sampler.py index 11549f44..53e9901a 100644 --- a/tests/test_normal_sampler.py +++ b/tests/test_normal_sampler.py @@ -5,9 +5,9 @@ import pytest import torch -from pythae.models import AE, AEConfig, VAE, VAEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig +from pythae.models import AE, VAE, AEConfig, VAEConfig from pythae.pipelines.generation import GenerationPipeline +from pythae.samplers import NormalSampler, NormalSamplerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -24,6 +24,7 @@ def dummy_data(): def model(request): return request.param + @pytest.fixture( params=[ NormalSamplerConfig(), @@ -33,6 +34,7 @@ def model(request): def sampler_config(request): return request.param + @pytest.fixture() def sampler(model, sampler_config): return NormalSampler(model=model, sampler_config=sampler_config) @@ -129,16 +131,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -153,13 +158,14 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None diff --git a/tests/test_pipeline_standalone.py b/tests/test_pipeline_standalone.py index efd80898..9f1f17aa 100644 --- a/tests/test_pipeline_standalone.py +++ b/tests/test_pipeline_standalone.py @@ -1,20 +1,20 @@ -import pytest import os -import torch +import pytest +import torch from torch.utils.data import Dataset + +from pythae.customexception import DatasetError from pythae.data.datasets import DatasetOutput +from pythae.models import VAE, FactorVAE, FactorVAEConfig, VAEConfig from pythae.pipelines import * -from pythae.models import VAE, VAEConfig, FactorVAE, FactorVAEConfig -from pythae.trainers import BaseTrainerConfig from pythae.samplers import NormalSampler, NormalSamplerConfig -from pythae.customexception import DatasetError +from pythae.trainers import BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) class CustomWrongOutputDataset(Dataset): - def __init__(self, path) -> None: self.img_path = path @@ -23,23 +23,19 @@ def __len__(self): def __getitem__(self, index) -> dict: data = torch.load(self.img_path).data[index] - return DatasetOutput( - wrong_key=data - ) + return DatasetOutput(wrong_key=data) -class CustomNoLenDataset(Dataset): +class CustomNoLenDataset(Dataset): def __init__(self, path) -> None: self.img_path = path def __getitem__(self, index) -> dict: data = torch.load(self.img_path).data[index] - return DatasetOutput( - data=data - ) + return DatasetOutput(data=data) -class CustomDataset(Dataset): +class CustomDataset(Dataset): def __init__(self, path) -> None: self.img_path = path @@ -48,9 +44,7 @@ def __len__(self): def __getitem__(self, index) -> dict: data = torch.load(self.img_path).data[index] - return DatasetOutput( - data=data - ) + return DatasetOutput(data=data) class Test_Pipeline_Standalone: @@ -60,19 +54,27 @@ def train_dataset(self): @pytest.fixture def custom_train_dataset(self): - return CustomDataset(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")) + return CustomDataset( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ) @pytest.fixture def custom_wrong_output_train_dataset(self): - return CustomWrongOutputDataset(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")) + return CustomWrongOutputDataset( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ) @pytest.fixture def custom_no_len_train_dataset(self): - return CustomWrongOutputDataset(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")) + return CustomWrongOutputDataset( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ) @pytest.fixture def training_pipeline(self, train_dataset): - vae_config = VAEConfig(input_dim=tuple(train_dataset.data[0].shape), latent_dim=2) + vae_config = VAEConfig( + input_dim=tuple(train_dataset.data[0].shape), latent_dim=2 + ) vae = VAE(vae_config) pipe = TrainingPipeline(model=vae) return pipe @@ -96,42 +98,56 @@ def test_training_pipeline(self, tmpdir, training_pipeline, train_dataset): training_pipeline(train_dataset.data) assert isinstance(training_pipeline.model, VAE) + def test_training_pipeline_wrong_output_dataset( + self, + tmpdir, + training_pipeline, + train_dataset, + custom_wrong_output_train_dataset, + ): - def test_training_pipeline_wrong_output_dataset(self, tmpdir, training_pipeline, train_dataset, custom_wrong_output_train_dataset): - tmpdir.mkdir("dummy_folder") dir_path = os.path.join(tmpdir, "dummy_folder") training_pipeline.training_config.output_dir = dir_path training_pipeline.training_config.num_epochs = 1 - + with pytest.raises(DatasetError): training_pipeline(train_data=custom_wrong_output_train_dataset) - + with pytest.raises(DatasetError): - training_pipeline(train_data=train_dataset.data, eval_data=custom_wrong_output_train_dataset) + training_pipeline( + train_data=train_dataset.data, + eval_data=custom_wrong_output_train_dataset, + ) + def test_training_pipeline_no_len_dataset( + self, tmpdir, training_pipeline, train_dataset, custom_no_len_train_dataset + ): - def test_training_pipeline_no_len_dataset(self, tmpdir, training_pipeline, train_dataset, custom_no_len_train_dataset): - tmpdir.mkdir("dummy_folder") dir_path = os.path.join(tmpdir, "dummy_folder") training_pipeline.training_config.output_dir = dir_path training_pipeline.training_config.num_epochs = 1 - + with pytest.raises(DatasetError): training_pipeline(train_data=custom_no_len_train_dataset) - - with pytest.raises(DatasetError): - training_pipeline(train_data=train_dataset.data, eval_data=custom_no_len_train_dataset) + with pytest.raises(DatasetError): + training_pipeline( + train_data=train_dataset.data, eval_data=custom_no_len_train_dataset + ) - def test_training_pipleine_custom_dataset(self, tmpdir, training_pipeline, train_dataset, custom_train_dataset): + def test_training_pipleine_custom_dataset( + self, tmpdir, training_pipeline, train_dataset, custom_train_dataset + ): tmpdir.mkdir("dummy_folder") dir_path = os.path.join(tmpdir, "dummy_folder") training_pipeline.training_config.output_dir = dir_path training_pipeline.training_config.num_epochs = 1 - training_pipeline(train_data=custom_train_dataset, eval_data=custom_train_dataset) + training_pipeline( + train_data=custom_train_dataset, eval_data=custom_train_dataset + ) assert training_pipeline.trainer.train_dataset == custom_train_dataset assert training_pipeline.trainer.eval_dataset == custom_train_dataset @@ -139,21 +155,25 @@ def test_training_pipleine_custom_dataset(self, tmpdir, training_pipeline, train def test_generation_pipeline(self, tmpdir, train_dataset): with pytest.raises(NotImplementedError): - pipe = GenerationPipeline(model=VAE(VAEConfig(input_dim=(1, 2, 3))), sampler_config=BaseTrainerConfig()) - + pipe = GenerationPipeline( + model=VAE(VAEConfig(input_dim=(1, 2, 3))), + sampler_config=BaseTrainerConfig(), + ) + tmpdir.mkdir("dummy_folder") dir_path = os.path.join(tmpdir, "dummy_folder") pipe = GenerationPipeline(model=VAE(VAEConfig(input_dim=(1, 2, 3)))) assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=1, + gen_data = pipe( + num_samples=1, batch_size=10, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=train_dataset.data, - eval_data=None + eval_data=None, ) assert tuple(gen_data.shape) == (1,) + (1, 2, 3) diff --git a/tests/test_pixelcnn_sampler.py b/tests/test_pixelcnn_sampler.py index 894628fb..60f89501 100644 --- a/tests/test_pixelcnn_sampler.py +++ b/tests/test_pixelcnn_sampler.py @@ -1,13 +1,18 @@ import os +from copy import deepcopy import pytest import torch -from copy import deepcopy from pythae.models import VQVAE, VQVAEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, PixelCNNSampler, PixelCNNSamplerConfig -from pythae.trainers import BaseTrainerConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + NormalSampler, + NormalSamplerConfig, + PixelCNNSampler, + PixelCNNSamplerConfig, +) +from pythae.trainers import BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -61,7 +66,9 @@ def test_save_config(self, tmpdir, sampler): assert os.path.isfile(sampler_config_file) - generation_config_rec = PixelCNNSamplerConfig.from_json_file(sampler_config_file) + generation_config_rec = PixelCNNSamplerConfig.from_json_file( + sampler_config_file + ) assert generation_config_rec.__dict__ == sampler.sampler_config.__dict__ @@ -192,16 +199,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -216,13 +226,14 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, PixelCNNSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None diff --git a/tests/test_planar_flow.py b/tests/test_planar_flow.py index f4f726d7..5f7338a3 100644 --- a/tests/test_planar_flow.py +++ b/tests/test_planar_flow.py @@ -1,21 +1,16 @@ -import pytest import os -import torch -import numpy as np -import shutil - from copy import deepcopy -from torch.optim import Adam -from pythae.models.base.base_utils import ModelOutput -from pythae.models.normalizing_flows import PlanarFlow, PlanarFlowConfig -from pythae.models.normalizing_flows import NFModel +import numpy as np +import pytest +import torch + from pythae.data.datasets import BaseDataset from pythae.models import AutoModel - - -from pythae.trainers import BaseTrainer, BaseTrainerConfig +from pythae.models.base.base_utils import ModelOutput +from pythae.models.normalizing_flows import NFModel, PlanarFlow, PlanarFlowConfig from pythae.pipelines import TrainingPipeline +from pythae.trainers import BaseTrainer, BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -79,7 +74,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -195,31 +192,24 @@ def prior(self, model_configs, request): torch.eye(np.prod(model_configs.input_dim)).to(device), ) - @pytest.fixture(params=[Adam]) - def optimizers(self, request, planar_flow, training_configs): - if request.param is not None: - optimizer = request.param( - planar_flow.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_planar_flow_train_step( - self, planar_flow, prior, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, planar_flow, prior, train_dataset, training_configs): nf_model = NFModel(prior=prior, flow=planar_flow) trainer = BaseTrainer( model=nf_model, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_planar_flow_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -234,19 +224,7 @@ def test_planar_flow_train_step( ] ) - def test_planar_flow_eval_step( - self, planar_flow, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=planar_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_planar_flow_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -262,18 +240,7 @@ def test_planar_flow_eval_step( ] ) - def test_planar_flow_main_train_loop( - self, planar_flow, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=planar_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_planar_flow_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -289,21 +256,10 @@ def test_planar_flow_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, planar_flow, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=planar_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -368,23 +324,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, planar_flow, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=planar_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model.flow) trainer.train() @@ -418,21 +363,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, planar_flow, prior, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=planar_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model.flow) @@ -464,7 +398,7 @@ def test_final_model_saving( ) def test_planar_flow_training_pipeline( - self, tmpdir, planar_flow, prior, train_dataset, training_configs + self, planar_flow, prior, train_dataset, training_configs ): dir_path = training_configs.output_dir diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 1d25794b..ec7aec0f 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -4,8 +4,8 @@ import pytest import torch -from pythae.data.preprocessors import DataProcessor from pythae.data.datasets import BaseDataset +from pythae.data.preprocessors import DataProcessor PATH = os.path.dirname(os.path.abspath(__file__)) diff --git a/tests/test_pvae_sampler.py b/tests/test_pvae_sampler.py index 1333feac..dcd75db1 100644 --- a/tests/test_pvae_sampler.py +++ b/tests/test_pvae_sampler.py @@ -5,8 +5,13 @@ import torch from pythae.models import PoincareVAE, PoincareVAEConfig -from pythae.samplers import PoincareDiskSampler, PoincareDiskSamplerConfig, NormalSampler, NormalSamplerConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + NormalSampler, + NormalSamplerConfig, + PoincareDiskSampler, + PoincareDiskSamplerConfig, +) PATH = os.path.dirname(os.path.abspath(__file__)) @@ -19,13 +24,28 @@ def dummy_data(): @pytest.fixture( params=[ - PoincareVAE(PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=7, prior_distribution="wrapped_normal", curvature=0.2)), - PoincareVAE(PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=2, prior_distribution="riemannian_normal", curvature=0.7)) + PoincareVAE( + PoincareVAEConfig( + input_dim=(1, 28, 28), + latent_dim=7, + prior_distribution="wrapped_normal", + curvature=0.2, + ) + ), + PoincareVAE( + PoincareVAEConfig( + input_dim=(1, 28, 28), + latent_dim=2, + prior_distribution="riemannian_normal", + curvature=0.7, + ) + ), ] ) def model(request): return request.param + @pytest.fixture( params=[ PoincareDiskSamplerConfig(), @@ -58,7 +78,9 @@ def test_save_config(self, tmpdir, sampler): assert os.path.isfile(sampler_config_file) - generation_config_rec = PoincareDiskSamplerConfig.from_json_file(sampler_config_file) + generation_config_rec = PoincareDiskSamplerConfig.from_json_file( + sampler_config_file + ) assert generation_config_rec.__dict__ == sampler.sampler_config.__dict__ @@ -142,16 +164,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -166,13 +191,14 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, PoincareDiskSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None diff --git a/tests/test_radial_flow.py b/tests/test_radial_flow.py index 23c420ba..1b9f9cda 100644 --- a/tests/test_radial_flow.py +++ b/tests/test_radial_flow.py @@ -1,21 +1,18 @@ -import pytest import os -import torch -import numpy as np import shutil - from copy import deepcopy + +import numpy as np +import pytest +import torch from torch.optim import Adam -from pythae.models.base.base_utils import ModelOutput -from pythae.models.normalizing_flows import RadialFlow, RadialFlowConfig -from pythae.models.normalizing_flows import NFModel from pythae.data.datasets import BaseDataset from pythae.models import AutoModel - - -from pythae.trainers import BaseTrainer, BaseTrainerConfig +from pythae.models.base.base_utils import ModelOutput +from pythae.models.normalizing_flows import NFModel, RadialFlow, RadialFlowConfig from pythae.pipelines import TrainingPipeline +from pythae.trainers import BaseTrainer, BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -73,7 +70,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -189,31 +188,24 @@ def prior(self, model_configs, request): torch.eye(np.prod(model_configs.input_dim)).to(device), ) - @pytest.fixture(params=[Adam]) - def optimizers(self, request, radial_flow, training_configs): - if request.param is not None: - optimizer = request.param( - radial_flow.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_radial_flow_train_step( - self, radial_flow, prior, train_dataset, training_configs, optimizers - ): + @pytest.fixture + def trainer(self, radial_flow, prior, train_dataset, training_configs): nf_model = NFModel(prior=prior, flow=radial_flow) trainer = BaseTrainer( model=nf_model, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_radial_flow_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -228,19 +220,7 @@ def test_radial_flow_train_step( ] ) - def test_radial_flow_eval_step( - self, radial_flow, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=radial_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_radial_flow_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -256,18 +236,7 @@ def test_radial_flow_eval_step( ] ) - def test_radial_flow_main_train_loop( - self, radial_flow, prior, train_dataset, training_configs, optimizers - ): - - nf_model = NFModel(prior=prior, flow=radial_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_radial_flow_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -283,21 +252,10 @@ def test_radial_flow_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, radial_flow, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=radial_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -362,23 +320,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, radial_flow, prior, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=radial_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model.flow) trainer.train() @@ -412,21 +359,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, radial_flow, prior, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, trainer, training_configs): dir_path = training_configs.output_dir - nf_model = NFModel(prior=prior, flow=radial_flow) - - trainer = BaseTrainer( - model=nf_model, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model.flow) diff --git a/tests/test_rae_gp.py b/tests/test_rae_gp.py index a2b1a61c..fb84a3be 100644 --- a/tests/test_rae_gp.py +++ b/tests/test_rae_gp.py @@ -3,14 +3,18 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop from pythae.customexception import BadInheritanceError +from pythae.models import RAE_GP, AutoModel, RAE_GP_Config from pythae.models.base.base_utils import ModelOutput -from pythae.models import RAE_GP, RAE_GP_Config, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, +) from pythae.trainers import BaseTrainer, BaseTrainerConfig -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_AE_Conv, @@ -118,7 +122,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -208,7 +214,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -291,14 +297,15 @@ def test_model_train_output(self, rae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -313,23 +320,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return RAE_GP(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -340,14 +354,12 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return RAE_GP(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape - @pytest.mark.slow class Test_RAE_GP_Training: @pytest.fixture @@ -393,26 +405,21 @@ def rae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, rae, training_configs): - if request.param is not None: - optimizer = request.param( - rae.parameters(), lr=training_configs.learning_rate - ) - - else: - optimizer = None - - return optimizer - - def test_rae_train_step(self, rae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, rae, train_dataset, training_configs): trainer = BaseTrainer( model=rae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - optimizer=optimizers, ) + trainer.prepare_training() + + return trainer + + def test_rae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -427,14 +434,7 @@ def test_rae_train_step(self, rae, train_dataset, training_configs, optimizers): ] ) - def test_rae_eval_step(self, rae, train_dataset, training_configs, optimizers): - trainer = BaseTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_rae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -450,16 +450,7 @@ def test_rae_eval_step(self, rae, train_dataset, training_configs, optimizers): ] ) - def test_rae_predict_step( - self, rae, train_dataset, training_configs, optimizers - ): - trainer = BaseTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_rae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -475,21 +466,11 @@ def test_rae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_rae_main_train_loop( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): - - trainer = BaseTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) + def test_rae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -505,19 +486,10 @@ def test_rae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, rae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -599,21 +571,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, rae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - model = deepcopy(trainer.model) trainer.train() @@ -661,19 +624,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, rae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = BaseTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - optimizer=optimizers, - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -721,7 +675,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_configs): + def test_rae_training_pipeline(self, rae, train_dataset, training_configs): dir_path = training_configs.output_dir @@ -781,10 +735,13 @@ def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_RAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -795,7 +752,7 @@ def ae_model(self): NormalSamplerConfig(), GaussianMixtureSamplerConfig(), MAFSamplerConfig(), - IAFSamplerConfig() + IAFSamplerConfig(), ] ) def sampler_configs(self, request): @@ -810,7 +767,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_rae_l2.py b/tests/test_rae_l2.py index be5c45cb..fc023346 100644 --- a/tests/test_rae_l2.py +++ b/tests/test_rae_l2.py @@ -6,15 +6,20 @@ from torch.optim import Adam from pythae.customexception import BadInheritanceError +from pythae.models import RAE_L2, AutoModel, RAE_L2_Config from pythae.models.base.base_utils import ModelOutput -from pythae.models import RAE_L2, RAE_L2_Config, AutoModel -from pythae.samplers import NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, IAFSamplerConfig +from pythae.pipelines import GenerationPipeline, TrainingPipeline +from pythae.samplers import ( + GaussianMixtureSamplerConfig, + IAFSamplerConfig, + MAFSamplerConfig, + NormalSamplerConfig, +) from pythae.trainers import ( + BaseTrainerConfig, CoupledOptimizerTrainer, CoupledOptimizerTrainerConfig, - BaseTrainerConfig, ) -from pythae.pipelines import TrainingPipeline, GenerationPipeline from tests.data.custom_architectures import ( Decoder_AE_Conv, Encoder_AE_Conv, @@ -122,7 +127,9 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "environment.json"] + ) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -212,7 +219,7 @@ def test_full_custom_model_saving( "model.pt", "encoder.pkl", "decoder.pkl", - "environment.json" + "environment.json", ] ) @@ -288,30 +295,35 @@ def test_model_train_output(self, rae, demo_data): assert isinstance(out, ModelOutput) - assert set([ - "loss", - "recon_loss", - "encoder_loss", - "decoder_loss", - "update_encoder", - "update_decoder", - "embedding_loss", - "recon_x", - "z"]) == set( - out.keys() + assert ( + set( + [ + "loss", + "recon_loss", + "encoder_loss", + "decoder_loss", + "update_encoder", + "update_decoder", + "embedding_loss", + "recon_x", + "z", + ] + ) + == set(out.keys()) ) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape + class Test_Model_interpolate: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -326,23 +338,30 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return RAE_L2(model_configs) - def test_interpolate(self, ae, demo_data, granularity): with pytest.raises(AssertionError): ae.interpolate(demo_data, demo_data[1:], granularity) interp = ae.interpolate(demo_data, demo_data, granularity) - assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + assert ( + tuple(interp.shape) + == ( + demo_data.shape[0], + granularity, + ) + + (demo_data.shape[1:]) + ) + class Test_Model_reconstruct: @pytest.fixture( params=[ torch.randn(3, 2, 3, 1), torch.randn(3, 2, 2), - torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ - : - ]['data'] + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ + "data" + ], ] ) def demo_data(self, request): @@ -353,9 +372,8 @@ def ae(self, model_configs, demo_data): model_configs.input_dim = tuple(demo_data[0].shape) return RAE_L2(model_configs) - def test_reconstruct(self, ae, demo_data): - + recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -371,10 +389,17 @@ def train_dataset(self): CoupledOptimizerTrainerConfig( num_epochs=3, steps_saving=2, - learning_rate=1e-5, - encoder_optim_decay=1e-3, - decoder_optim_decay=1e-3, - ) + encoder_learning_rate=1e-5, + decoder_learning_rate=1e-6, + ), + CoupledOptimizerTrainerConfig( + num_epochs=3, + steps_saving=2, + encoder_learning_rate=1e-5, + decoder_learning_rate=1e-6, + encoder_optimizer_cls="AdamW", + decoder_optimizer_cls="SGD", + ), ] ) def training_configs(self, tmpdir, request): @@ -413,31 +438,21 @@ def rae(self, model_configs, custom_encoder, custom_decoder, request): return model - @pytest.fixture(params=[Adam]) - def optimizers(self, request, rae, training_configs): - if request.param is not None: - encoder_optimizer = request.param( - rae.encoder.parameters(), lr=training_configs.learning_rate - ) - decoder_optimizer = request.param( - rae.decoder.parameters(), lr=training_configs.learning_rate - ) - - else: - encoder_optimizer = None - decoder_optimizer = None - - return (encoder_optimizer, decoder_optimizer) - - def test_rae_train_step(self, rae, train_dataset, training_configs, optimizers): + @pytest.fixture + def trainer(self, rae, train_dataset, training_configs): trainer = CoupledOptimizerTrainer( model=rae, train_dataset=train_dataset, + eval_dataset=train_dataset, training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], ) + trainer.prepare_training() + + return trainer + + def test_rae_train_step(self, trainer): + start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -452,15 +467,7 @@ def test_rae_train_step(self, rae, train_dataset, training_configs, optimizers): ] ) - def test_rae_eval_step(self, rae, train_dataset, training_configs, optimizers): - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) + def test_rae_eval_step(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -476,17 +483,7 @@ def test_rae_eval_step(self, rae, train_dataset, training_configs, optimizers): ] ) - def test_rae_predict_step( - self, rae, train_dataset, training_configs, optimizers - ): - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) + def test_rae_predict_step(self, trainer, train_dataset): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -502,22 +499,11 @@ def test_rae_predict_step( ] ) - assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert inputs.cpu() in train_dataset.data assert recon.shape == inputs.shape - assert generated.shape == inputs.shape + assert generated.shape == inputs.shape - def test_rae_main_train_loop( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): - - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - eval_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) + def test_rae_main_train_loop(self, trainer): start_model_state_dict = deepcopy(trainer.model.state_dict()) @@ -533,20 +519,10 @@ def test_rae_main_train_loop( ] ) - def test_checkpoint_saving( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving(self, rae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) - # Make a training step step_1_loss = trainer.train_step(epoch=1) @@ -660,22 +636,12 @@ def test_checkpoint_saving( ] ) - def test_checkpoint_saving_during_training( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_checkpoint_saving_during_training(self, rae, trainer, training_configs): # target_saving_epoch = training_configs.steps_saving dir_path = training_configs.output_dir - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) - model = deepcopy(trainer.model) trainer.train() @@ -728,20 +694,10 @@ def test_checkpoint_saving_during_training( ] ) - def test_final_model_saving( - self, tmpdir, rae, train_dataset, training_configs, optimizers - ): + def test_final_model_saving(self, rae, trainer, training_configs): dir_path = training_configs.output_dir - trainer = CoupledOptimizerTrainer( - model=rae, - train_dataset=train_dataset, - training_config=training_configs, - encoder_optimizer=optimizers[0], - decoder_optimizer=optimizers[1], - ) - trainer.train() model = deepcopy(trainer._best_model) @@ -789,7 +745,7 @@ def test_final_model_saving( assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) - def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_configs): + def test_rae_training_pipeline(self, rae, train_dataset, training_configs): with pytest.raises(AssertionError): pipeline = TrainingPipeline(model=rae, training_config=BaseTrainerConfig()) @@ -808,7 +764,6 @@ def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_config ) # check decays are set accordingly to model params - assert pipeline.trainer.encoder_optimizer.param_groups[0]["weight_decay"] == 0.0 assert ( pipeline.trainer.decoder_optimizer.param_groups[0]["weight_decay"] == rae.model_config.reg_weight @@ -859,10 +814,13 @@ def test_rae_training_pipeline(self, tmpdir, rae, train_dataset, training_config assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + class Test_RAE_Generation: @pytest.fixture def train_data(self): - return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + return torch.load( + os.path.join(PATH, "data/mnist_clean_train_dataset_sample") + ).data @pytest.fixture() def ae_model(self): @@ -888,7 +846,7 @@ def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data return_gen=True, train_data=train_data, eval_data=train_data, - training_config=BaseTrainerConfig(num_epochs=1) + training_config=BaseTrainerConfig(num_epochs=1), ) - assert gen_data.shape[0] == 11 \ No newline at end of file + assert gen_data.shape[0] == 11 diff --git a/tests/test_rhvae_sampler.py b/tests/test_rhvae_sampler.py index 2949b45a..b662b29d 100644 --- a/tests/test_rhvae_sampler.py +++ b/tests/test_rhvae_sampler.py @@ -5,8 +5,13 @@ import torch from pythae.models import RHVAE, RHVAEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, RHVAESampler, RHVAESamplerConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + NormalSampler, + NormalSamplerConfig, + RHVAESampler, + RHVAESamplerConfig, +) PATH = os.path.dirname(os.path.abspath(__file__)) @@ -30,7 +35,7 @@ def model(request): params=[ RHVAESamplerConfig(n_lf=1, mcmc_steps_nbr=2, eps_lf=0.00001), RHVAESamplerConfig(n_lf=3, mcmc_steps_nbr=1, eps_lf=0.001), - None + None, ] ) def sampler_config(request): @@ -137,16 +142,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -161,15 +169,16 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, RHVAESampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None assert "sampler_config.json" not in os.listdir(dir_path) - assert len(os.listdir(dir_path)) == num_samples \ No newline at end of file + assert len(os.listdir(dir_path)) == num_samples diff --git a/tests/test_training_callbacks.py b/tests/test_training_callbacks.py index 1f7729d8..68bb2b17 100644 --- a/tests/test_training_callbacks.py +++ b/tests/test_training_callbacks.py @@ -1,20 +1,21 @@ +import os +from collections import Counter + import pytest import torch -import os -from pythae.trainers.training_callbacks import * -from collections import Counter -from pythae.trainers import * from pythae.models import ( AE, - AEConfig, RAE_L2, - RAE_L2_Config, + VAEGAN, Adversarial_AE, Adversarial_AE_Config, - VAEGAN, + AEConfig, + RAE_L2_Config, VAEGANConfig, ) +from pythae.trainers import * +from pythae.trainers.training_callbacks import * PATH = os.path.dirname(os.path.abspath(__file__)) @@ -110,9 +111,7 @@ def init_callback(): @pytest.fixture() def dummy_handler(init_callback): - return CallbackHandler( - callbacks=[init_callback], model=None, optimizer=None, scheduler=None - ) + return CallbackHandler(callbacks=[init_callback], model=None) @pytest.fixture(params=[MetricConsolePrinterCallback(), CustomCallback()]) @@ -150,6 +149,8 @@ def test_init_trainer(self, ae, train_dataset, callbacks): trainer = BaseTrainer(model=ae, train_dataset=train_dataset) + trainer.prepare_training() + assert callbacks not in trainer.callback_handler.callbacks assert ProgressBarCallback().__class__ in [ @@ -160,6 +161,8 @@ def test_init_trainer(self, ae, train_dataset, callbacks): model=ae, callbacks=[callbacks], train_dataset=train_dataset ) + trainer.prepare_training() + assert callbacks in trainer.callback_handler.callbacks assert ProgressBarCallback().__class__ in [ @@ -235,6 +238,8 @@ def test_trainer_callback_calls( callbacks=[dummy_callback], ) + trainer.prepare_training() + assert "on_train_step_begin" not in dummy_callback.step_list assert "on_train_step_end" not in dummy_callback.step_list trainer.train_step(epoch=1) diff --git a/tests/test_two_stage_sampler.py b/tests/test_two_stage_sampler.py index 9d9d7013..dc757b08 100644 --- a/tests/test_two_stage_sampler.py +++ b/tests/test_two_stage_sampler.py @@ -1,14 +1,19 @@ import os +from copy import deepcopy import numpy as np import pytest import torch -from copy import deepcopy -from pythae.models import VAE, VAEConfig, VAMP, VAMPConfig, AE, AEConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, TwoStageVAESampler, TwoStageVAESamplerConfig -from pythae.trainers import BaseTrainerConfig +from pythae.models import AE, VAE, VAMP, AEConfig, VAEConfig, VAMPConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + NormalSampler, + NormalSamplerConfig, + TwoStageVAESampler, + TwoStageVAESamplerConfig, +) +from pythae.trainers import BaseTrainerConfig PATH = os.path.dirname(os.path.abspath(__file__)) @@ -208,16 +213,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -232,13 +240,14 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, TwoStageVAESampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None diff --git a/tests/test_vamp_sampler.py b/tests/test_vamp_sampler.py index 9e568c2d..55089c88 100644 --- a/tests/test_vamp_sampler.py +++ b/tests/test_vamp_sampler.py @@ -5,8 +5,13 @@ import torch from pythae.models import VAMP, VAMPConfig -from pythae.samplers import NormalSampler, NormalSamplerConfig, VAMPSampler, VAMPSamplerConfig from pythae.pipelines import GenerationPipeline +from pythae.samplers import ( + NormalSampler, + NormalSamplerConfig, + VAMPSampler, + VAMPSamplerConfig, +) PATH = os.path.dirname(os.path.abspath(__file__)) @@ -26,6 +31,7 @@ def dummy_data(): def model(request): return request.param + @pytest.fixture( params=[ VAMPSamplerConfig(), @@ -142,16 +148,19 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, NormalSampler) assert pipe.sampler.sampler_config == NormalSamplerConfig() - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=True, save_sampler_config=True, train_data=dummy_data, - eval_data=None + eval_data=None, ) - assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert tuple(gen_data.shape) == (num_samples,) + tuple( + model.model_config.input_dim + ) assert len(os.listdir(dir_path)) == num_samples + 1 assert "sampler_config.json" in os.listdir(dir_path) @@ -166,13 +175,14 @@ def test_generation_pipeline( assert isinstance(pipe.sampler, VAMPSampler) assert pipe.sampler.sampler_config == sampler_config - gen_data = pipe(num_samples=num_samples, + gen_data = pipe( + num_samples=num_samples, batch_size=batch_size, output_dir=dir_path, return_gen=False, save_sampler_config=False, train_data=dummy_data, - eval_data=dummy_data + eval_data=dummy_data, ) assert gen_data is None