Skip to content

Commit 0956fdb

Browse files
committed
Made all samplers operational with mini-batching
1 parent 8e3910a commit 0956fdb

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

eeyore/samplers/am.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def set_recursive_cov(self, n, offset=0):
6161
def draw(self, x, y, savestate=False, offset=0):
6262
proposed = {key : None for key in self.keys}
6363

64+
if self.counter.num_batches != 1:
65+
self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y)
66+
6467
randn_sample = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device)
6568
if (self.counter.idx + 1 - offset > self.t0):
6669
if torch.rand(1, dtype=self.model.dtype, device=self.model.device) < self.l:
@@ -76,7 +79,8 @@ def draw(self, x, y, savestate=False, offset=0):
7679

7780
if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate:
7881
self.current['sample'] = proposed['sample'].clone().detach()
79-
self.current['target_val'] = proposed['target_val'].clone().detach()
82+
if self.counter.num_batches == 1:
83+
self.current['target_val'] = proposed['target_val'].clone().detach()
8084
self.current['accepted'] = 1
8185
if (self.counter.idx > 0):
8286
self.num_accepted = self.num_accepted + 1

eeyore/samplers/hmc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def leapfrog(self, position0, momentum0, x, y):
126126
def draw(self, x, y, savestate=False):
127127
proposed = {key : None for key in self.keys}
128128

129+
if self.counter.num_batches != 1:
130+
self.current['target_val'], self.current['grad_val'] = \
131+
self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y)
132+
129133
proposed['sample'] = self.current['sample'].clone().detach()
130134
proposed['momentum'] = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device)
131135

@@ -143,8 +147,9 @@ def draw(self, x, y, savestate=False):
143147

144148
if torch.rand(1, dtype=self.model.dtype, device=self.model.device) < rate:
145149
self.current['sample'] = proposed['sample'].clone().detach()
146-
self.current['target_val'] = proposed['target_val'].clone().detach()
147-
self.current['grad_val'] = proposed['grad_val'].clone().detach()
150+
if self.counter.num_batches == 1:
151+
self.current['target_val'] = proposed['target_val'].clone().detach()
152+
self.current['grad_val'] = proposed['grad_val'].clone().detach()
148153
self.current['accepted'] = 1
149154
else:
150155
self.model.set_params(self.current['sample'].clone().detach())

eeyore/samplers/mala.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def set_kernel(self, state):
4646
def draw(self, x, y, savestate=False):
4747
proposed = {key : None for key in self.keys}
4848

49+
if self.counter.num_batches != 1:
50+
self.current['target_val'], self.current['grad_val'] = \
51+
self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y)
52+
4953
proposed['sample'] = self.kernel.sample()
5054

5155
proposed['target_val'], proposed['grad_val'] = \
@@ -61,8 +65,9 @@ def draw(self, x, y, savestate=False):
6165

6266
if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate:
6367
self.current['sample'] = proposed['sample'].clone().detach()
64-
self.current['target_val'] = proposed['target_val'].clone().detach()
65-
self.current['grad_val'] = proposed['grad_val'].clone().detach()
68+
if self.counter.num_batches == 1:
69+
self.current['target_val'] = proposed['target_val'].clone().detach()
70+
self.current['grad_val'] = proposed['grad_val'].clone().detach()
6671
self.current['accepted'] = 1
6772
else:
6873
self.model.set_params(self.current['sample'].clone().detach())

eeyore/samplers/ram.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def set_all(self, theta, data=None, cov=None):
3838
def draw(self, x, y, savestate=False, offset=0):
3939
proposed = {key : None for key in self.keys}
4040

41+
if self.counter.num_batches != 1:
42+
self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y)
43+
4144
randn_sample = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device)
4245
proposed['sample'] = self.current['sample'].clone().detach() + self.chol_cov @ randn_sample
4346
proposed['target_val'] = self.model.log_target(proposed['sample'].clone().detach(), x, y)
@@ -46,7 +49,8 @@ def draw(self, x, y, savestate=False, offset=0):
4649

4750
if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate:
4851
self.current['sample'] = proposed['sample'].clone().detach()
49-
self.current['target_val'] = proposed['target_val'].clone().detach()
52+
if self.counter.num_batches == 1:
53+
self.current['target_val'] = proposed['target_val'].clone().detach()
5054
self.current['accepted'] = 1
5155
else:
5256
self.model.set_params(self.current['sample'].clone().detach())

0 commit comments

Comments
 (0)