Skip to content

Commit

Permalink
Casted sample_weights as tensor as well (#1104)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmg99 authored Jul 11, 2024
1 parent 9a8cf3b commit badd707
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pomegranate/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def summarize(self, X, sample_weight=None):

X, sample_weight = super().summarize(X, sample_weight=sample_weight)
X = _cast_as_tensor(X, dtype=self.means.dtype)

sample_weight = _cast_as_tensor(sample_weight, dtype=self.means.dtype)
if self.covariance_type == 'full':
self._w_sum += torch.sum(sample_weight, dim=0)
self._xw_sum += torch.sum(X * sample_weight, axis=0)
Expand Down

0 comments on commit badd707

Please sign in to comment.