Skip to content

Commit

Permalink
Added support for Normalizing Flows
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Jul 20, 2023
1 parent 3c33551 commit f0216dd
Show file tree
Hide file tree
Showing 4 changed files with 2,490 additions and 1 deletion.
1,650 changes: 1,650 additions & 0 deletions examples/simulation_example_SplineFlow.ipynb

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions lightgbmlss/distributions/SplineFlow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
from torch.distributions import identity_transform, SigmoidTransform, SoftplusTransform
from pyro.distributions import Normal
from pyro.distributions.transforms import Spline
from .flow_utils import NormalizingFlowClass
from ..utils import identity_fn


class SplineFlow(NormalizingFlowClass):
"""
Spline Flow class.
The spline flow is a normalizing flow based on element-wise rational spline bijections of linear and quadratic
order (Durkan et al., 2019; Dolatabadi et al., 2020). Rational splines are functions that are comprised of segments
that are the ratio of two polynomials. Rational splines offer an excellent combination of functional flexibility
whilst maintaining a numerically stable inverse.
For more details, see:
- Durkan, C., Bekasov, A., Murray, I. and Papamakarios, G. Neural Spline Flows. NeurIPS 2019.
- Dolatabadi, H. M., Erfani, S. and Leckie, C., Invertible Generative Modeling using Linear Rational Splines. AISTATS 2020.
Source
---------
https://docs.pyro.ai/en/stable/distributions.html#pyro.distributions.transforms.Spline
Arguments
---------
target_support: str
The target support. Options are
- "real": [-inf, inf]
- "positive": [0, inf]
- "positive_integer": [0, 1, 2, 3, ...]
- "unit_interval": [0, 1]
count_bins: int
The number of segments comprising the spline.
bound: float
The quantity "K" determining the bounding box, [-K,K] x [-K,K] of the spline. By adjusting the
"K" value, you can control the size of the bounding box and consequently control the range of inputs that
the spline transform operates on. Larger values of "K" will result in a wider valid range for the spline
transformation, while smaller values will restrict the valid range to a smaller region. Should be chosen
based on the range of the data.
order: str
The order of the spline. Options are "linear" or "quadratic".
stabilization: str
Stabilization method for the Gradient and Hessian. Options are "None", "MAD" or "L2".
loss_fn: str
Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score).
Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable.
Hence, using the CRPS disregards any variation in the curvature of the loss function.
"""
def __init__(self,
target_support: str = "real",
count_bins: int = 8,
bound: float = 3.0,
order: str = "linear",
stabilization: str = "None",
loss_fn: str = "nll"
):

# Number of parameters
if order == "quadratic":
n_params = 2*count_bins + (count_bins-1)
elif order == "linear":
n_params = 3*count_bins + (count_bins-1)

# Parameter dictionary
param_dict = {f"param_{i+1}": identity_fn for i in range(n_params)}
torch.distributions.Distribution.set_default_validate_args(False)

# Specify Target Transform
if target_support == "real":
target_transform = identity_transform
discrete = False
elif target_support == "positive":
target_transform = SoftplusTransform()
discrete = False
elif target_support == "positive_integer":
target_transform = SoftplusTransform()
discrete = True
elif target_support == "unit_interval":
target_transform = SigmoidTransform()
discrete = False

# Specify Normalizing Flow Class
super().__init__(base_dist=Normal, # Base distribution, currently only Normal is supported.
flow_transform=Spline,
count_bins=count_bins,
bound=bound,
order=order,
n_dist_param=n_params,
param_dict=param_dict,
target_transform=target_transform,
discrete=discrete,
univariate=True,
stabilization=stabilization,
loss_fn=loss_fn
)
4 changes: 3 additions & 1 deletion lightgbmlss/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""LightGBMLSS - An extension of LightGBM to probabilistic forecasting"""

from . import distribution_utils
from . import flow_utils
from . import zero_inflated
from . import Gaussian
from . import StudentT
Expand All @@ -18,4 +19,5 @@
from . import ZINB
from . import ZAGamma
from . import ZABeta
from . import ZALN
from . import ZALN
from . import SplineFlow
Loading

0 comments on commit f0216dd

Please sign in to comment.