Skip to content

Commit

Permalink
Bugfixes and a 10% speed improvement due to caching parameter transfo…
Browse files Browse the repository at this point in the history
…rmation
  • Loading branch information
tdewolff committed Dec 6, 2023
1 parent 336bb57 commit ffda267
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 32 deletions.
3 changes: 2 additions & 1 deletion mogptk/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(self, *args, names=None):
self.channels = []
if len(args) == 2 and (isinstance(args[1], np.ndarray) or isinstance(args[1], list) and all(isinstance(item, np.ndarray) for item in args[1])):
if names is None or isinstance(names, str):
names = [names]
names = [names]*len(args[0])

if isinstance(args[0], np.ndarray):
for name, y in zip(names, args[1]):
Expand Down
4 changes: 2 additions & 2 deletions mogptk/gpr/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

class Config:
dtype = torch.float64
dtype = torch.float32
if torch.cuda.is_available():
device = torch.device('cuda', torch.cuda.current_device())
else:
Expand All @@ -14,7 +14,7 @@ def use_half_precision():
Use half precision (float16) for all tensors. This may be much faster on GPUs, but has reduced precision and may more often cause numerical instability. Only recommended on GPUs.
"""
if config.device.type == 'cpu':
print('WARNING: half precision not recommend on CPU')
print('WARNING: half precision not recommended on CPU')
config.dtype = torch.float16

def use_single_precision():
Expand Down
20 changes: 10 additions & 10 deletions mogptk/gpr/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,17 +329,18 @@ def K(self, X1, X2=None):
c1 = X1[:,0].long()
m1 = [c1==i for i in range(self.output_dims)]
x1 = [X1[m1[i],1:] for i in range(self.output_dims)]
r1 = [torch.nonzero(m1[i], as_tuple=False) for i in range(self.output_dims)] # as_tuple avoids warning
r1 = [torch.nonzero(m1[i], as_tuple=False) for i in range(self.output_dims)]

if X2 is None:
r2 = [r1[i].reshape(1,-1) for i in range(self.output_dims)]
res = torch.empty(X1.shape[0], X1.shape[0], device=config.device, dtype=config.dtype) # N1 x N1
# calculate lower triangle of main kernel matrix, the upper triangle is a transpose

res = torch.empty(X1.shape[0], X1.shape[0], device=config.device, dtype=config.dtype) # NxM
for i in range(self.output_dims):
for j in range(i+1):
# calculate sub kernel matrix and add to main kernel matrix
if i == j:
res[r1[i],r2[i]] = self.Ksub(i, i, x1[i])
k = self.Ksub(i, i, x1[i])
res[r1[i],r2[i]] = k
else:
k = self.Ksub(i, j, x1[i], x1[j])
res[r1[i],r2[j]] = k
Expand All @@ -349,9 +350,9 @@ def K(self, X1, X2=None):
c2 = X2[:,0].long()
m2 = [c2==j for j in range(self.output_dims)]
x2 = [X2[m2[j],1:] for j in range(self.output_dims)]
r2 = [torch.nonzero(m2[j], as_tuple=False).reshape(1,-1) for j in range(self.output_dims)] # as_tuple avoids warning
r2 = [torch.nonzero(m2[j], as_tuple=False).reshape(1,-1) for j in range(self.output_dims)]

res = torch.empty(X1.shape[0], X2.shape[0], device=config.device, dtype=config.dtype) # N1 x N2
res = torch.empty(X1.shape[0], X2.shape[0], device=config.device, dtype=config.dtype) # NxM
for i in range(self.output_dims):
for j in range(self.output_dims):
# calculate sub kernel matrix and add to main kernel matrix
Expand All @@ -363,12 +364,11 @@ def K_diag(self, X1):
# extract channel mask, get data, and find indices that belong to the channels
c1 = X1[:,0].long()
m1 = [c1==i for i in range(self.output_dims)]
x1 = [X1[m1[i],1:] for i in range(self.output_dims)] # I is broadcastable with last dimension in X
r1 = [torch.nonzero(m1[i], as_tuple=False)[:,0] for i in range(self.output_dims)] # as_tuple avoids warning
x1 = [X1[m1[i],1:] for i in range(self.output_dims)]
r1 = [torch.nonzero(m1[i], as_tuple=False)[:,0] for i in range(self.output_dims)]

res = torch.empty(X1.shape[0], device=config.device, dtype=config.dtype) # N1 x N1
res = torch.empty(X1.shape[0], device=config.device, dtype=config.dtype) # NxM

# calculate lower triangle of main kernel matrix, the upper triangle is a transpose
for i in range(self.output_dims):
# calculate sub kernel matrix and add to main kernel matrix
res[r1[i]] = self.Ksub_diag(i, x1[i])
Expand Down
23 changes: 15 additions & 8 deletions mogptk/gpr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,9 @@ def _register_parameters(self, obj, name=None):

def zero_grad(self):
for p in self._params:
p = p.unconstrained
if p.grad is not None:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()
p._constrained = None
if p.unconstrained.grad is not None:
p.unconstrained.grad = None

def parameters(self):
"""
Expand All @@ -199,6 +195,17 @@ def parameters(self):
if p.train:
yield p.unconstrained

