diff --git a/jaxampler/_src/rvs/poisson.py b/jaxampler/_src/rvs/poisson.py index 6ebd0c0..699a528 100644 --- a/jaxampler/_src/rvs/poisson.py +++ b/jaxampler/_src/rvs/poisson.py @@ -27,34 +27,50 @@ class Poisson(DiscreteRV): - def __init__(self, lmbda: Numeric | Any, name: Optional[str] = None) -> None: - shape, self._lmbda = jx_cast(lmbda) + def __init__(self, mu: Numeric | Any, loc: Numeric | Any = 0.0, name: Optional[str] = None) -> None: + shape, self._mu, self._loc = jx_cast(mu, loc) self.check_params() super().__init__(name=name, shape=shape) def check_params(self) -> None: - assert jnp.all(self._lmbda > 0.0), "Lambda must be positive" + assert jnp.all(self._mu > 0.0), "Lambda must be positive" @partial(jit, static_argnums=(0,)) def logpmf_x(self, x: Numeric) -> Numeric: - return jax_poisson.logpmf(x, self._lmbda) + return jax_poisson.logpmf( + k=x, + mu=self._mu, + loc=self._loc, + ) @partial(jit, static_argnums=(0,)) def pmf_x(self, x: Numeric) -> Numeric: - return jax_poisson.pmf(x, self._lmbda) + return jax_poisson.pmf( + k=x, + mu=self._mu, + loc=self._loc, + ) + + @partial(jit, static_argnums=(0,)) + def logcdf_x(self, x: Numeric) -> Numeric: + return jnp.log(self.cdf_x(x)) @partial(jit, static_argnums=(0,)) def cdf_x(self, x: Numeric) -> Numeric: - return jax_poisson.cdf(x, self._lmbda) + return jax_poisson.cdf( + k=x, + mu=self._mu, + loc=self._loc, + ) def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array: if key is None: key = self.get_key() new_shape = shape + self._shape - return jax.random.poisson(key, self._lmbda, shape=new_shape) + return jax.random.poisson(key, self._mu, shape=new_shape) def __repr__(self) -> str: - string = f"Poisson(lmbda={self._lmbda}" + string = f"Poisson(lmbda={self._mu}" if self._name is not None: string += f", name={self._name}" string += ")"