From badd70777334ebc2188d6ffe21ec9b6d7aca2793 Mon Sep 17 00:00:00 2001 From: Daniel Molinuevo <45424856+dmg99@users.noreply.github.com> Date: Thu, 11 Jul 2024 07:12:16 +0200 Subject: [PATCH] Casted sample_weights as tensor as well (#1104) --- pomegranate/distributions/normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pomegranate/distributions/normal.py b/pomegranate/distributions/normal.py index 3ca19c76..26133b60 100644 --- a/pomegranate/distributions/normal.py +++ b/pomegranate/distributions/normal.py @@ -257,7 +257,7 @@ def summarize(self, X, sample_weight=None): X, sample_weight = super().summarize(X, sample_weight=sample_weight) X = _cast_as_tensor(X, dtype=self.means.dtype) - + sample_weight = _cast_as_tensor(sample_weight, dtype=self.means.dtype) if self.covariance_type == 'full': self._w_sum += torch.sum(sample_weight, dim=0) self._xw_sum += torch.sum(X * sample_weight, axis=0)