Skip to content

Commit

Permalink
ENH: Add initial/mean point to chains to fix (#249)
Browse files Browse the repository at this point in the history
* Add initial/mean point to chains (closes #243)
* A update/fix for the test suite
  • Loading branch information
toastedcrumpets authored Nov 28, 2023
1 parent 95b0973 commit d435b13
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions espei/optimizers/opt_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def initialize_new_chains(params, chains_per_parameter, std_deviation, determini
# apply a Gaussian random to each parameter with std dev of std_deviation*parameter
tiled_parameters = np.tile(params, (nchains, 1))
chains = rng.normal(tiled_parameters, np.abs(tiled_parameters * std_deviation))
chains[0] = params #Ensure the initial guess is always included in the set, as the std_deviation may be too large to generate feasible points around a "good" initial point.
return chains

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,14 @@ def test_parameter_initialization():
initial_parameters = np.array([1, 10, 100, 1000])
opt = EmceeOptimizer(Database())
deterministic_params = opt.initialize_new_chains(initial_parameters, 1, 0.10, deterministic=True)
print(repr(deterministic_params))
expected_parameters = np.array([
[9.81708401e-01, 9.39027722e+00, 1.08016748e+02, 9.13512881e+02],
[1.03116874, 9.01412995, 112.79594345, 916.44725799],
list(initial_parameters), # The first element is always a start at the initial parameters
#These values are known due to deterministic=True above
[1.03116874e+00, 9.01412995e+00, 1.12795943e+02, 9.16447258e+02],
[1.00664662e+00, 1.07178898e+01, 9.63696718e+01, 1.36872292e+03],
[1.07642366e+00, 1.16413520e+01, 8.71742457e+01, 9.61836382e+02]])
[1.07642366e+00, 1.16413520e+01, 8.71742457e+01, 9.61836382e+02],
])
assert np.all(np.isclose(deterministic_params, expected_parameters))


Expand Down

0 comments on commit d435b13

Please sign in to comment.