Skip to content

Commit

Permalink
update the test to check all response functions
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonRobertPike committed Sep 24, 2024
1 parent 6f24901 commit 023211d
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions tests/test_model/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd

from lightgbmlss.model import *
from lightgbmlss.distributions.Gaussian import *
Expand Down Expand Up @@ -184,12 +185,17 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para
lgblss.dist.n_dist_param * lgblss.dist.n_dist_param * (X_test.shape[1] + 1)
)

for key, func in lgblss.dist.param_dict.items():
if func == identity_fn:
assert np.allclose(
pred_contributions.xs(key, level="distribution_arg", axis=1).sum(axis=1),
pred_params[key], atol=1e-5
)
for key, response_func in lgblss.dist.param_dict.items():
pred_contributions_combined = (
pd.Series(response_func(
torch.tensor(
pred_contributions.xs(key, level="distribution_arg", axis=1).sum(axis=1).values)
)))
assert np.allclose(
pred_contributions_combined,
pred_params[key], atol=1e-5
)


def test_model_plot(self, univariate_data, univariate_lgblss, univariate_params):
# Unpack
Expand Down

0 comments on commit 023211d

Please sign in to comment.