Skip to content

Commit 2c991ee

Browse files
committed
Including minibatch MCMC for paper 2
1 parent 4cd9b5a commit 2c991ee

20 files changed

+1856
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ eeyore/data/__pycache__/
99
eeyore/datasets/__pycache__/
1010
eeyore/distributions/__pycache__/
1111
eeyore/integrators/__pycache__/
12+
eeyore/itertools/__pycache__/
1213
eeyore/kernels/__pycache__/
1314
eeyore/linalg/__pycache__/
1415
eeyore/models/__pycache__/

eeyore/chains/chain_file.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,19 @@ def close(self):
2525
for key in self.vals.keys():
2626
self.vals[key].close()
2727

28-
def update(self, state, reset=True, close=True):
28+
def update(self, state,
29+
reset=True, close=True, fmt={'sample': '%.18e', 'target_val': '%.18e', 'grad_val': '%.18e', 'accepted': '%d'}):
2930
""" Update the chain """
3031
if reset:
3132
self.reset(keys=self.vals.keys())
3233

3334
for key in self.vals.keys():
3435
if isinstance(state[key], torch.Tensor):
35-
np.savetxt(self.vals[key], state[key].detach().cpu().numpy().ravel()[np.newaxis], delimiter=',')
36+
np.savetxt(
37+
self.vals[key], state[key].detach().cpu().numpy().ravel()[np.newaxis], fmt=fmt[key], delimiter=','
38+
)
3639
elif isinstance(state[key], np.ndarray):
37-
np.savetxt(self.vals[key], state[key].ravel()[np.newaxis], delimiter=',')
40+
np.savetxt(self.vals[key], state[key].ravel()[np.newaxis], fmt=fmt[key], delimiter=',')
3841
else:
3942
self.vals[key].write(str(state[key])+'\n')
4043

eeyore/chains/chain_list.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ def mc_cor(self, mc_cov_mat=None, method='inse', adjust=False):
9292
return st.cor_from_cov(mc_cov_mat)
9393

9494
def acceptance_rate(self):
95-
""" proportion of accepted samples """
95+
""" Proportion of accepted samples """
9696
return sum(self.vals['accepted']) / self.num_samples()
9797

98+
def block_acceptance_rate(self):
99+
return torch.stack(self.vals['accepted']).sum(axis=0) / self.num_samples()
100+
98101
def multi_ess(self, mc_cov_mat=None, method='inse', adjust=False):
99102
return st.multi_ess(self.get_samples(), mc_cov_mat=mc_cov_mat, method=method, adjust=adjust)
100103

@@ -106,13 +109,17 @@ def load(self, path):
106109
""" Load a previously saved chain """
107110
self.vals = torch.load(path)
108111

109-
def to_chainfile(self, keys=None, path=Path.cwd(), mode='a'):
112+
def to_chainfile(self,
113+
keys=None,
114+
path=Path.cwd(),
115+
mode='a',
116+
fmt={'sample': '%.18e', 'target_val': '%.18e', 'grad_val': '%.18e', 'accepted': '%d'}):
110117
from .chain_file import ChainFile
111118

112119
chainfile = ChainFile(keys=keys or self.vals.keys(), path=path, mode=mode)
113120

114121
for i in range(len(self)):
115-
chainfile.update(self.state(i), reset=False, close=False)
122+
chainfile.update(self.state(i), reset=False, close=False, fmt=fmt)
116123

117124
chainfile.close()
118125

eeyore/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,7 @@
22
from .data_info import data_paths
33
from .empty_dataset import EmptyXYDataset
44
from .idataset import IDataset
5+
from .mld_batcher import MLDBatcher
6+
from .mld_classification_batcher import MLDClassificationBatcher
57
from .xydataset import XYDataset
68
from .xyidataset import XYIDataset

eeyore/datasets/mld_batcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Minimum likelihood distance batcher (abstract base class)
2+
3+
class MLDBatcher:
4+
pass
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Minimum likelihood distance batcher for classification
2+
3+
import copy
4+
import functools
5+
import operator
6+
import random
7+
import torch
8+
9+
from .mld_batcher import MLDBatcher
10+
11+
class MLDClassificationBatcher(MLDBatcher):
12+
def __init__(self, num_batches, chunk_sizes, dataset=None):
13+
self.num_batches = num_batches
14+
15+
self.chunk_sizes = chunk_sizes
16+
assert len(self.chunk_sizes) == 2
17+
18+
self.set_dataset(dataset)
19+
20+
def set_dataset(self, dataset):
21+
self.dataset = dataset
22+
23+
if self.dataset is not None:
24+
self.num_points = len(dataset)
25+
self.num_classes = len(dataset.y[0])
26+
27+
label_argmax = torch.argmax(self.dataset.y, axis=1)
28+
29+
self.class_indices = [[] for _ in range(self.num_classes)]
30+
for i in range(self.num_points):
31+
self.class_indices[label_argmax[i].item()].append(i)
32+
33+
self.class_props = [len(self.class_indices[i]) / self.num_points for i in range(self.num_classes)]
34+
35+
self.class_num_batch_points = [
36+
[int(self.class_props[j]*self.chunk_sizes[i]) for j in range(self.num_classes)] for i in range(2)
37+
]
38+
39+
def batch_size(self):
40+
return sum(self.chunk_sizes)
41+
42+
def fill_class_sizes(self):
43+
class_num_batch_points = copy.deepcopy(self.class_num_batch_points)
44+
45+
sampled_classes = [
46+
random.choices(range(self.num_classes), k=self.chunk_sizes[i]-sum(class_num_batch_points[i])) for i in range(2)
47+
]
48+
49+
for i in range(2):
50+
for j in sampled_classes[i]:
51+
class_num_batch_points[i][j] = class_num_batch_points[i][j] + 1
52+
53+
return class_num_batch_points
54+
55+
def get_batch(self, model, params, fill=True):
56+
class_num_batch_points = [self.fill_class_sizes() for _ in range(self.num_batches)]
57+
58+
mld_distance = float('inf')
59+
60+
for i in range(self.num_batches):
61+
indices = []
62+
63+
indices.extend([
64+
random.sample(self.class_indices[j], class_num_batch_points[i][0][j]) for j in range(self.num_classes)
65+
])
66+
67+
rest_indices = [list(set(self.class_indices[j]) - set(indices[j])) for j in range(self.num_classes)]
68+
69+
indices.extend(
70+
[random.sample(rest_indices[j], class_num_batch_points[i][1][j]) for j in range(self.num_classes)]
71+
)
72+
73+
indices = functools.reduce(operator.iconcat, indices, [])
74+
75+
indices.sort()
76+
77+
distance = 0.
78+
79+
for j in range(2):
80+
log_lik_vals = model.set_params_and_log_lik(params[j].clone().detach(), self.dataset.x, self.dataset.y)
81+
82+
distance = distance + (log_lik_vals.mean() - log_lik_vals[indices].mean()).abs()
83+
84+
distance = distance.sqrt().item()
85+
86+
if distance < mld_distance:
87+
mld_indices = indices
88+
mld_distance = distance
89+
90+
return self.dataset.x[mld_indices, :], self.dataset.y[mld_indices, :]

eeyore/itertools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .chunk_evenly import chunk_evenly

eeyore/itertools/chunk_evenly.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
def chunk_evenly(iterable, n):
2+
iterable_len = len(iterable)
3+
r, a = iterable_len % n, 0
4+
5+
for i, s in enumerate(range(0, iterable_len if (r == 0) else (iterable_len-n), n)):
6+
yield iterable[(s+a):((s+a+n+1) if (i < r) else (s+a+n))]
7+
8+
if i < r:
9+
a = a + 1

eeyore/models/mlp.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import torch
23
import torch.nn as nn
34

