Skip to content

Commit

Permalink
Merge pull request #19 from gsd-authors/max_entropy
Browse files Browse the repository at this point in the history
Max entropy
  • Loading branch information
krzysztofrusek authored Jan 15, 2024
2 parents b681496 + 25fd150 commit f5af92f
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 14 deletions.
18 changes: 18 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ Besides the high-level API one can use optimizers form `scipy` or `tensorflow_pr

::: gsd.experimental.OptState

### Maximum entropy

GSD distribution can be considered as the whole family of distributions
with the following properties:

1. Its distribution over $[1,N]$
2. The first parameter represents expectation value
3. It covers all possible variances

Another distribution that has similar properties and can be considered a member
of GSD family is maximum entropy distribution.

::: gsd.experimental.MaxEntropyGSD
:docstring:
:members: __init__






2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
site_name: GSD
site_description: The documentation for the reference implementation of generalised score distribution in python.
repo_url: https://github.com/gsd-authors/gsd
repo_name: gsd-authors/gsd

theme:
name: "material"
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ include = [
[tool.hatch.envs.default]
dependencies=["jaxlib>=0.4.6"]

[project.optional-dependencies]
experimental = [
"optimistix>=0.0.6",
]

[tool.hatch.envs.default.scripts]
test = "python -m unittest discover -p '*test.py'"

2 changes: 1 addition & 1 deletion src/gsd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.1dev'
__version__ = '0.2.1'
from gsd.fit import GSDParams as GSDParams
from gsd.fit import fit_moments as fit_moments
from gsd.gsd import (log_prob as log_prob,
Expand Down
2 changes: 2 additions & 0 deletions src/gsd/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
from .fit import OptState as OptState
from .fit import fit_mle as fit_mle
from .fit import fit_mle_grid as fit_mle_grid

from .max_entropy import MaxEntropyGSD as MaxEntropyGSD
132 changes: 132 additions & 0 deletions src/gsd/experimental/max_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
from jaxtyping import Array, Float, Int, PRNGKeyArray

import gsd
from gsd import GSDParams
from gsd.gsd import vmin


@jax.jit
def vmax(mean: Array, N: Int) -> Array:
"""
Computes maximal variance for categorical distribution supported on Z[1,N]
:param mean:
:param N:
:return:
"""
return (mean - 1.0) * (N - mean)


def _lagrange_log_probs(lagrage: tuple, dist: 'MaxEntropyGSD'):
lamda1, lamdam, lamdas = lagrage
lp = lamda1 + dist.support * lamdam + lamdas * dist.squred_diff - 1.0
return lp


def _implicit_log_probs(lagrage: tuple, d: 'MaxEntropyGSD'):
lp = _lagrange_log_probs(lagrage, d)
p = jnp.exp(lp)
return (jnp.sum(p) - 1.0, # jax.nn.logsumexp(lp),
jnp.dot(p, d.support) - d.mean,
# jax.nn.logsumexp(a=lp, b=d.support) - jnp.log(d.mean),
jnp.dot(p, d.squred_diff) - d.sigma ** 2,
# jax.nn.logsumexp(a=lp, b=d.squred_diff) - 2 * jnp.log(d.sigma)
)


def _explicit_log_probs(dist: 'MaxEntropyGSD'):
solver = optx.Newton(rtol=1e-8, atol=1e-8, )

lgr = jax.tree_util.tree_map(jnp.asarray, (-0.01, -0.01, -0.01))
sol = optx.root_find(_implicit_log_probs, solver, lgr, args=dist,
max_steps=int(1e4), throw=False)
return _lagrange_log_probs(sol.value, dist)


class MaxEntropyGSD(eqx.Module):
r"""
Maximum entropy distribution supported on `Z[1,N]`
This distribution is defined to fulfill the following conditions on $p_i$
* Maximize $H= -\sum_i p_i\log(p_i)$ wrt.
* $\sum p_i=1$
* $\sum i p_i= \mu$
* $\sum (i-\mu)^2 p_i= \sigma^2$
:param mean: Expectation value of the distribution.
:param sigma: Standard deviation of the distribution.
:param N: Number of responses
"""
mean: Float[Array, ""]
sigma: Float[Array, ""] # std
N: int = eqx.field(static=True)


def log_prob(self, x: Int[Array, ""]):
lp = _explicit_log_probs(self)
return lp[x - 1]

def prob(self, x: Int[Array, ""]):
return jnp.exp(self.log_prob(x))

@property
def support(self):
return jnp.arange(1, self.N + 1)

@property
def squred_diff(self):
return jnp.square((self.support - self.mean))

def stddev(self):
return jnp.sqrt(self.variance())

def vmax(self):
return (self.mean - 1.0) * (self.N - self.mean)

def vmin(self):
return vmin(self.mean)

@property
def all_log_probs(self):
lp = _explicit_log_probs(self)
return lp

@jax.jit
def entropy(self):
lp = self.all_log_probs
return -jnp.dot(lp, jnp.exp(lp))

def sample(self, key: PRNGKeyArray, axis=-1, shape=None):
lp = self.all_log_probs
return jax.random.categorical(key, lp, axis, shape) + self.support[0]

@staticmethod
def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
"""Created maxentropy from GSD parameters.
:param theta: Parameters of a GSD distribution.
:param N: Support size
:return: A distribution object
"""
return MaxEntropyGSD(
mean=gsd.mean(theta.psi, theta.rho),
sigma=jnp.sqrt(gsd.variance(theta.psi, theta.rho)),
N=N
)

MaxEntropyGSD.__init__.__doc__ = """Creates a MaxEntropyGSD
:param mean: Expectation value of the distribution.
:param sigma: Standard deviation of the distribution.
:param N: Number of responses
.. note::
An alternative way to construct this distribution is by use of
:ref:`from_gsd`
"""
8 changes: 5 additions & 3 deletions src/gsd/gsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def logbinom(n: ArrayLike, k: ArrayLike) -> Array:


def vmin(psi: ArrayLike) -> Array:
"""Compute the minimal possible variance for give mean
"""Compute the minimal possible variance for categorical distribution
supported on Z[1,N] for a give mean
:param psi: mean
:return: variance
Expand All @@ -28,12 +29,13 @@ def vmin(psi: ArrayLike) -> Array:


def vmax(psi: ArrayLike) -> Array:
"""Compute the maximal possible variance for give mean
"""Compute the maximal possible variance for categorical distribution
supported on Z[1,N] for give mean
:param psi: mean
:return: variance
"""
return (psi - 1.0) * (5 - psi)
return (psi - 1.0) * (N - psi)


def _C(Vmax: ArrayLike, Vmin: ArrayLike) -> Array:
Expand Down
37 changes: 27 additions & 10 deletions tests/experimental_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from jax import config
config.update("jax_enable_x64", True)

config.update("jax_enable_x64", True)
from gsd.experimental.max_entropy import MaxEntropyGSD
import unittest # noqa: E402

import jax
Expand Down Expand Up @@ -45,7 +46,8 @@ def test_fit_grid3(self):
data = jnp.asarray([7, 25., 0, 0, 0])
hat = est(data)
theta = fit_mle_grid(data, num, False)
jax.tree_util.tree_map(lambda a,b: self.assertAlmostEqual(a,b,2), hat, theta)
jax.tree_util.tree_map(lambda a, b: self.assertAlmostEqual(a, b, 2),
hat, theta)

...

Expand All @@ -68,7 +70,7 @@ def test_sample_fit(self):
k = jax.random.key(12)
th = GSDParams(psi=4.2, rho=.92)
th = jax.tree_util.tree_map(jnp.asarray, th)
s = gsd.sample(th.psi, th.rho, (100000,),k)
s = gsd.sample(th.psi, th.rho, (100000,), k)
data = gsd.sufficient_statistic(s)
num = GSDParams(512, 128)
grid = GridEstimator.make(num)
Expand All @@ -79,24 +81,39 @@ def test_sample_fit(self):

def test_g_test(self):
# https://github.com/Qub3k/gsd-acm-mm/blob/master/Data_Analysis/G-test_results/G_test_on_real_data_chunk000_of_872.csv
data = jnp.asarray([0,0,1,10,13.])
data = jnp.asarray([0, 0, 1, 10, 13.])
num = GSDParams(512, 128)
grid = GridEstimator.make(num)


hat = grid(data)
self.assertTrue(np.allclose(hat.psi, 4.5, 0.001))
self.assertTrue(np.allclose(hat.rho, 0.935, 0.01))

p = bootstrap.prob(hat)
# 0.09459716927725387
t = bootstrap.t_statistic(data,p)
self.assertAlmostEqual(t,0.09459716927725387,2)
t = bootstrap.t_statistic(data, p)
self.assertAlmostEqual(t, 0.09459716927725387, 2)

# 0.4957
pv = bootstrap.pp_plot_data(data,lambda x: grid(x) ,jax.random.key(44),9999)
pv = bootstrap.pp_plot_data(data, lambda x: grid(x),
jax.random.key(44), 9999)

self.assertAlmostEqual(pv, 0.4957, 1)

...


self.assertAlmostEqual(pv,0.4957,1)
class MaxEntropyTestCase(unittest.TestCase):
def test_maxentropy(self):
me = MaxEntropyGSD(mean=3.2, sigma=0.2, N=5)
self.assertAlmostEqual(me.mean, 3.2)

s = me.sample(jax.random.key(44))
s2 = me.sample(jax.random.key(44), shape=(5,))
self.assertAlmostEqual(s2.shape[0], 5)

...
def test_probs(self):
me = MaxEntropyGSD.from_gsd(GSDParams(psi=3.2, rho=0.9), 5)
lp = me.all_log_probs
p = np.exp(lp)
self.assertAlmostEqual(p.sum(), 1)

0 comments on commit f5af92f

Please sign in to comment.