Skip to content

Commit

Permalink
Make the sample state indexable (#425)
Browse files Browse the repository at this point in the history
* fixing #424

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixing myst issues in the docs

* docs execute argument name

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] authored May 2, 2022
1 parent 086bcc2 commit bee076f
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 8 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: ["ubuntu-latest"]
include:
- python-version: "3.8"
- python-version: "3.9"
os: "macos-latest"
- python-version: "3.8"
- python-version: "3.9"
os: "windows-latest"

steps:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ emcee_version.py
.tox
env
.eggs
.coverage.*
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@
"use_repository_button": True,
"use_download_button": True,
}
jupyter_execute_notebooks = "off"
execution_timeout = -1
nb_execution_mode = "off"
nb_execution_timeout = -1
18 changes: 18 additions & 0 deletions src/emcee/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def __init__(
self.blobs = dc(blobs)
self.random_state = dc(random_state)

def __len__(self):
if self.blobs is None:
return 3
return 4

def __repr__(self):
return "State({0}, log_prob={1}, blobs={2}, random_state={3})".format(
self.coords, self.log_prob, self.blobs, self.random_state
Expand All @@ -55,3 +60,16 @@ def __iter__(self):
return iter(
(self.coords, self.log_prob, self.random_state, self.blobs)
)

def __getitem__(self, index):
if index < 0:
return self[len(self) + index]
if index == 0:
return self.coords
elif index == 1:
return self.log_prob
elif index == 2:
return self.random_state
elif index == 3 and self.blobs is not None:
return self.blobs
raise IndexError("Invalid index '{0}'".format(index))
34 changes: 32 additions & 2 deletions src/emcee/tests/unit/test_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-

import numpy as np
import pytest

from emcee import EnsembleSampler
from emcee.state import State


def check_rstate(a, b):
assert all(np.allclose(a_, b_) for a_, b_ in zip(a[1:], b[1:]))


def test_back_compat(seed=1234):
np.random.seed(seed)
coords = np.random.randn(16, 3)
Expand All @@ -18,13 +23,13 @@ def test_back_compat(seed=1234):
assert np.allclose(coords, c)
assert np.allclose(log_prob, l)
assert np.allclose(blobs, b)
assert all(np.allclose(a, b) for a, b in zip(rstate[1:], r[1:]))
check_rstate(rstate, r)

state = State(coords, log_prob, None, rstate)
c, l, r = state
assert np.allclose(coords, c)
assert np.allclose(log_prob, l)
assert all(np.allclose(a, b) for a, b in zip(rstate[1:], r[1:]))
check_rstate(rstate, r)


def test_overwrite(seed=1234):
Expand All @@ -40,3 +45,28 @@ def ll(x):
sampler = EnsembleSampler(nwalkers, 1, ll)
sampler.run_mcmc(p0, 10)
assert np.allclose(init, p0)


def test_indexing(seed=1234):
np.random.seed(seed)
coords = np.random.randn(16, 3)
log_prob = np.random.randn(len(coords))
blobs = np.random.randn(len(coords))
rstate = np.random.get_state()

state = State(coords, log_prob, blobs, rstate)
np.testing.assert_allclose(state[0], state.coords)
np.testing.assert_allclose(state[1], state.log_prob)
check_rstate(state[2], state.random_state)
np.testing.assert_allclose(state[3], state.blobs)
np.testing.assert_allclose(state[-1], state.blobs)
with pytest.raises(IndexError):
state[4]

state = State(coords, log_prob, random_state=rstate)
np.testing.assert_allclose(state[0], state.coords)
np.testing.assert_allclose(state[1], state.log_prob)
check_rstate(state[2], state.random_state)
check_rstate(state[-1], state.random_state)
with pytest.raises(IndexError):
state[3]
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
[tox]
envlist = py{37,38,39}{,-extras},lint
envlist = py{37,38,39,310}{,-extras},lint

[gh-actions]
python =
3.7: py37
3.8: py38
3.9: py39-extras
3.10: py310

[testenv]
deps = coverage[toml]
Expand Down

0 comments on commit bee076f

Please sign in to comment.