Probaforms
is a python library of conditional Generative Adversarial Networks, Normalizing Flows, Variational Autoencoders and other generative models for tabular data. All models have a sklearn-like interface to enable rapid use in a variety of science and engineering applications.
- Variational Autoencoder (CVAE)
- Wasserstein GAN (WGAN)
- Real NVP
pip install probaforms
or
git clone https://github.com/HSE-LAMBDA/probaforms.git
cd probaforms
pip install -e .
or
poetry install
(See more examples in the documentation.)
The following code snippet generates a noisy synthetic data, fits a conditional generative model, sample new objects, and displays the results.
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
from probaforms.models import RealNVP
# generate sample X with conditions C
X, y = make_moons(n_samples=1000, noise=0.1)
C = y.reshape(-1, 1)
# fit nomalizing flow model
model = RealNVP(lr=0.01, n_epochs=100)
model.fit(X, C)
# sample new objects
X_gen = model.sample(C)
# display the results
plt.scatter(X_gen[y==0, 0], X_gen[y==0, 1])
plt.scatter(X_gen[y==1, 0], X_gen[y==1, 1])
plt.show()
- Home: https://github.com/HSE-LAMBDA/probaforms
- Documentation: https://hse-lambda.github.io/probaforms
- For any usage questions, suggestions and bugs use the issue page, please.