diff --git a/lightgbmlss/distributions/SplineFlow.py b/lightgbmlss/distributions/SplineFlow.py index c9f08fe..4bb18ab 100644 --- a/lightgbmlss/distributions/SplineFlow.py +++ b/lightgbmlss/distributions/SplineFlow.py @@ -59,17 +59,31 @@ def __init__(self, loss_fn: str = "nll" ): + # Check if stabilization method is valid. + if not isinstance(stabilization, str): + raise ValueError("stabilization must be a string.") + if stabilization not in ["None", "MAD", "L2"]: + raise ValueError("Invalid stabilization method. Options are 'None', 'MAD' or 'L2'.") + + # Check if loss function is valid. + if not isinstance(loss_fn, str): + raise ValueError("loss_fn must be a string.") + if loss_fn not in ["nll", "crps"]: + raise ValueError("Invalid loss_fn. Options are 'nll' or 'crps'.") + # Number of parameters + if not isinstance(order, str): + raise ValueError("order must be a string.") if order == "quadratic": n_params = 2*count_bins + (count_bins-1) elif order == "linear": n_params = 3*count_bins + (count_bins-1) - - # Parameter dictionary - param_dict = {f"param_{i+1}": identity_fn for i in range(n_params)} - torch.distributions.Distribution.set_default_validate_args(False) + else: + raise ValueError("Invalid order specification. Options are 'linear' or 'quadratic'.") # Specify Target Transform + if not isinstance(target_support, str): + raise ValueError("target_support must be a string.") if target_support == "real": target_transform = identity_transform discrete = False @@ -82,6 +96,22 @@ def __init__(self, elif target_support == "unit_interval": target_transform = SigmoidTransform() discrete = False + else: + raise ValueError("Invalid target_support. Options are 'real', 'positive', 'positive_integer' or 'unit_interval'.") + + # Check if count_bins is valid + if not isinstance(count_bins, int): + raise ValueError("count_bins must be an integer.") + if count_bins <= 0: + raise ValueError("count_bins must be a positive integer > 0.") + + # Check if bound is float + if not isinstance(bound, float): + raise ValueError("bound must be a float.") + + # Specify parameter dictionary + param_dict = {f"param_{i + 1}": identity_fn for i in range(n_params)} + torch.distributions.Distribution.set_default_validate_args(False) # Specify Normalizing Flow Class super().__init__(base_dist=Normal, # Base distribution, currently only Normal is supported.