Skip to content

Commit

Permalink
Added distribution_arg_names as function input
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Nov 29, 2023
1 parent e6639f5 commit 795f8c7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions lightgbmlss/distributions/SplineFlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(self,
order=order,
n_dist_param=n_params,
param_dict=param_dict,
distribution_arg_names=list(param_dict.keys()),
target_transform=target_transform,
discrete=discrete,
univariate=True,
Expand Down
4 changes: 4 additions & 0 deletions lightgbmlss/distributions/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class NormalizingFlowClass:
Number of parameters.
param_dict: Dict[str, Any]
Dictionary that maps parameters to their response scale.
distribution_arg_names: List
List of distributional parameter names.
target_transform: Transform
Specify the target transform.
discrete: bool
Expand All @@ -61,6 +63,7 @@ def __init__(self,
order: Optional[str] = "quadratic",
n_dist_param: int = None,
param_dict: Dict[str, Any] = None,
distribution_arg_names: List = None,
target_transform: Transform = None,
discrete: bool = False,
univariate: bool = True,
Expand All @@ -75,6 +78,7 @@ def __init__(self,
self.order = order
self.n_dist_param = n_dist_param
self.param_dict = param_dict
self.distribution_arg_names = distribution_arg_names
self.target_transform = target_transform
self.discrete = discrete
self.univariate = univariate
Expand Down

0 comments on commit 795f8c7

Please sign in to comment.