Skip to content

Commit f2a4c90

Browse files
authored
Merge pull request #192 from yonatank93/update_uq_test
Refactor UQ tests
2 parents 8fcf4f2 + 21f632c commit f2a4c90

File tree

7 files changed

+549
-198
lines changed

7 files changed

+549
-198
lines changed

docs/source/tutorials.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ Tutorials
1515
tutorials/nn_SiC
1616
tutorials/parameter_transform
1717
tutorials/uq_mcmc
18+
tutorials/uq_bootstrap
1819
tutorials/lennard_jones
1920
tutorials/linear_regression
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "656f7afc",
6+
"metadata": {},
7+
"source": [
8+
"# Bootstrapping\n",
9+
"\n",
10+
"In this example, we demonstrate how to perform uncertainty quantification (UQ) using\n",
11+
"bootstrap method. We use a Stillinger-Weber (SW) potential for silicon that is archived\n",
12+
"in OpenKIM_.\n",
13+
"\n",
14+
"For simplicity, we only set the energy-scaling parameters, i.e., ``A`` and ``lambda`` as\n",
15+
"the tunable parameters. These parameters will be calibrated to energies and forces of a\n",
16+
"small dataset, consisting of 4 compressed and stretched configurations of diamond silicon\n",
17+
"structure."
18+
]
19+
},
20+
{
21+
"cell_type": "markdown",
22+
"id": "98b590d7",
23+
"metadata": {},
24+
"source": [
25+
"To start, let's first install the SW model::\n",
26+
"\n",
27+
"$ kim-api-collections-management install user SW_StillingerWeber_1985_Si__MO_405512056662_006\n",
28+
"\n",
29+
".. seealso::\n",
30+
" This installs the model and its driver into the ``User Collection``. See\n",
31+
" :ref:`install_model` for more information about installing KIM models."
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": null,
37+
"id": "e617de91",
38+
"metadata": {
39+
"ExecuteTime": {
40+
"end_time": "2024-10-07T12:37:31.393276Z",
41+
"start_time": "2024-10-07T12:37:29.472146Z"
42+
}
43+
},
44+
"outputs": [],
45+
"source": [
46+
"import matplotlib.pyplot as plt\n",
47+
"import numpy as np\n",
48+
"\n",
49+
"from kliff.calculators import Calculator\n",
50+
"from kliff.dataset import Dataset\n",
51+
"from kliff.loss import Loss\n",
52+
"from kliff.models import KIMModel\n",
53+
"from kliff.uq.bootstrap import BootstrapEmpiricalModel\n",
54+
"from kliff.utils import download_dataset\n",
55+
"\n",
56+
"%matplotlib inline"
57+
]
58+
},
59+
{
60+
"cell_type": "markdown",
61+
"id": "57f71678",
62+
"metadata": {},
63+
"source": [
64+
"Before running bootstrap, we need to define a loss function and train the model. More\n",
65+
"detail information about this step can be found in :ref:`tut_kim_sw` and :ref:`tut_params_transform`."
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"id": "a3aa1d13",
72+
"metadata": {
73+
"ExecuteTime": {
74+
"end_time": "2024-10-07T12:37:59.004347Z",
75+
"start_time": "2024-10-07T12:37:58.979490Z"
76+
}
77+
},
78+
"outputs": [],
79+
"source": [
80+
"# Create the model\n",
81+
"model = KIMModel(model_name=\"SW_StillingerWeber_1985_Si__MO_405512056662_006\")\n",
82+
"\n",
83+
"# Set the tunable parameters and the initial guess\n",
84+
"opt_params = {\"A\": [[\"default\"]], \"lambda\": [[\"default\"]]}\n",
85+
"\n",
86+
"model.set_opt_params(**opt_params)\n",
87+
"model.echo_opt_params()\n",
88+
"\n",
89+
"# Get the dataset\n",
90+
"dataset_path = download_dataset(dataset_name=\"Si_training_set_4_configs\")\n",
91+
"# Read the dataset\n",
92+
"tset = Dataset(dataset_path)\n",
93+
"configs = tset.get_configs()\n",
94+
"\n",
95+
"# Create calculator\n",
96+
"calc = Calculator(model)\n",
97+
"# Only use the forces data\n",
98+
"ca = calc.create(configs, use_energy=False, use_forces=True)\n",
99+
"\n",
100+
"# Instantiate the loss function\n",
101+
"residual_data = {\"normalize_by_natoms\": False}\n",
102+
"loss = Loss(calc, residual_data=residual_data)"
103+
]
104+
},
105+
{
106+
"cell_type": "markdown",
107+
"id": "39a95904",
108+
"metadata": {},
109+
"source": [
110+
"To perform UQ by bootstrapping, the general workflow starts by instantiating :class:`~kliff.uq.bootstrap.BootstrapEmpiricalModel`, or :class:`~kliff.uq.bootstrap.BootstrapNeuralNetworkModel` if using a neural network\n",
111+
"potential."
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"id": "966614ab",
118+
"metadata": {
119+
"ExecuteTime": {
120+
"end_time": "2024-10-07T12:38:38.479190Z",
121+
"start_time": "2024-10-07T12:38:38.475357Z"
122+
}
123+
},
124+
"outputs": [],
125+
"source": [
126+
"# Instantiate bootstrap class object\n",
127+
"BS = BootstrapEmpiricalModel(loss, seed=1717)"
128+
]
129+
},
130+
{
131+
"cell_type": "markdown",
132+
"id": "b8be6029",
133+
"metadata": {},
134+
"source": [
135+
"Then, we generate some bootstrap compute arguments. This is equivalent to generating\n",
136+
"bootstrap data. Typically, we just need to specify how many bootstrap data samples to\n",
137+
"generate. Additionally, if we call ``generate_bootstrap_compute_arguments`` multiple\n",
138+
"times, the new generated data samples will be appended to the previously generated data\n",
139+
"samples. This is also the behavior if we read the data samples from the previously\n",
140+
"exported file."
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": null,
146+
"id": "e660eb87",
147+
"metadata": {
148+
"ExecuteTime": {
149+
"end_time": "2024-10-07T12:39:14.455217Z",
150+
"start_time": "2024-10-07T12:39:14.442511Z"
151+
}
152+
},
153+
"outputs": [],
154+
"source": [
155+
"# Generate bootstrap compute arguments\n",
156+
"BS.generate_bootstrap_compute_arguments(100)"
157+
]
158+
},
159+
{
160+
"cell_type": "markdown",
161+
"id": "898350eb",
162+
"metadata": {},
163+
"source": [
164+
"Finally, we will iterate over these bootstrap data samples and train the potential\n",
165+
"using each data sample. The resulting optimal parameters from each data sample give a\n",
166+
"single sample of parameters. By iterating over all data samples, then we will get an\n",
167+
"ensemble of parameters.\n",
168+
"\n",
169+
"Note that the mapping from the bootstrap dataset to the parameters involve optimization.\n",
170+
"We suggest to use the same mapping, i.e., the same optimizer setting, in each iteration.\n",
171+
"This includes using the same set of initial parameter guess. In the case when the loss\n",
172+
"function has multiple local minima, we don't want the parameter ensemble to be biased\n",
173+
"on the results of the other optimizations. For neural network model, we need to reset\n",
174+
"the initial parameter value, which is done internally.\n"
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"id": "d347a576",
181+
"metadata": {
182+
"ExecuteTime": {
183+
"end_time": "2024-10-07T12:39:53.510993Z",
184+
"start_time": "2024-10-07T12:39:48.359289Z"
185+
}
186+
},
187+
"outputs": [],
188+
"source": [
189+
"# Run bootstrap\n",
190+
"min_kwargs = dict(method=\"lm\") # Optimizer setting\n",
191+
"initial_guess = calc.get_opt_params() # Initial guess in the optimization\n",
192+
"BS.run(min_kwargs=min_kwargs, initial_guess=initial_guess)"
193+
]
194+
},
195+
{
196+
"cell_type": "markdown",
197+
"id": "e2526a32",
198+
"metadata": {},
199+
"source": [
200+
"The resulting parameter ensemble can be accessed in `BS.samples` as a `np.ndarray`.\n",
201+
"Then, we can plot the distribution of the parameters, as an example, or propagate the\n",
202+
"error to the target quantities we want to study."
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"execution_count": null,
208+
"id": "e33a7732",
209+
"metadata": {
210+
"ExecuteTime": {
211+
"end_time": "2024-10-07T12:40:23.927758Z",
212+
"start_time": "2024-10-07T12:40:23.759710Z"
213+
}
214+
},
215+
"outputs": [],
216+
"source": [
217+
"# Plot the distribution of the parameters\n",
218+
"plt.figure()\n",
219+
"plt.plot(*(BS.samples.T), \".\", alpha=0.5)\n",
220+
"param_names = list(opt_params.keys())\n",
221+
"plt.xlabel(param_names[0])\n",
222+
"plt.ylabel(param_names[1])\n",
223+
"plt.show()"
224+
]
225+
},
226+
{
227+
"cell_type": "markdown",
228+
"id": "fe68cf9b",
229+
"metadata": {},
230+
"source": [
231+
".. _OpenKIM: https://openkim.org"
232+
]
233+
}
234+
],
235+
"metadata": {
236+
"kernelspec": {
237+
"display_name": "Python 3 (ipykernel)",
238+
"language": "python",
239+
"name": "python3"
240+
},
241+
"language_info": {
242+
"codemirror_mode": {
243+
"name": "ipython",
244+
"version": 3
245+
},
246+
"file_extension": ".py",
247+
"mimetype": "text/x-python",
248+
"name": "python",
249+
"nbconvert_exporter": "python",
250+
"pygments_lexer": "ipython3",
251+
"version": "3.10.12"
252+
},
253+
"toc": {
254+
"base_numbering": 1,
255+
"nav_menu": {},
256+
"number_sections": true,
257+
"sideBar": false,
258+
"skip_h1_title": false,
259+
"title_cell": "Table of Contents",
260+
"title_sidebar": "Contents",
261+
"toc_cell": false,
262+
"toc_position": {},
263+
"toc_section_display": true,
264+
"toc_window_display": false
265+
}
266+
},
267+
"nbformat": 4,
268+
"nbformat_minor": 5
269+
}

tests/uq/conftest.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from kliff.dataset import Dataset
6+
from kliff.models import KIMModel
7+
8+
9+
@pytest.fixture(scope="session")
10+
def uq_test_dir():
11+
# Directory of uq test files
12+
return Path(__file__).resolve().parent
13+
14+
15+
@pytest.fixture(scope="session")
16+
def uq_test_data_dir():
17+
# Directory of uq test data
18+
return Path(__file__).resolve().parents[1] / "test_data/configs/Si_4"
19+
20+
21+
@pytest.fixture(scope="session")
22+
def uq_test_configs(uq_test_data_dir):
23+
# Load test configs
24+
data = Dataset(uq_test_data_dir)
25+
return data.get_configs()
26+
27+
28+
@pytest.fixture(scope="session")
29+
def uq_kim_model():
30+
# Load a KIM model
31+
modelname = "SW_StillingerWeber_1985_Si__MO_405512056662_006"
32+
model = KIMModel(modelname)
33+
model.set_opt_params(A=[["default"]])
34+
return model
35+
36+
37+
@pytest.fixture(scope="session")
38+
def uq_nn_orig_state_filename(uq_test_dir):
39+
"""Return the original state filename for the NN model."""
40+
return uq_test_dir / "orig_model.pkl"

0 commit comments

Comments
 (0)