Skip to content

Commit

Permalink
Bump version and ensure float params
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Oct 7, 2024
1 parent e932a30 commit 14b7a52
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,10 @@ class Beta(AbstractDistribution):
cond_shape: ClassVar[None] = None

def __init__(self, alpha: ArrayLike, beta: ArrayLike):
alpha, beta = jnp.broadcast_arrays(alpha, beta)
alpha, beta = jnp.broadcast_arrays(
arraylike_to_array(alpha, dtype=float),
arraylike_to_array(beta, dtype=float),
)
self.alpha = Parameterize(softplus, inv_softplus(alpha))
self.beta = Parameterize(softplus, inv_softplus(beta))
self.shape = alpha.shape
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "15.0.0"
version = "15.1.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down

0 comments on commit 14b7a52

Please sign in to comment.