@@ -51,3 +52,47 @@ def forward(self, x):
5152
def num_hidden_layers(self):
5253
""" Get the number of hidden layers. """
5354
return len(self.hp.dims)-2
55+
56+
def num_par_blocks(self):
57+
return sum(self.hp.dims[1:])
58+
59+
def layer_and_node_from_par_block(self, b):
60+
num_nodes_per_layer = [0] + list(itertools.accumulate(self.hp.dims[1:]))
61+
l = self.num_hidden_layers()
62+
63+
for i in range(1, len(num_nodes_per_layer)):
64+
if num_nodes_per_layer[-i-1] <= b < num_nodes_per_layer[-i]:
65+
n = b if (num_nodes_per_layer[-i-1] == 0) else (b % num_nodes_per_layer[-i-1])
66+
break
67+
else:
68+
l = l - 1
69+
70+
return l, n
71+
72+
def starting_par_block_idx(self, l):
73+
s = 0
74+
75+
if l > 0:
76+
for i in range(l):
77+
s = s + (self.hp.dims[i]+1 if self.hp.bias[i] else self.hp.dims[i])*self.hp.dims[i+1]
78+
79+
return s
80+
81+
def starting_par_block_indices(self):
82+
s = [0]
83+
84+
for l in range(self.num_hidden_layers()):
85+
s.append(s[-1]+(self.hp.dims[l]+1 if self.hp.bias[l] else self.hp.dims[l])*self.hp.dims[l+1])
86+
87+
return s
88+
89+
def par_block_indices(self, b):
90+
l, n = self.layer_and_node_from_par_block(b)
91+
s = self.starting_par_block_idx(l)
92+
93+
indices = list(range(s+n*self.hp.dims[l], s+(n+1)*self.hp.dims[l])) if (self.hp.dims[l] > 1) else [s+n]
94+
95+
if self.hp.bias[l]:
96+
indices.append(s+self.hp.dims[l]*self.hp.dims[l+1]+n)
97+
98+
return indices, l, n

eeyore/models/tmp.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# %%
2+
3+
from eeyore.constants import loss_functions
4+
from eeyore.models import mlp
5+
6+
# %%
7+
8+
# hparams = mlp.Hyperparameters(dims=[2, 3, 3, 2], bias=3*[True], activations=3*[None])
9+
hparams = mlp.Hyperparameters(dims=[2, 3, 3, 2], bias=[True, True, True], activations=3*[None])
10+
11+
model = mlp.MLP(loss=loss_functions['multiclass_classification'], hparams=hparams)
12+
13+
# %%
14+
15+
print(model.num_par_blocks())
16+
17+
print([model.starting_par_block_idx(i) for i in [0, 1, 2]])
18+
19+
print(model.starting_par_block_indices())
20+
21+
for b in range(8):
22+
l, n = model.layer_and_node_from_par_block(b)
23+
print("Block {} is in layer {} and node {} of that layer".format(b, l, n))
24+
25+
for b in range(8):
26+
indices, l, n = model.par_block_indices(b)
27+
print("Block {} is in layer {} and node {} of that layer and has indices {}".format(b, l, n, indices))

eeyore/samplers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .am import AM
2+
from .dmcl import DMCL
3+
from .gibbs import Gibbs
24
from .hmc import HMC
35
from .mala import MALA
46
from .metropolis_hastings import MetropolisHastings

eeyore/samplers/dmcl.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
import torch
3+
4+
from .single_chain_serial_sampler import SingleChainSerialSampler
5+
from eeyore.chains import ChainList
6+
from eeyore.datasets import DataCounter
7+
from eeyore.itertools import chunk_evenly
8+
from eeyore.kernels import NormalKernel
9+
10+
class DMCL(SingleChainSerialSampler):
11+
def __init__(self, model, batch_model, batcher,
12+
theta0=None, dataloader=None, data0=None, counter=None,
13+
scales=1., node_subblock_size=None, chain=ChainList()):
14+
super(DMCL, self).__init__(counter or DataCounter.from_dataloader(dataloader))
15+
self.model = model
16+
self.batch_model = batch_model
17+
self.dataloader = dataloader
18+
19+
self.batcher = batcher
20+
self.batcher.set_dataset(dataloader.dataset)
21+
22+
self.keys = ['sample', 'target_val', 'accepted']
23+
self.chain = chain
24+
25+
if theta0 is not None:
26+
self.set_current(theta0.clone().detach(), data=data0)
27+
28+
if isinstance(scales, float):
29+
self.scales = torch.full([self.model.num_par_blocks()], scales, dtype=self.model.dtype, device=self.model.device)
30+
elif isinstance(scales, torch.Tensor):
31+
self.scales = scales.to(dtype=self.model.dtype, device=self.model.device)
32+
elif isinstance(scales, list):
33+
self.scales = torch.tensor(scales, dtype=self.model.dtype, device=self.model.device)
34+
else:
35+
self.scales = scales
36+
37+
if node_subblock_size is None:
38+
self.node_subblock_size = [None for _ in range(self.model.num_par_blocks())]
39+
else:
40+
self.node_subblock_size = node_subblock_size
41+
42+
def set_current(self, theta, data=None):
43+
super().set_current(theta, data=data)
44+
45+
def reset(self, theta, data=None, reset_counter=True, reset_chain=True):
46+
super().reset(theta, data=data, reset_counter=reset_counter, reset_chain=reset_chain)
47+
48+
def get_blocks(self):
49+
blocks = []
50+
51+
for b in range(self.model.num_par_blocks()):
52+
indices, l, n = self.model.par_block_indices(b)
53+
54+
if self.node_subblock_size[b] is None:
55+
indices = [indices]
56+
else:
57+
indices = list(chunk_evenly(indices, self.node_subblock_size[b]))
58+
59+
blocks.append([l, n, indices])
60+
61+
return blocks
62+
63+
def save_blocks(self, path='gibbs_lbocks.txt', mode='w'):
64+
with open(path, mode) as file:
65+
json.dump(self.get_blocks(), file)
66+
67+
def draw(self, x, y, savestate=False):
68+
proposed = {key : None for key in self.keys}
69+
self.current['accepted'] = []
70+
71+
proposed['sample'] = self.current['sample'].clone().detach()
72+
73+
for b in range(self.model.num_par_blocks()):
74+
indices, _, _ = self.model.par_block_indices(b)
75+
76+
if self.node_subblock_size[b] is None:
77+
indices = [indices]
78+
else:
79+
indices = list(chunk_evenly(indices, self.node_subblock_size[b]))
80+
81+
for i in range(len(indices)):
82+
kernel = NormalKernel(proposed['sample'][indices[i]], self.scales[b])
83+
84+
proposed['sample'][indices[i]] = kernel.sample()
85+
86+
x_batch, y_batch = self.batcher.get_batch(
87+
self.batch_model,
88+
[self.current['sample'].clone().detach(), proposed['sample'].clone().detach()],
89+
fill=True
90+
)
91+
92+
self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x_batch, y_batch)
93+
proposed['target_val'] = self.model.log_target(proposed['sample'].clone().detach(), x_batch, y_batch)
94+
95+
log_rate = proposed['target_val'] - self.current['target_val']
96+
if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate:
97+
self.current['sample'][indices[i]] = proposed['sample'][indices[i]]
98+
self.current['accepted'].append(1)
99+
else:
100+
self.model.set_params(self.current['sample'].clone().detach())
101+
self.current['accepted'].append(0)
102+
103+
self.current['accepted'] = torch.tensor(self.current['accepted'], device=self.model.device)
104+
105+
if savestate:
106+
self.chain.detach_and_update(self.current)
107+
108+
self.current['sample'].detach_()
109+
self.current['target_val'].detach_()

0 commit comments

Comments
 (0)