Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft #20

Merged
merged 2 commits into from
Jan 22, 2024
Merged

Soft #20

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions discussion/softvmin.wl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
(* ExecuteFile["softvmin.wl"] *)

Clear["Global`*"]

fa[x_] := (2-x)(x-1)
fb[x_] := (3-x)(x-2)

Plot[{fa[x],fb[x]},{x,1,3}] //Export["fafb.pdf", # ]&


pows = {0,2,4}
vars=Subscript[b,#]&/@pows

appf[x_] := Total[Subscript[b,#] (x-2)^# &/@pows]

eqs={
appf[2-d]==fa[2-d],
appf[2+d]==fb[2+d],
D[fa[x],x]==D[appf[x],x]/.{x->2-d},
D[fb[x],x]==D[appf[x],x]/.{x->2+d},
D[fa[x],{x,2}]==D[appf[x],{x,2}]/.{x->2-d},
D[fb[x],{x,2}]==D[appf[x],{x,2}]/.{x->2+d},
D[appf[x],x]==0/.{x->2}
}

sol=Solve[
eqs,
vars
]

(* sol = vars/.Solve[
eqs/.{d->0.1},
vars
] *)

sol=sol[[1]]


Plot[{(appf[x]/.sol)/.{d->1/50}, fa[x],fb[x]},{x,1.8,2.2}, PlotRange->{0,1/4}]//Export["appf.pdf", # ]&

Export["sol.txt",(appf[x]/.sol)]

(* Needs["CCodeGenerator`"]

CCodeGenerator[]


c = Compile[ {{x},{d}}, appf[x]/.sol];
file = CCodeStringGenerate[c, "fun"] *)

(* Test cases *)

(appf[x]/.sol)/.{d->1/50, x->1.99}

(appf[x]/.sol)/.{d->1/10, x->2.05}

p = (n+1)/(n+2)
ep = p x + (1-p)(x+1)

v = Simplify[p (x-ep)^2 + (1-p)(x+1-ep)^2]

ExportString[v,"tex"]

sol = Solve[((appf[x]/.sol)/.{x->2})==v,d]

ExportString[sol,"tex"]

N[(d/.sol[[1]])/.{n->24}]

202 changes: 202 additions & 0 deletions examples/softvmin.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [

# HATCH_PYTHON=python3.10
requires-python = ">=3.10"
dependencies=["jax>=0.4.6"]
dependencies=["jax>=0.4.23"]

[project.urls]
Homepage = "https://github.com/gsd-authors/gsd"
Expand All @@ -46,7 +46,7 @@ include = [
]

[tool.hatch.envs.default]
dependencies=["jaxlib>=0.4.6"]
dependencies=["jaxlib>=0.4.23"]

[project.optional-dependencies]
experimental = [
Expand Down
2 changes: 1 addition & 1 deletion src/gsd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.2dev'
__version__ = '0.2.2'
from gsd.fit import GSDParams as GSDParams
from gsd.fit import fit_moments as fit_moments
from gsd.gsd import (log_prob as log_prob,
Expand Down
10 changes: 5 additions & 5 deletions src/gsd/experimental/max_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _explicit_log_probs(dist: 'MaxEntropyGSD'):

lgr = jax.tree_util.tree_map(jnp.asarray, (-0.01, -0.01, -0.01))
sol = optx.root_find(_implicit_log_probs, solver, lgr, args=dist,
max_steps=int(1e4), throw=False)
max_steps=int(1e4), throw=True)
return _lagrange_log_probs(sol.value, dist)


Expand All @@ -66,7 +66,6 @@ class MaxEntropyGSD(eqx.Module):
sigma: Float[Array, ""] # std
N: int = eqx.field(static=True)


def log_prob(self, x: Int[Array, ""]):
lp = _explicit_log_probs(self)
return lp[x - 1]
Expand Down Expand Up @@ -106,7 +105,7 @@ def sample(self, key: PRNGKeyArray, axis=-1, shape=None):
return jax.random.categorical(key, lp, axis, shape) + self.support[0]

@staticmethod
def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
def from_gsd(theta: GSDParams, N: int) -> 'MaxEntropyGSD':
"""Created maxentropy from GSD parameters.

:param theta: Parameters of a GSD distribution.
Expand All @@ -119,6 +118,7 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
N=N
)


MaxEntropyGSD.__init__.__doc__ = """Creates a MaxEntropyGSD

:param mean: Expectation value of the distribution.
Expand All @@ -127,6 +127,6 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':

.. note::
An alternative way to construct this distribution is by use of
:ref:`from_gsd`
:meth:`from_gsd`

"""
"""
32 changes: 31 additions & 1 deletion src/gsd/gsd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Callable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -141,3 +141,33 @@ def sufficient_statistic(data: ArrayLike) -> Array:
bins = jnp.arange(0.5, N + 1.5, 1.)
c, _ = jnp.histogram(jnp.asarray(data), bins=bins)
return c


def softvmin_poly(x: Array, c: float, d: float) -> Array:
"""Smooths approximation to `vmin` function.

:param x: An argument, this would be psi
:param d: Cut point of approximation from `[0,0.5)`
:return: An approximated value `x` such that `abs(round(x)-x)<=d`
"""
sq1 = jnp.square(x - c)
sq2 = jnp.square(sq1)

return (3 * d) / 8 - ((-3 + 4 * d) * sq1) / (4 * d) - sq2 / (8 * d ** 3)


def make_softvmin(d: float) -> Callable[[Array], Array]:
"""Create a soft approximation to `vmin` function.

:param d: Cut point of approximation from `[0,0.5)`
:return: A callable returning n approximated value `vmin` for `x`
`abs(round(x)-x)<=d`
"""
def sofvmin(psi: ArrayLike):
psi = jnp.asarray(psi)
c = jax.lax.stop_gradient(jnp.round(psi))
return jnp.where(jnp.abs(psi - c) < d, softvmin_poly(psi, c, d),
vmin(psi)
)

return sofvmin
59 changes: 56 additions & 3 deletions tests/experimental_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from jax import config

from gsd.gsd import make_softvmin, vmax, vmin

config.update("jax_enable_x64", True)
from gsd.experimental.max_entropy import MaxEntropyGSD
import unittest # noqa: E402

import jax
import jax.numpy as jnp
import numpy as np

import gsd
Expand All @@ -14,6 +13,14 @@
from gsd.experimental.fit import GridEstimator
from gsd.fit import log_pmax, pairs, pmax, GSDParams, fit_moments

import equinox as eqx
import optimistix as optx

from gsd.experimental.max_entropy import MaxEntropyGSD, vmax

import jax
import jax.numpy as jnp


class FitTestCase(unittest.TestCase):
def test_pairs(self):
Expand Down Expand Up @@ -117,3 +124,49 @@ def test_probs(self):
lp = me.all_log_probs
p = np.exp(lp)
self.assertAlmostEqual(p.sum(), 1)


def test_fit(self):
def nll(d, x):
m, s = d
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
svmin = make_softvmin(0.1)
smin = jnp.sqrt(svmin(mean))
smax = jnp.sqrt(vmax(mean, N=5))
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
d = MaxEntropyGSD(mean, sigma, N=5)
return -jnp.mean(d.log_prob(x))

# x = jnp.asarray([2, 3, 2, 2, 3, 3, 4])
x = jnp.asarray([2, 2, 2, 2, 2, 2, 2])

eqx.tree_pprint(jax.grad(nll)((0.01, 2.0), x), short_arrays=False)

def fit(x):
solver = optx.BFGS(rtol=1e-2, atol=1e-4)

res = optx.minimise(nll, solver, (-0.0, .0),
args=x,
max_steps=int(1e6),
throw=True)
return res

res = jax.jit(fit)(x)
eqx.tree_pprint(res.value, short_arrays=False)

m, s = res.value
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
smin = jnp.sqrt(vmin(mean))
smax = jnp.sqrt(vmax(mean, N=5))
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
d = MaxEntropyGSD(mean, sigma, N=5)

self.assertAlmostEqual(d.mean,2., places=4)

eqx.tree_pprint(d, short_arrays=False)
eqx.tree_pprint(MaxEntropyGSD(jnp.mean(x), jnp.std(x), N=5),
short_arrays=False)




27 changes: 27 additions & 0 deletions tests/ref_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
from jax import config

from gsd.gsd import softvmin_poly, make_softvmin, vmin

config.update("jax_enable_x64", True)

import unittest
Expand Down Expand Up @@ -104,5 +107,29 @@ def test_sufficient_statistic4(self):
# 1, 2 3 4 5
self.assertTrue(np.allclose(ss,c))


class SoftTestCase(unittest.TestCase):
def test_poly(self):
v = softvmin_poly(x=1.99,c=2., d=1/50.)
self.assertAlmostEqual(v, 0.0109938)
v = softvmin_poly(x=2.05,c=2, d=1 / 10.)
self.assertAlmostEqual(v, 0.0529687)

def test_softvmin(self):
svmin = make_softvmin(0.1)
self.assertAlmostEqual(svmin(3.3), vmin(3.3))

for x in [1.5,1.9, 1.95, 2.05, 2.1, 2.2]:
gsvmin = jax.grad(svmin)
g = gsvmin(x)
print(g)
self.assertIsNotNone(g)

ggsvmin = jax.grad(gsvmin)
gg = ggsvmin(x)
print(gg)
self.assertIsNotNone(gg)


if __name__ == '__main__':
unittest.main()
Loading