Skip to content

Commit

Permalink
Fix formating and bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztofrusek committed Jan 22, 2024
1 parent a54be1d commit 4c8a75d
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 220 deletions.
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,19 @@ experimental = [
[tool.hatch.envs.default.scripts]
test = "python -m unittest discover -p '*test.py'"

[tool.ruff]
extend-include = ["*.ipynb"]
fixable = ["I001", "F401"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
ignore-init-module-imports = true
select = ["E", "F", "I001"]
src = []

[tool.ruff.isort]
combine-as-imports = true
lines-after-imports = 2
order-by-type = false

[tool.pyright]
reportIncompatibleMethodOverride = true
include = ["src", "tests"]
17 changes: 9 additions & 8 deletions src/gsd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
__version__ = '0.2.2'
from gsd.fit import GSDParams as GSDParams
from gsd.fit import fit_moments as fit_moments
from gsd.gsd import (log_prob as log_prob,
sample as sample,
mean as mean,
variance as variance,
sufficient_statistic as sufficient_statistic)
__version__ = "0.2.3dev"
from gsd.fit import fit_moments as fit_moments, GSDParams as GSDParams
from gsd.gsd import (
log_prob as log_prob,
mean as mean,
sample as sample,
sufficient_statistic as sufficient_statistic,
variance as variance,
)
from gsd.ref_prob import gsd_prob as gsd_prob
10 changes: 5 additions & 5 deletions src/gsd/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from gsd import fit_moments

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GSD estimator using moments')
parser.add_argument("-c", nargs=5, type=int, help="List of 5 counts",
required=True)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GSD estimator using moments")
parser.add_argument("-c", nargs=5, type=int, help="List of 5 counts", required=True)
args = parser.parse_args()

hat, _ = fit_moments(data=jnp.asarray(args.o, dtype=jnp.float32))
print(f'psi={hat.psi:.4f} rho={hat.rho:.4f}')
print(f"psi={hat.psi:.4f} rho={hat.rho:.4f}")
25 changes: 14 additions & 11 deletions src/gsd/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from .bootstrap import g_test as g_test
from .bootstrap import prob as prob
from .bootstrap import t_statistic as t_statistic
from .bootstrap import static_bootstrap as static_bootstrap
from .bootstrap import BootstrapResult as BootstrapResult
from .bootstrap import pp_plot_data as pp_plot_data
from .fit import GridEstimator as GridEstimator
from .fit import OptState as OptState
from .fit import fit_mle as fit_mle
from .fit import fit_mle_grid as fit_mle_grid

from .bootstrap import (
BootstrapResult as BootstrapResult,
g_test as g_test,
pp_plot_data as pp_plot_data,
prob as prob,
static_bootstrap as static_bootstrap,
t_statistic as t_statistic,
)
from .fit import (
fit_mle as fit_mle,
fit_mle_grid as fit_mle_grid,
GridEstimator as GridEstimator,
OptState as OptState,
)
from .max_entropy import MaxEntropyGSD as MaxEntropyGSD
38 changes: 23 additions & 15 deletions src/gsd/experimental/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jax.typing import ArrayLike

from gsd.experimental.fit import Estimator

from .. import GSDParams
from ..gsd import log_prob, sample, sufficient_statistic

Expand Down Expand Up @@ -62,9 +63,9 @@ def prob(x: GSDParams) -> Array:
:param x: Parametrs of GSD
:return: An array of probabilities
"""
return jnp.exp(jax.vmap(log_prob, in_axes=(None, None, 0))(x.psi, x.rho,
jnp.arange(1,
6)))
return jnp.exp(
jax.vmap(log_prob, in_axes=(None, None, 0))(x.psi, x.rho, jnp.arange(1, 6))
)


class BootstrapResult(NamedTuple):
Expand All @@ -74,27 +75,34 @@ class BootstrapResult(NamedTuple):


@partial(jax.jit, static_argnums=(1, 3, 4))
def static_bootstrap(data: ArrayLike, estimator: Estimator, key: Array,
n_bootstrap_samples: int,
n_total_scores: int) -> BootstrapResult:
def static_bootstrap(
data: ArrayLike,
estimator: Estimator,
key: Array,
n_bootstrap_samples: int,
n_total_scores: int,
) -> BootstrapResult:
theta_hat = estimator(data)
exp_prob_gsd = prob(theta_hat)

bootstrap_samples_gsd = sample(theta_hat.psi, theta_hat.rho,
(n_bootstrap_samples, n_total_scores), key)
bootstrap_samples_gsd = jax.vmap(sufficient_statistic)(
bootstrap_samples_gsd)
bootstrap_samples_gsd = sample(
theta_hat.psi, theta_hat.rho, (n_bootstrap_samples, n_total_scores), key
)
bootstrap_samples_gsd = jax.vmap(sufficient_statistic)(bootstrap_samples_gsd)

bootstrap_fit = jax.lax.map(estimator, bootstrap_samples_gsd)

bootstrap_exp_prob_gsd = jax.vmap(prob)(bootstrap_fit)
return BootstrapResult(probs=exp_prob_gsd,
bootstrap_samples=bootstrap_samples_gsd,
bootstrap_probs=bootstrap_exp_prob_gsd)
return BootstrapResult(
probs=exp_prob_gsd,
bootstrap_samples=bootstrap_samples_gsd,
bootstrap_probs=bootstrap_exp_prob_gsd,
)


def pp_plot_data(data: ArrayLike, estimator: Estimator, key: Array,
n_bootstrap_samples: int) -> Array:
def pp_plot_data(
data: ArrayLike, estimator: Estimator, key: Array, n_bootstrap_samples: int
) -> Array:
n = int(np.sum(data))
b = static_bootstrap(data, estimator, key, n_bootstrap_samples, n)
p_value = g_test(data, b.probs, b.bootstrap_samples, b.bootstrap_probs)
Expand Down
87 changes: 50 additions & 37 deletions src/gsd/experimental/fit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from functools import partial
from typing import NamedTuple, Callable
from typing import Callable, NamedTuple

import jax
import numpy as np
from jax import numpy as jnp, Array, tree_util as jtu
from jax import Array, numpy as jnp, tree_util as jtu
from jax._src.basearray import ArrayLike

from gsd import GSDParams, fit_moments
from gsd.fit import make_logits, allowed_region
from gsd import fit_moments, GSDParams
from gsd.fit import allowed_region, make_logits


Estimator = Callable[[Array], GSDParams]

Expand All @@ -31,10 +32,14 @@ class OptState(NamedTuple):


@partial(jax.jit, static_argnums=[1, 2, 3, 4, 5])
def fit_mle(data: ArrayLike, max_iterations: int = 100,
log_lr_min: ArrayLike = -15, log_lr_max: ArrayLike = 2.0,
num_lr: ArrayLike = 10, constrain_by_pmax=False, ) -> tuple[
GSDParams, OptState]:
def fit_mle(
data: ArrayLike,
max_iterations: int = 100,
log_lr_min: ArrayLike = -15,
log_lr_max: ArrayLike = 2.0,
num_lr: ArrayLike = 10,
constrain_by_pmax=False,
) -> tuple[GSDParams, OptState]:
"""Finds the maximum likelihood estimator of the GSD parameters. The
algorithm used here is a simple gradient ascent. We use the concept of
projected gradient to enforce constraints for parameters (psi in [1, 5],
Expand All @@ -50,7 +55,7 @@ def fit_mle(data: ArrayLike, max_iterations: int = 100,
learning rate.
:param num_lr: Number of learning rates to check during
the line search.
:return: An opt state whore params filed contains estimated values of
GSD Parameters
"""
Expand All @@ -65,9 +70,9 @@ def ll(theta: GSDParams) -> Array:

theta0 = fit_moments(data)

rate = jnp.concatenate([jnp.zeros((1,)),
jnp.logspace(log_lr_min, log_lr_max, num_lr,
base=2.0)])
rate = jnp.concatenate(
[jnp.zeros((1,)), jnp.logspace(log_lr_min, log_lr_max, num_lr, base=2.0)]
)

def update(tg, t, lo, hi):
"""
Expand All @@ -94,39 +99,43 @@ def body_fun(state: OptState) -> OptState:
new_lls = jnp.where(jnp.isnan(new_lls), -jnp.inf, new_lls)
max_idx = jnp.argmax(new_lls)
# jax.debug.print("{max_idx}||| {new_lls}",max_idx=max_idx,new_lls=new_lls)
ret = OptState(params=jtu.tree_map(lambda t: t[max_idx], new_theta),
previous_params=t, count=count + 1, )
ret = OptState(
params=jtu.tree_map(lambda t: t[max_idx], new_theta),
previous_params=t,
count=count + 1,
)
# jax.debug.print("body: {0} {1}",*ret.params)
return ret

def cond_fun(state: OptState) -> Array:
tn, tnm1, c = state
should_stop = jnp.logical_or(jnp.all(jnp.array(tn) == jnp.array(tnm1)),
c > max_iterations)
should_stop = jnp.logical_or(
jnp.all(jnp.array(tn) == jnp.array(tnm1)), c > max_iterations
)
# stop on NaN
should_stop = jnp.logical_or(should_stop,
jnp.any(jnp.isnan(jnp.array(tn))))
should_stop = jnp.logical_or(should_stop, jnp.any(jnp.isnan(jnp.array(tn))))
return jnp.logical_not(should_stop)

opt_state = jax.lax.while_loop(cond_fun, body_fun, OptState(params=theta0,
previous_params=jtu.tree_map(
lambda
_: jnp.inf,
theta0),
count=0, ), )
opt_state = jax.lax.while_loop(
cond_fun,
body_fun,
OptState(
params=theta0,
previous_params=jtu.tree_map(lambda _: jnp.inf, theta0),
count=0,
),
)
return opt_state.params, opt_state


def _make_map(psis, rhos, n):
f = lambda psi, rho: allowed_region(
make_logits(GSDParams(psi=psi, rho=rho)), n)
f = lambda psi, rho: allowed_region(make_logits(GSDParams(psi=psi, rho=rho)), n)
f = jax.vmap(f, in_axes=(0, None))
f = jax.vmap(f, in_axes=(None, 0))
return f(psis, rhos)


def fit_mle_grid(data: ArrayLike, num: GSDParams,
constrain_by_pmax=False) -> GSDParams:
def fit_mle_grid(data: ArrayLike, num: GSDParams, constrain_by_pmax=False) -> GSDParams:
"""Fit GSD using naive grid search method.
This function uses `numpy` and cannot be used in `jit`
Expand All @@ -141,14 +150,14 @@ def fit_mle_grid(data: ArrayLike, num: GSDParams,
:return: Fitted parameters
"""
lo = GSDParams(psi=1., rho=0.)
hi = GSDParams(psi=5., rho=1.)
lo = GSDParams(psi=1.0, rho=0.0)
hi = GSDParams(psi=5.0, rho=1.0)

grid_exes = jtu.tree_map(jnp.linspace, lo, hi, num)

def ll(psi, rho) -> Array:
ll = jnp.asarray(data) * make_logits(GSDParams(psi=psi, rho=rho))
ll = jnp.where(jnp.isnan(ll), 0., ll)
ll = jnp.where(jnp.isnan(ll), 0.0, ll)
return jnp.sum(ll)

grid_ll = jax.vmap(ll, in_axes=(0, None))
Expand All @@ -169,27 +178,28 @@ def ll(psi, rho) -> Array:


class GridEstimator(NamedTuple):
""" Stateful MLE based on grid search
"""Stateful MLE based on grid search
:param psis: Grid of psi axis
:param rhos: Grid of rho axis
:param lps: Grid of `log_prob` for each answer and each entry in the axes.
"""

psis: Array
rhos: Array
lps: Array

@staticmethod
def make(num: GSDParams)->"GridEstimator":
def make(num: GSDParams) -> "GridEstimator":
"""Make a grid estimator for GSD. This estimator precomputed log
probabilities for each answer on a regular grid.
:param num: Number of grid points
:return: Estimator
"""
lo = GSDParams(psi=1., rho=0.)
hi = GSDParams(psi=5., rho=1.)
lo = GSDParams(psi=1.0, rho=0.0)
hi = GSDParams(psi=5.0, rho=1.0)

grid_exes = jtu.tree_map(jnp.linspace, lo, hi, num)

Expand All @@ -202,8 +212,11 @@ def _make_logis(psi, rho):
logits = jnp.where(logits < small, small, logits)
return logits # ,psi,rho

return GridEstimator(psis=grid_exes.psi, rhos=grid_exes.rho,
lps=_make_logis(grid_exes.psi, grid_exes.rho))
return GridEstimator(
psis=grid_exes.psi,
rhos=grid_exes.rho,
lps=_make_logis(grid_exes.psi, grid_exes.rho),
)

@jax.jit
def __call__(self, data: Array) -> GSDParams:
Expand Down
Loading

0 comments on commit 4c8a75d

Please sign in to comment.