diff --git a/src/mud/base.py b/src/mud/base.py index cd29ad4..60bbb24 100644 --- a/src/mud/base.py +++ b/src/mud/base.py @@ -51,21 +51,21 @@ def set_initial(self, distribution=None): self._up = None self._pr = None - def set_predicted(self, distribution=None): + def set_predicted(self, distribution=None, **kwargs): if distribution is None: - distribution = gkde(self.y.T) + distribution = gkde(self.y.T, **kwargs) pred_pdf = distribution.pdf(self.y.T).T else: - pred_pdf = distribution.pdf(self.y) + pred_pdf = distribution.pdf(self.y, **kwargs) self._pr = pred_pdf self._up = None - def fit(self): + def fit(self, **kwargs): if self._in is None: self.set_initial() self._pr = None if self._pr is None: - self.set_predicted() + self.set_predicted(**kwargs) if self._ob is None: self.set_observed() @@ -153,7 +153,7 @@ def fit(self): assert ps_pdf.shape[0] == self.X.shape[0] if np.sum(ps_pdf) == 0: raise ValueError("Posterior numerically unstable.") - self._ps = ps_pdf / np.sum(ps_pdf) + self._ps = ps_pdf def map_point(self): if self._ps is None: