Skip to content

Commit 6f26fc3

Browse files
authored
Merge pull request #175 from apax-hub/bal
batch active learning
2 parents 0ff6b56 + 49d2ee0 commit 6f26fc3

15 files changed

+489
-194
lines changed

apax/bal/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from apax.bal.api import kernel_selection
2+
3+
__all__ = ["kernel_selection"]

apax/bal/api.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from functools import partial
2+
from typing import List, Union
3+
4+
import jax
5+
import numpy as np
6+
from ase import Atoms
7+
from click import Path
8+
from tqdm import trange
9+
10+
from apax.bal import feature_maps, kernel, selection, transforms
11+
from apax.data.input_pipeline import TFPipeline
12+
from apax.model.builder import ModelBuilder
13+
from apax.model.gmnn import EnergyModel
14+
from apax.train.checkpoints import restore_parameters
15+
from apax.train.run import RawDataset, initialize_dataset
16+
17+
18+
def create_feature_fn(
19+
model: EnergyModel,
20+
params,
21+
base_feature_map,
22+
feature_transforms=[],
23+
is_ensemble: bool = False,
24+
):
25+
"""
26+
Converts a model into a feature map and transforms it as needed and
27+
sets it up for use in copmuting the features of a dataset.
28+
29+
All transformations are applied on the feature function, not on computed features.
30+
Only the final function is jit compiled.
31+
"""
32+
feature_fn = base_feature_map.apply(model)
33+
34+
if is_ensemble:
35+
feature_fn = transforms.ensemble_features(feature_fn)
36+
37+
for transform in feature_transforms:
38+
feature_fn = transform.apply(feature_fn)
39+
40+
feature_fn = transforms.batch_features(feature_fn)
41+
feature_fn = partial(feature_fn, params)
42+
feature_fn = jax.jit(feature_fn)
43+
return feature_fn
44+
45+
46+
def compute_features(feature_fn, dataset: TFPipeline, processing_batch_size: int):
47+
"""Compute the features of a dataset."""
48+
features = []
49+
n_data = dataset.n_data
50+
ds = dataset.batch(processing_batch_size)
51+
52+
pbar = trange(n_data, desc="Computing features", ncols=100, leave=True)
53+
for i, (inputs, _) in enumerate(ds):
54+
g = feature_fn(inputs)
55+
features.append(np.asarray(g))
56+
pbar.update(g.shape[0])
57+
pbar.close()
58+
59+
features = np.concatenate(features, axis=0)
60+
return features
61+
62+
63+
def kernel_selection(
64+
model_dir: Union[Path, List[Path]],
65+
train_atoms: List[Atoms],
66+
pool_atoms: List[Atoms],
67+
base_fm_options: dict,
68+
selection_method: str,
69+
feature_transforms: list = [],
70+
selection_batch_size: int = 10,
71+
processing_batch_size: int = 64,
72+
):
73+
n_models = 1 if isinstance(model_dir, (Path, str)) else len(model_dir)
74+
is_ensemble = n_models > 1
75+
76+
selection_fn = {
77+
"max_dist": selection.max_dist_selection,
78+
}[selection_method]
79+
80+
base_feature_map = feature_maps.FeatureMapOptions(base_fm_options)
81+
82+
config, params = restore_parameters(model_dir)
83+
84+
n_train = len(train_atoms)
85+
dataset = initialize_dataset(config, RawDataset(atoms_list=train_atoms + pool_atoms))
86+
87+
init_box = dataset.init_input()["box"][0]
88+
89+
builder = ModelBuilder(config.model.get_dict(), n_species=119)
90+
model = builder.build_energy_model(apply_mask=True, init_box=init_box)
91+
92+
feature_fn = create_feature_fn(
93+
model, params, base_feature_map, feature_transforms, is_ensemble
94+
)
95+
g = compute_features(feature_fn, dataset, processing_batch_size)
96+
km = kernel.KernelMatrix(g, n_train)
97+
new_indices = selection_fn(km, selection_batch_size)
98+
99+
return new_indices

apax/bal/feature_maps.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Literal, Tuple, Union
2+
3+
import jax
4+
import jax.numpy as jnp
5+
from flax.traverse_util import flatten_dict, unflatten_dict
6+
from pydantic import BaseModel, TypeAdapter
7+
8+
9+
def extract_feature_params(params: dict, layer_name: str) -> Tuple[dict, dict]:
10+
"""Seprate params into those belonging to a selected layer
11+
and the remaining ones.
12+
"""
13+
p_flat = flatten_dict(params)
14+
15+
feature_layer_params = {k: v for k, v in p_flat.items() if layer_name in k}
16+
remaining_params = {k: v for k, v in p_flat.items() if layer_name not in k}
17+
18+
if len(feature_layer_params.keys()) > 2: # w and b
19+
print(feature_layer_params.keys())
20+
raise ValueError("Found more than one layer of the specified name")
21+
22+
return feature_layer_params, remaining_params
23+
24+
25+
class LastLayerGradientFeatures(BaseModel, extra="forbid"):
26+
"""
27+
Model transfomration which computes the gradient of the output
28+
wrt. the specified layer.
29+
https://arxiv.org/pdf/2203.09410
30+
"""
31+
32+
name: Literal["ll_grad"]
33+
layer_name: str = "dense_2"
34+
35+
def apply(self, model):
36+
def ll_grad(params, inputs):
37+
ll_params, remaining_params = extract_feature_params(params, self.layer_name)
38+
39+
def inner(ll_params):
40+
ll_params.update(remaining_params)
41+
full_params = unflatten_dict(ll_params)
42+
43+
# TODO find better abstraction for inputs
44+
R, Z, idx, box, offsets = (
45+
inputs["positions"],
46+
inputs["numbers"],
47+
inputs["idx"],
48+
inputs["box"],
49+
inputs["offsets"],
50+
)
51+
return model.apply(full_params, R, Z, idx, box, offsets)
52+
53+
g_ll = jax.grad(inner)(ll_params)
54+
g_ll = unflatten_dict(g_ll)
55+
56+
g_flat = jax.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll)
57+
(gw, gb), _ = jax.tree_util.tree_flatten(g_flat)
58+
59+
bias_factor = 0.1
60+
weight_factor = jnp.sqrt(1 / gw.shape[-1])
61+
g_scaled = [weight_factor * gw, bias_factor * gb]
62+
63+
g = jnp.concatenate(g_scaled)
64+
65+
return g
66+
67+
return ll_grad
68+
69+
70+
class IdentityFeatures(BaseModel, extra="forbid"):
71+
"""Identity feature map. For debugging purposes"""
72+
73+
name: Literal["identity"]
74+
75+
def apply(self, model):
76+
return model.apply
77+
78+
79+
FeatureMapOptions = TypeAdapter(
80+
Union[LastLayerGradientFeatures, IdentityFeatures]
81+
).validate_python

apax/bal/kernel.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import einops
2+
import numpy as np
3+
4+
5+
class KernelMatrix:
6+
"""
7+
Matrix representation of a kernel defined by a feature map g
8+
K_{ij} = \\sum_{k} g_{ik} g_{jk}
9+
"""
10+
11+
def __init__(self, g: np.ndarray, n_train: int):
12+
self.num_columns = g.shape[0]
13+
self.g = g
14+
self.diagonal = einops.einsum(g, g, "s feature, s feature -> s")
15+
self.n_train = n_train
16+
17+
def compute_column(self, idx: int) -> np.ndarray:
18+
return einops.einsum(self.g, self.g[idx, :], "s feature, feature -> s")
19+
20+
def score(self, idx: int) -> np.ndarray:
21+
"""Computes the distance of sample i from all other samples j as
22+
K_{ii} + K_{jj} - 2 K_{ij}
23+
"""
24+
return self.diagonal[idx] + self.diagonal - 2 * self.compute_column(idx)

apax/bal/selection.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
3+
from apax.bal.kernel import KernelMatrix
4+
5+
6+
def max_dist_selection(matrix: KernelMatrix, batch_size: int):
7+
"""
8+
Iteratively selects samples from the pool which are
9+
most distant from all previously selected samples.
10+
\\argmax_{S \\in \\mathbb{X}_{rem}} \\min_{S' \\in \\mathbb{X}_{sel} } d(S, S')
11+
12+
https://arxiv.org/pdf/2203.09410.pdf
13+
https://doi.org/10.1039/D2DD00034B
14+
"""
15+
n_train = matrix.n_train
16+
17+
min_squared_distances = matrix.diagonal
18+
min_squared_distances[:n_train] = -np.inf
19+
20+
# Use max norm for first point
21+
new_idx = np.argmax(min_squared_distances)
22+
selected_idxs = list(range(n_train)) + [new_idx]
23+
24+
for _ in range(1, batch_size):
25+
squared_distances = matrix.score(new_idx)
26+
27+
squared_distances[selected_idxs] = -np.inf
28+
min_squared_distances = np.minimum(min_squared_distances, squared_distances)
29+
30+
new_idx = np.argmax(min_squared_distances)
31+
selected_idxs.append(new_idx)
32+
33+
return (
34+
np.array(selected_idxs[n_train:]) - n_train
35+
) # shift by number of train datapoints

apax/bal/transforms.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
def ensemble_features(feature_fn):
6+
ensemble_feature_fn = jax.vmap(feature_fn, (0, None), 0)
7+
8+
def averaged_feature_fn(params, x):
9+
g = ensemble_feature_fn(params, x)
10+
11+
if len(g.shape) != 2:
12+
# models, features
13+
raise ValueError(
14+
"Dimension mismatch for input features. Expected shape (models,"
15+
f" features), got {g.shape}"
16+
)
17+
18+
n_models = g.shape[0]
19+
# sqrt since the kernel is K = g^T g
20+
feature_scale_factor = jnp.sqrt(1 / n_models)
21+
g_ens = feature_scale_factor * jnp.sum(g, axis=0) # shape: n_features
22+
return g_ens
23+
24+
return averaged_feature_fn
25+
26+
27+
def batch_features(feature_fn):
28+
batched_feature_fn = jax.vmap(feature_fn, (None, 0), 0)
29+
return batched_feature_fn

apax/data/input_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,13 @@ def shuffle_and_batch(self):
222222

223223
shuffled_ds = prefetch_to_single_device(shuffled_ds.as_numpy_iterator(), 2)
224224
return shuffled_ds
225+
226+
def batch(self, batch_size):
227+
# TODO: the batch size here overrides self.batch_size
228+
# we should find a better abstraction
229+
ds = self.ds.batch(batch_size=batch_size).map(
230+
PadToSpecificSize(self.max_atoms, self.max_nbrs)
231+
)
232+
233+
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
234+
return ds

apax/data/preprocessing.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,39 +78,40 @@ def dataset_neighborlist(
7878
r_max,
7979
)
8080

81-
pbar_update_freq = max(int(len(atoms_list) / 100), 50)
82-
with trange(
81+
nl_pbar = trange(
8382
len(positions),
8483
desc="Precomputing NL",
8584
ncols=100,
85+
mininterval=0.25,
8686
disable=disable_pbar,
8787
leave=True,
88-
) as nl_pbar:
89-
for i, position in enumerate(positions):
90-
if np.all(box[i] < 1e-6):
91-
position = jnp.asarray(position)
92-
if n_atoms[i] != last_n_atoms:
93-
neighbors = neighbor_fn.allocate(position)
94-
last_n_atoms = n_atoms[i]
95-
96-
neighbors = extract_nl(neighbors, position)
97-
98-
if neighbors.did_buffer_overflow:
99-
log.info("Neighbor list overflowed, reallocating.")
100-
neighbors = neighbor_fn.allocate(position)
101-
102-
neighbor_idxs = np.asarray(neighbors.idx)
103-
n_neighbors = neighbor_idxs.shape[1]
104-
offsets = np.full([n_neighbors, 3], 0)
105-
else:
106-
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms_list[i], r_max)
107-
offsets = np.matmul(offsets, box[i])
108-
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
109-
110-
offset_list.append(offsets)
111-
idx_list.append(neighbor_idxs)
112-
if i % pbar_update_freq == 0:
113-
nl_pbar.update(pbar_update_freq)
88+
)
89+
for i, position in enumerate(positions):
90+
if np.all(box[i] < 1e-6):
91+
position = jnp.asarray(position)
92+
if n_atoms[i] != last_n_atoms:
93+
neighbors = neighbor_fn.allocate(position)
94+
last_n_atoms = n_atoms[i]
95+
96+
neighbors = extract_nl(neighbors, position)
97+
98+
if neighbors.did_buffer_overflow:
99+
log.info("Neighbor list overflowed, reallocating.")
100+
neighbors = neighbor_fn.allocate(position)
101+
102+
neighbor_idxs = np.asarray(neighbors.idx)
103+
n_neighbors = neighbor_idxs.shape[1]
104+
offsets = np.full([n_neighbors, 3], 0)
105+
else:
106+
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms_list[i], r_max)
107+
offsets = np.matmul(offsets, box[i])
108+
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
109+
110+
offset_list.append(offsets)
111+
idx_list.append(neighbor_idxs)
112+
nl_pbar.update()
113+
nl_pbar.close()
114+
114115
return idx_list, offset_list
115116

116117

0 commit comments

Comments
 (0)