Skip to content

Commit

Permalink
Initial implementation for #3 and results on current datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
adamesalles committed Apr 24, 2024
1 parent 3412b6a commit b43f7d0
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 0 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
175 changes: 175 additions & 0 deletions src/prototypes/gpytorch-scgp-second-order-only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import torch
import gpytorch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import time
from tqdm import tqdm, trange
from scgp.kernels import RBFKernelSecondGrad

PATH = pathlib.Path(__file__).parent.parent.parent.absolute()
DATA_PATH = PATH / "data"
LOGGING = False


matplotlib.rc('text', usetex=True)
matplotlib.rc('font', family='serif')
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')


# Loading data
def load_data(path: str) -> tuple:
data = np.loadtxt(path, delimiter=",", skiprows=1, dtype=np.float32)
x_train = torch.from_numpy(data[:, 0])
y_train = torch.stack(
[torch.from_numpy(data[:, 1]), torch.from_numpy(data[:, 3])], -1
).squeeze(1)
return x_train, y_train


# Read http://www.gaussianprocess.org/gpml/chapters/RW9.pdf
class SCGP(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood, kernel):
super(SCGP, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMeanGrad()
self.base_kernel = kernel
self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel)

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x,
covar_x)


def scgp_fit(path: str, iters: int, kernel: gpytorch.kernels.Kernel) -> tuple:
train_x, train_y = load_data(path)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
num_tasks=2
) # y and y_second_prime
model = SCGP(train_x, train_y, likelihood, kernel)

# Optimizing hyperparameters via marginal log likelihood
model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in trange(iters):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
if LOGGING:
tqdm.write(
f"Iter {i + 1}/{iters} - Loss: {loss.item():.3f}"
"noise: {model.likelihood.noise.item():.3f}"
)
optimizer.step()

return train_x, train_y, model, likelihood


def save_plot_scpg(
train_x: torch.Tensor,
train_y: torch.Tensor,
model: SCGP,
likelihood: gpytorch.likelihoods.MultitaskGaussianLikelihood,
ker_name: str,
dataset_name: str
) -> None:

# Name of the plot
name = f"S2OCGP_Scale{ker_name}_{dataset_name}_test"

# Evaluation mode
model.train()
model.eval()
likelihood.eval()

# Initialize plots
f, (y_ax, y_prime_ax) = plt.subplots(1, 2,
figsize=(10, 4), tight_layout=True)

# Make predictions
with torch.no_grad(), gpytorch.settings.max_cg_iterations(100):
test_x = torch.linspace(0, 1, 100)
predictions = likelihood(model(test_x))
mean = predictions.mean
lower, upper = predictions.confidence_region()

# Plotting predictions for f
y_ax.plot(train_x.detach().numpy(), train_y[:, 0].detach().numpy(), "k*")
y_ax.plot(test_x.numpy(), mean[:, 0].numpy(), "b")
y_ax.fill_between(
test_x.numpy(), lower[:, 0].numpy(), upper[:, 0].numpy(), alpha=0.5
)
y_ax.legend(["Observed Values", "Mean", "Confidence"])
y_ax.set_title("Function values")
y_ax.set_xlim([0, 1])
# y_ax.set_ylim([-7.5, 12.5])
y_ax.set_xlabel(r"$\alpha$")
y_ax.set_ylabel(r"$f_{\boldsymbol{z}}(\alpha)$")

# Plotting predictions for f'
y_prime_ax.plot(train_x.detach().numpy(), train_y[:, 1].detach().numpy(),
"k*")
y_prime_ax.plot(test_x.numpy(), mean[:, 1].numpy(), "b")
y_prime_ax.fill_between(
test_x.numpy(), lower[:, 1].numpy(), upper[:, 1].numpy(), alpha=0.5
)
y_prime_ax.legend(["Observed Derivatives", "Mean", "Confidence"])
y_prime_ax.set_title(r"Second Order Derivatives with respect to $\alpha$")
y_prime_ax.set_xlim([0, 1])

y_prime_ax.set_xlabel(r"$\alpha$")
y_prime_ax.set_ylabel(r"$\frac{\mathrm{d}^2}{\mathrm{d}\alpha^2}f_{\boldsymbol{z}}(\alpha)$")

save_path = PATH / "results" / str(dataset_name) / 's2ocgp'
save_path.mkdir(parents=True, exist_ok=True)

f.savefig(save_path / (name + ".png"), dpi=600, bbox_inches="tight")
plt.close(f)


if __name__ == "__main__":
datasets = {
"uniform": DATA_PATH / "Gaussian_logCA0_uniform_J=20.csv",
"adaptive": DATA_PATH / "Gaussian_logCA0_adaptive_J=20.csv",
"uniform_HC": DATA_PATH / "Gaussian_HC_logCA0_uniform_J=20.csv",
"adaptive_HC": DATA_PATH / "Gaussian_HC_logCA0_adaptive_J=20.csv",
"uniform_HC2": DATA_PATH / "Gaussian_HC_logCA0_uniform_J=20_HC2.csv",
"adaptive_HC2": DATA_PATH / "Gaussian_HC_logCA0_adaptive_J=20_HC2.csv",
"uniform_HC3": DATA_PATH / "Gaussian_HC_logCA0_uniform_J=20_HC3.csv",
"adaptive_HC3": DATA_PATH / "Gaussian_HC_logCA0_adaptive_J=20_HC3.csv",
}

