Skip to content

Commit

Permalink
Merge pull request #6 from wd60622/minor-ticks
Browse files Browse the repository at this point in the history
bounds for all plots
  • Loading branch information
wd60622 authored Aug 26, 2023
2 parents 950f70d + 10acf74 commit b65e5f8
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 32 deletions.
5 changes: 5 additions & 0 deletions conjugate/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ class Exponential(ContinuousPlotDistMixin, SliceMixin):
def dist(self):
return stats.expon(scale=self.lam)

def __mul__(self, other):
return Gamma(alpha=other, beta=1 / self.lam)

__rmul__ = __mul__


@dataclass
class Gamma(ContinuousPlotDistMixin, SliceMixin):
Expand Down
64 changes: 38 additions & 26 deletions conjugate/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def resolve_label(label: LABEL_INPUT, yy: np.ndarray):
return [f"{label} {i}" for i in range(1, ncols + 1)]

if callable(label):
return [label(i) for i in range(1, ncols + 1)]
return [label(i) for i in range(ncols)]

if isinstance(label, Iterable):
return label
Expand Down Expand Up @@ -68,6 +68,27 @@ def set_max_value(self, value: float) -> "PlotDistMixin":

return self

@property
def min_value(self) -> float:
if not hasattr(self, "_min_value"):
self._min_value = 0.0

return self._min_value

@min_value.setter
def min_value(self, value: float) -> None:
self._min_value = value

def set_min_value(self, value: float) -> "PlotDistMixin":
"""Set the minimum value for plotting."""
self.min_value = value

return self

def set_bounds(self, lower: float, upper: float) -> "PlotDistMixin":
"""Set both the min and max values for plotting."""
return self.set_min_value(lower).set_max_value(upper)

def _reshape_x_values(self, x: np.ndarray) -> np.ndarray:
"""Make sure that the values are ready for plotting."""
for value in asdict(self).values():
Expand Down Expand Up @@ -101,27 +122,6 @@ def plot_pdf(self, ax: Optional[plt.Axes] = None, **kwargs) -> plt.Axes:

return self._create_plot_on_axis(x, ax, **kwargs)

@property
def min_value(self) -> float:
if not hasattr(self, "_min_value"):
self._min_value = 0.0

return self._min_value

@min_value.setter
def min_value(self, value: float) -> None:
self._min_value = value

def set_min_value(self, value: float) -> "ContinuousPlotDistMixin":
"""Set the minimum value for plotting."""
self.min_value = value

return self

def set_bounds(self, lower: float, upper: float) -> "ContinuousPlotDistMixin":
"""Set both the min and max values for plotting."""
return self.set_min_value(lower).set_max_value(upper)

def _create_x_values(self) -> np.ndarray:
return np.linspace(self.min_value, self.max_value, 100)

Expand Down Expand Up @@ -180,21 +180,33 @@ def plot_pmf(
return self._create_plot_on_axis(x, ax, mark, **kwargs)

def _create_x_values(self) -> np.ndarray:
return np.arange(0, self.max_value + 1, 1)
return np.arange(self.min_value, self.max_value + 1, 1)

def _create_plot_on_axis(self, x, ax, mark, **kwargs) -> plt.Axes:
def _create_plot_on_axis(
self, x, ax, mark, conditional: bool = False, **kwargs
) -> plt.Axes:
yy = self.dist.pmf(x)
if conditional:
yy = yy / np.sum(yy)
ylabel = f"Conditional Probability $f(x|{self.min_value} \\leq x \\leq {self.max_value})$"
else:
ylabel = "Probability $f(x)$"

if "label" in kwargs:
label = kwargs.pop("label")
label = resolve_label(label, yy)
else:
label = None

ax.plot(x, yy, mark, label=label, **kwargs)
if self.max_value <= 15:

if self.max_value - self.min_value < 15:
ax.set_xticks(x.ravel())
else:
ax.set_xticks(x.ravel(), minor=True)
ax.set_xticks(x[::5].ravel())

ax.set_xlabel("Domain")
ax.set_ylabel("Probability $f(x)$")
ax.set_ylabel(ylabel)
ax.set_ylim(0, None)
return ax
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "conjugate-models"
version = "0.1.2"
version = "0.1.3"
description = "Bayesian Conjugate Models in Python"
authors = ["Will Dean <wd60622@gmail.com>"]
license = "MIT"
Expand Down
21 changes: 17 additions & 4 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Beta,
Dirichlet,
Gamma,
Exponential,
NegativeBinomial,
Poisson,
)
Expand Down Expand Up @@ -76,14 +77,15 @@ def test_slicing(alpha, beta, result_alpha, result_beta):


@pytest.mark.parametrize(
"dist, result_int",
"dist, result_dist",
[
(Poisson(1), Poisson(2)),
(Gamma(1, 1), Gamma(2, 1)),
(NegativeBinomial(1, 1), NegativeBinomial(2, 1)),
(Exponential(1), Gamma(2, 1)),
],
)
def test_distribution(dist, result_int) -> None:
def test_distribution(dist, result_dist) -> None:
if hasattr(dist, "plot_pdf"):
with pytest.raises(ValueError):
dist.plot_pdf()
Expand All @@ -100,8 +102,19 @@ def test_distribution(dist, result_int) -> None:
ax = dist.plot_pmf()
assert isinstance(ax, plt.Axes)

other = 2 * dist
assert other == result_int
if result_dist is not None:
other = 2 * dist
assert other == result_dist

lower, upper = -20, 20

assert dist.min_value != lower
assert dist.max_value != upper

dist.set_bounds(lower, upper)

assert dist.min_value == lower
assert dist.max_value == upper


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
("label", np.ones(shape=(3, 2)), ["label 1", "label 2"]),
("label", np.ones(shape=(2, 3)), ["label 1", "label 2", "label 3"]),
(
lambda i: f"another {i} label",
lambda i: f"another {i + 1} label",
np.ones(shape=(2, 3)),
["another 1 label", "another 2 label", "another 3 label"],
),
Expand Down

0 comments on commit b65e5f8

Please sign in to comment.