Skip to content

Commit

Permalink
Update data argument in predict_dist function
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Jul 20, 2023
1 parent 01bb9c6 commit 7b0cb02
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions lightgbmlss/distributions/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def draw_samples(self,

def predict_dist(self,
booster: lgb.Booster,
test_set: pd.DataFrame,
data: pd.DataFrame,
start_values: np.ndarray,
pred_type: str = "parameters",
n_samples: int = 1000,
Expand All @@ -350,8 +350,8 @@ def predict_dist(self,
---------
booster : lgb.Booster
Trained model.
test_set : pd.DataFrame
Test data.
data : pd.DataFrame
Data to predict from.
start_values : np.ndarray.
Starting values for each distributional parameter.
pred_type : str
Expand All @@ -374,13 +374,13 @@ def predict_dist(self,
"""

predt = torch.tensor(
booster.predict(test_set, raw_score=True),
booster.predict(data, raw_score=True),
dtype=torch.float32
).reshape(-1, self.n_dist_param)

# Set init_score as starting point for each distributional parameter.
init_score_pred = torch.tensor(
np.ones(shape=(test_set.shape[0], 1))*start_values,
np.ones(shape=(data.shape[0], 1))*start_values,
dtype=torch.float32
)

Expand Down

0 comments on commit 7b0cb02

Please sign in to comment.