def named_parameters(self):
"""
Yield trainable parameters of model.
Returns:
Parameter generator
"""
for p in self._params:
if p.train:
yield p.name, p.unconstrained

def get_parameters(self):
"""
Return all parameters of model.
Expand Down Expand Up @@ -309,7 +316,7 @@ def loss(self):
self.zero_grad()
loss = -self.log_marginal_likelihood() - self.log_prior()
loss.backward()
return float(loss)
return loss

def K(self, X1, X2=None):
"""
Expand Down
22 changes: 14 additions & 8 deletions mogptk/gpr/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, value, name=None, lower=None, upper=None, prior=None, train=T
self.unconstrained = None
self.pegged_parameter = None
self.pegged_transform = None
self._constrained = None

self.assign(value, lower=lower, upper=upper)

Expand Down Expand Up @@ -129,14 +130,17 @@ def constrained(self):
Returns:
torch.tensor
"""
if self.pegged:
other = self.pegged_parameter.constrained
if self.pegged_transform is not None:
other = self.pegged_transform(other)
return other
if self.transform is not None:
return self.transform.forward(self.unconstrained)
return self.unconstrained
if self._constrained is None:
if self.pegged:
other = self.pegged_parameter.constrained
if self.pegged_transform is not None:
other = self.pegged_transform(other)
self._constrained = other
elif self.transform is not None:
self._constrained = self.transform.forward(self.unconstrained)
else:
self._constrained = self.unconstrained
return self._constrained

def numpy(self):
"""
Expand Down Expand Up @@ -255,6 +259,7 @@ def assign(self, value=None, name=None, lower=None, upper=None, prior=None, trai
self.unconstrained = value
self.pegged_parameter = None
self.pegged_transform = None
self._constrained = None

def peg(self, other, transform=None):
"""
Expand All @@ -271,6 +276,7 @@ def peg(self, other, transform=None):
self.pegged_parameter = other
self.pegged_transform = transform
self.train = False
self._constrained = None

def log_prior(self):
"""
Expand Down
6 changes: 3 additions & 3 deletions mogptk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def log_marginal_likelihood(self):
Examples:
>>> model.log_marginal_likelihood()
"""
return self.gpr.log_marginal_likelihood().detach().cpu().item()
return float(self.gpr.log_marginal_likelihood())

def BIC(self):
"""
Expand Down Expand Up @@ -353,7 +353,7 @@ def loss(self):
Examples:
>>> model.loss()
"""
return self.gpr.loss()
return float(self.gpr.loss())

def error(self, method='MAE', use_all_data=False):
"""
Expand Down Expand Up @@ -483,7 +483,7 @@ def train(
initial_time = time.time()
progress_time = 0.0

iters_len = int(math.log10(iter_offset+iters)) + 1
iters_len = 1 if iters == 0 else int(math.log10(iter_offset+iters)) + 1
def progress(i, loss, last=False):
nonlocal progress_time

Expand Down

0 comments on commit ffda267

Please sign in to comment.