Skip to content

Commit

Permalink
Update loss calculation in metric_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Mar 14, 2024
1 parent e466fb5 commit 3ba04b7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion lightgbmlss/distributions/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def metric_fn(self, predt: np.ndarray, data: lgb.Dataset) -> Tuple[str, float, b
"""
# Target
target = torch.tensor(data.get_label().reshape(-1, 1))
n_obs = target.shape[0]

# Start values (needed to replace NaNs in predt)
start_values = data.get_init_score().reshape(-1, self.n_dist_param)[0, :].tolist()
Expand All @@ -135,7 +136,7 @@ def metric_fn(self, predt: np.ndarray, data: lgb.Dataset) -> Tuple[str, float, b
is_higher_better = False
_, loss = self.get_params_loss(predt, target, start_values, requires_grad=False)

return self.loss_fn, loss, is_higher_better
return self.loss_fn, loss / n_obs, is_higher_better

def loss_fn_start_values(self,
params: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion lightgbmlss/distributions/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def metric_fn(self, predt: np.ndarray, data: lgb.Dataset) -> Tuple[str, float, b
"""
# Target
target = torch.tensor(data.get_label().reshape(-1, 1))
n_obs = target.shape[0]

# Start values (needed to replace NaNs in predt)
start_values = data.get_init_score().reshape(-1, self.n_dist_param)[0, :].tolist()
Expand All @@ -151,7 +152,7 @@ def metric_fn(self, predt: np.ndarray, data: lgb.Dataset) -> Tuple[str, float, b
is_higher_better = False
_, loss = self.get_params_loss(predt, target, start_values)

return self.loss_fn, loss.detach(), is_higher_better
return self.loss_fn, loss.detach() / n_obs, is_higher_better

def calculate_start_values(self,
target: np.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion lightgbmlss/distributions/mixture_distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def metric_fn(self, predt: np.ndarray, data: lgb.Dataset) -> Tuple[str, float, b
"""
# Target
target = torch.tensor(data.get_label().reshape(-1, 1), dtype=torch.float32)
n_obs = target.shape[0]

# Start values (needed to replace NaNs in predt)
start_values = data.get_init_score().reshape(-1, self.n_dist_param)[0, :].tolist()
Expand All @@ -178,7 +179,7 @@ def metric_fn(self, predt: np.ndarray, data: lgb.Dataset) -> Tuple[str, float, b
is_higher_better = False
_, loss = self.get_params_loss(predt, target.flatten(), start_values, requires_grad=False)

return self.loss_fn, loss, is_higher_better
return self.loss_fn, loss / n_obs, is_higher_better

def create_mixture_distribution(self,
params: List[torch.Tensor],
Expand Down

0 comments on commit 3ba04b7

Please sign in to comment.