Skip to content

Commit

Permalink
Added checks for function arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Jul 21, 2023
1 parent 3b2297c commit 797bcb8
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions lightgbmlss/distributions/SplineFlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,31 @@ def __init__(self,
loss_fn: str = "nll"
):

# Check if stabilization method is valid.
if not isinstance(stabilization, str):
raise ValueError("stabilization must be a string.")
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Options are 'None', 'MAD' or 'L2'.")

# Check if loss function is valid.
if not isinstance(loss_fn, str):
raise ValueError("loss_fn must be a string.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss_fn. Options are 'nll' or 'crps'.")

# Number of parameters
if not isinstance(order, str):
raise ValueError("order must be a string.")
if order == "quadratic":
n_params = 2*count_bins + (count_bins-1)
elif order == "linear":
n_params = 3*count_bins + (count_bins-1)

# Parameter dictionary
param_dict = {f"param_{i+1}": identity_fn for i in range(n_params)}
torch.distributions.Distribution.set_default_validate_args(False)
else:
raise ValueError("Invalid order specification. Options are 'linear' or 'quadratic'.")

# Specify Target Transform
if not isinstance(target_support, str):
raise ValueError("target_support must be a string.")
if target_support == "real":
target_transform = identity_transform
discrete = False
Expand All @@ -82,6 +96,22 @@ def __init__(self,
elif target_support == "unit_interval":
target_transform = SigmoidTransform()
discrete = False
else:
raise ValueError("Invalid target_support. Options are 'real', 'positive', 'positive_integer' or 'unit_interval'.")

# Check if count_bins is valid
if not isinstance(count_bins, int):
raise ValueError("count_bins must be an integer.")
if count_bins <= 0:
raise ValueError("count_bins must be a positive integer > 0.")

# Check if bound is float
if not isinstance(bound, float):
raise ValueError("bound must be a float.")

# Specify parameter dictionary
param_dict = {f"param_{i + 1}": identity_fn for i in range(n_params)}
torch.distributions.Distribution.set_default_validate_args(False)

# Specify Normalizing Flow Class
super().__init__(base_dist=Normal, # Base distribution, currently only Normal is supported.
Expand Down

0 comments on commit 797bcb8

Please sign in to comment.