-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Convert analyzers/fisher.py to be compatible with latest calculator
- Loading branch information
Mingjian Wen
committed
Aug 19, 2019
1 parent
32f5d44
commit 70ddab9
Showing
11 changed files
with
206 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
Fisher information for the SW potential. | ||
""" | ||
|
||
|
||
from kliff.models import KIM | ||
from kliff.calculators import Calculator | ||
from kliff.dataset import Dataset | ||
from kliff.analyzers import Fisher | ||
|
||
|
||
########################################################################################## | ||
# Select the parameters that will be used to compute the Fisher information. Only | ||
# parameters specified below will be use, others will be kept fixed. The size of the | ||
# Fisher information matrix will be equal to the total size of the parameters specified | ||
# here. | ||
model = KIM(model_name='SW_StillingerWeber_1985_Si__MO_405512056662_005') | ||
model.set_fitting_params( | ||
A=[['default']], B=[['default']], sigma=[['default']], gamma=[['default']] | ||
) | ||
|
||
# dataset | ||
dataset_name = 'tmp_tset' | ||
tset = Dataset() | ||
tset.read(dataset_name) | ||
configs = tset.get_configs() | ||
|
||
# calculator | ||
calc = Calculator(model) | ||
calc.create(configs) | ||
|
||
########################################################################################## | ||
# Fisher information analyzer. | ||
analyzer = Fisher(calc) | ||
analyzer.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .rmse import energy_forces_RMSE | ||
from .rmse import EnergyForcesRMSE | ||
from .fisher import Fisher | ||
|
||
__all__ = ['energy_forces_RMSE'] | ||
__all__ = ['EnergyForcesRMSE', 'Fisher'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,106 +1,180 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
import numpy as np | ||
import copy | ||
import sys | ||
import numpy as np | ||
from ..log import log_entry | ||
from ..utils import split_string | ||
|
||
import kliff | ||
|
||
|
||
logger = kliff.logger.get_logger(__name__) | ||
|
||
|
||
class Fisher: | ||
"""Fisher information matrix. | ||
r"""Fisher information matrix. | ||
Compute the Fisher information according to $I_{ij} = \sum_m df_m/dp_i * df_m/dp_j$, | ||
where f_m are the forces on atoms in configuration m, p_i is the ith model parameter. | ||
Derivatives are computed numerically using Ridders' algorithm. | ||
Compute the Fisher information according to | ||
Parameters | ||
---------- | ||
..math:: | ||
I_{ij} = \sum_m \frac{\partial \bm f_m}{\partial \theta_i} | ||
\cdot \frac{\partial \bm f_m}{\partial \theta_j} | ||
KIMobjs: list of KIMcalculator objects | ||
where :math:`f_m` are the forces on atoms in configuration :math:`m`, :math:`\theta_i` | ||
is the ith model parameter. | ||
Derivatives are computed numerically using Ridders' algorithm: | ||
https://en.wikipedia.org/wiki/Ridders%27_method | ||
params: ModelParameters object | ||
Parameters | ||
---------- | ||
calculator: | ||
A calculator object. | ||
""" | ||
|
||
def __init__(self, params, calculator): | ||
self.params = params | ||
def __init__(self, calculator): | ||
self.calculator = calculator | ||
self.F = None | ||
self.F_std = None | ||
self.F_stdev = None | ||
self.delta_params = None | ||
|
||
def compute(self): | ||
"""Comptue the Fisher information matrix and the standard deviation. | ||
def run(self, verbose=1): | ||
"""Compute the Fisher information matrix and the standard deviation. | ||
Parameters | ||
---------- | ||
verbose: int | ||
If ``0``, do not write out to file; if ``1``, write to a file named | ||
``analysis_Fisher_info_matrix.txt``. | ||
Returns | ||
------- | ||
I: 2D array, shape(N, N) | ||
Fisher information matrix, where N is the number of parameters. | ||
F: 2D array, shape(N, N), where N is the number of parameters | ||
Fisher informaiton matrix (FIM) | ||
F_std: 2D array, shape(N, N), where N is the number of parameters | ||
standard deviation of FIM | ||
I_stdev: 2D array, shape(N, N) | ||
Standard deviation of Fisher information matrix, where N is the number of | ||
parameters. | ||
""" | ||
F_all = [] | ||
kim_in_out_data = self.calculator.get_kim_input_and_output() | ||
for in_out in kim_in_out_data: | ||
dfdp = self._get_derivative_one_conf(in_out) | ||
F_all.append(np.dot(dfdp, dfdp.T)) | ||
self.F = np.mean(F_all, axis=0) | ||
self.F_std = np.std(F_all, axis=0) | ||
return self.F, self.F_std | ||
|
||
def _get_derivative_one_conf(self, in_out): | ||
"""Compute the derivative dfm/dpi for one atom configuration. | ||
|
||
msg = 'Start computing Fisher information matrix.' | ||
log_entry(logger, msg, level='info') | ||
|
||
I_all = [] | ||
|
||
cas = self.calculator.get_compute_arguments() | ||
for i, ca in enumerate(cas): | ||
if i % 100 == 0: | ||
msg = 'Processing configuration {}.'.format(i) | ||
log_entry(logger, msg, level='info') | ||
dfdp = self._compute_jacobian_one_config(ca) | ||
I_all.append(np.dot(dfdp.T, dfdp)) | ||
I = np.mean(I_all, axis=0) | ||
I_stdev = np.std(I_all, axis=0) | ||
|
||
self._write_result(I, I_stdev, verbose) | ||
msg = 'Finish computing Fisher information matrix.' | ||
log_entry(logger, msg, level='info') | ||
|
||
return I, I_stdev | ||
|
||
def _write_result(self, I, stdev, verbose, path='analysis_Fisher_info_matrix.txt'): | ||
|
||
params = self.calculator.get_opt_params() | ||
nparams = len(params) | ||
names = [] | ||
values = [] | ||
component_idx = [] | ||
for i in range(len(params)): | ||
out = self.calculator.model.get_opt_param_name_value_and_indices(i) | ||
n, v, p_idx, c_idx = out | ||
names.append(n) | ||
values.append(v) | ||
component_idx.append(c_idx) | ||
|
||
# header | ||
header = '#' * 80 + '\n# Fisher information matrix.\n#\n' | ||
msg = ( | ||
'The size of the parameter list is {0}, and thus the Fisher information ' | ||
'matrix is a {0} by {0} matrix. The rows (columns) are associated with the ' | ||
'parameters in the following order:'.format(nparams) | ||
) | ||
header += split_string(msg, length=80, starter='#') | ||
header += '#\n' | ||
header += ( | ||
'row (column) index param name param value param component index\n' | ||
) | ||
for i, (n, v, c) in enumerate(zip(names, values, component_idx)): | ||
header += '{} {} {:23.15e} {}\n'.format(i, n, v, c) | ||
header += '#' * 80 + '\n' | ||
print(header) | ||
|
||
# write to file | ||
if verbose > 0: | ||
with open(path, 'w') as fout: | ||
|
||
fout.write(header) | ||
|
||
fout.write( | ||
'\n# Fisher information matrix, shape({0}, {0})\n'.format(nparams) | ||
) | ||
for line in I: | ||
for v in line: | ||
fout.write('{:23.15e} '.format(v)) | ||
fout.write('\n') | ||
|
||
fout.write( | ||
'\n# Standard deviation in Fisher information matrix, ' | ||
'shape({0}, {0})\n'.format(nparams) | ||
) | ||
for line in stdev: | ||
for v in line: | ||
fout.write('{:23.15e} '.format(v)) | ||
fout.write('\n') | ||
|
||
def _compute_jacobian_one_config(self, ca): | ||
"""Compute the Jacobian of forces w.r.t. parameters for one configuration. | ||
Parameters | ||
---------- | ||
in_out: Configuration object | ||
ca: object | ||
`compute argument` associated with one configuration. | ||
""" | ||
|
||
try: | ||
import numdifftools as nd | ||
except ImportError as e: | ||
raise ImportError( | ||
str(e) + '.\nFisher information computation needs ' | ||
'"numdifftools". Please install first.' | ||
+'{}\nFisher information analyzer needs "numdifftools". Please install ' | ||
'it first.'.format(str(e)) | ||
) | ||
|
||
derivs = [] | ||
ori_param_vals = self.params.get_x0() | ||
for i, p in enumerate(ori_param_vals): | ||
values = copy.deepcopy(ori_param_vals) | ||
Jfunc = nd.Jacobian(self._get_prediction) | ||
df = Jfunc(p, i, values, in_out) | ||
derivs.append(df.reshape((-1,))) | ||
# restore param values back | ||
self.params.update_params(ori_param_vals) | ||
# compute Jacobian of forces w.r.t. parameters | ||
original_params = self.calculator.get_opt_params() | ||
Jfunc = nd.Jacobian(self._compute_forces_one_config) | ||
j = Jfunc(copy.deepcopy(original_params), ca) | ||
|
||
# restore params back | ||
self.calculator.update_opt_params(original_params) | ||
|
||
return np.array(derivs) | ||
return j | ||
|
||
def _get_prediction(self, x, idx, values, in_out): | ||
""" Compute predictions using specific parameter. | ||
def _compute_forces_one_config(self, params, ca): | ||
""" Compute forces using a specific set of model parameters. | ||
Parameters | ||
---------- | ||
values: list of float | ||
params: list | ||
the parameter values | ||
idx: int | ||
the index of 'x' in the value list | ||
x: float | ||
the specific parameter value at slot 'idx' | ||
ca: object | ||
`compute argument` associated with one configuration | ||
Return | ||
------ | ||
forces: list of floats | ||
the forces on atoms in this configuration | ||
forces: 1D array | ||
the forces on atoms in this configuration | ||
""" | ||
values[idx] = x | ||
self.params.update_params(values) | ||
self.calculator.update_params(self.params) | ||
self.calculator.compute(in_out) | ||
forces = self.calculator.get_forces(in_out) | ||
self.calculator.update_opt_params(params) | ||
self.calculator.compute(ca) | ||
forces = self.calculator.get_forces(ca) | ||
forces = np.reshape(forces, (-1,)) | ||
|
||
return forces |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.