kernels = {"RBFKernel": RBFKernelSecondGrad(),}

for ker_name, kernel in kernels.items():
print(f"Running {ker_name}")

for dataset_name, dataset_path in datasets.items():
start_time = time.time()
try:
train_x, train_y, data_scgp, data_likelihood = scgp_fit(
dataset_path, 100, kernel=kernel
)
save_plot_scpg(
train_x,
train_y,
data_scgp,
data_likelihood,
ker_name,
dataset_name
)
final_time = time.time() - start_time
print(
f"Finished {ker_name} on {dataset_name} took {final_time:.2f}s"
)
except Exception as e:
print(f"Error on {ker_name} on {dataset_name} took {e}")
continue
Empty file added src/scgp/__init__.py
Empty file.
144 changes: 144 additions & 0 deletions src/scgp/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
from linear_operator.operators import KroneckerProductLinearOperator
from gpytorch.kernels.rbf_kernel import postprocess_rbf, RBFKernel


class RBFKernelSecondGrad(RBFKernel):
def forward(self, x1, x2, diag=False, **params):
batch_shape = x1.shape[:-2]
n_batch_dims = len(batch_shape)
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

K = torch.zeros(
*batch_shape,
n1 * (2 * d + 1),
n2 * (2 * d + 1),
device=x1.device,
dtype=x1.dtype
)
final_K = torch.zeros(
*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype
)

# Scale the inputs by the lengthscale (for stability)
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)

# Form all possible rank-1 products for the gradient and Hessian blocks
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d)
outer = outer / self.lengthscale.unsqueeze(-2)
outer = torch.transpose(outer, -1, -2).contiguous()

# 1) Kernel block
diff = self.covar_dist(x1_, x2_, square_dist=True, **params)
K_11 = postprocess_rbf(diff)
K[..., :n1, :n2] = K_11
final_K[..., :n1, :n2] = K_11

# 2) First gradient block
outer1 = outer.view(*batch_shape, n1, n2 * d)
K[..., :n1, n2 : (n2 * (d + 1))] = outer1 * K_11.repeat(
[*([1] * (n_batch_dims + 1)), d]
)

# 3) Second gradient block
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
outer2 = outer2.transpose(-1, -2)
K[..., n1 : (n1 * (d + 1)), :n2] = -outer2 * K_11.repeat(
[*([1] * n_batch_dims), d, 1]
)

# 4) Hessian block
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat(
[*([1] * (n_batch_dims + 1)), d]
)
kp = KroneckerProductLinearOperator(
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1)
/ self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(
*batch_shape, 1, 1
),
)
chain_rule = kp.to_dense() - outer3
K[..., n1 : (n1 * (d + 1)), n2 : (n2 * (d + 1))] = chain_rule * K_11.repeat(
[*([1] * n_batch_dims), d, d]
)

# 5) 1-3 block
douter1dx2 = KroneckerProductLinearOperator(
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(
*batch_shape, 1, 1
)
/ self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(
*batch_shape, 1, 1
),
).to_dense()

K_13 = (-douter1dx2 + outer1 * outer1) * K_11.repeat(
[*([1] * (n_batch_dims + 1)), d]
) # verified for n1=n2=1 case
K[..., :n1, (n2 * (d + 1)) :] = K_13
final_K[..., :n1, n2:] = K_13

K_31 = (-douter1dx2.transpose(-1, -2) + outer2 * outer2) * K_11.repeat(
[*([1] * n_batch_dims), d, 1]
) # verified for n1=n2=1 case
K[..., (n1 * (d + 1)) :, :n2] = K_31
final_K[..., n1:, :n2] = K_31

# rest of the blocks are all of size (n1*d,n2*d)
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1])
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d])
# II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2)
kp2 = KroneckerProductLinearOperator(
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(
*batch_shape, 1, 1
)
/ self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(
*batch_shape, 1, 1
),
).to_dense()

# II may not be the correct thing to use. It might be more appropriate to use kp instead??
II = kp.to_dense()
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d])

K_23 = (
(-kp2 + outer1 * outer1) * (-outer2) + 2.0 * II * outer1
) * K_11dd # verified for n1=n2=1 case

K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23

K_32 = (
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
) * K_11dd # verified for n1=n2=1 case

K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32

K_33 = (
(-kp2.transpose(-1, -2) + outer2 * outer2) * (-kp2)
- 2.0 * II * outer2 * outer1
+ 2.0 * (II) ** 2
) * K_11dd + (
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
) * outer1 * K_11dd # verified for n1=n2=1 case

K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33
final_K[..., n1:, n2:] = K_33

# Symmetrize for stability
if n1 == n2 and torch.eq(x1, x2).all():
final_K = 0.5 * (final_K.transpose(-1, -2) + final_K)

# Apply a perfect shuffle permutation to match the MutiTask ordering
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
final_K = final_K[..., pi1, :][..., :, pi2]

return final_K

def num_outputs_per_input(self, x1, x2):
return x1.size(-1) + 1

0 comments on commit b43f7d0

Please sign in to comment.