-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial implementation for #3 and results on current datasets
- Loading branch information
1 parent
3412b6a
commit b43f7d0
Showing
11 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+239 KB
results/adaptive_HC2/s2ocgp/S2OCGP_ScaleRBFKernel_adaptive_HC2_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+212 KB
results/adaptive_HC3/s2ocgp/S2OCGP_ScaleRBFKernel_adaptive_HC3_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import torch | ||
import gpytorch | ||
import matplotlib | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import pathlib | ||
import time | ||
from tqdm import tqdm, trange | ||
from scgp.kernels import RBFKernelSecondGrad | ||
|
||
PATH = pathlib.Path(__file__).parent.parent.parent.absolute() | ||
DATA_PATH = PATH / "data" | ||
LOGGING = False | ||
|
||
|
||
matplotlib.rc('text', usetex=True) | ||
matplotlib.rc('font', family='serif') | ||
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}') | ||
|
||
|
||
# Loading data | ||
def load_data(path: str) -> tuple: | ||
data = np.loadtxt(path, delimiter=",", skiprows=1, dtype=np.float32) | ||
x_train = torch.from_numpy(data[:, 0]) | ||
y_train = torch.stack( | ||
[torch.from_numpy(data[:, 1]), torch.from_numpy(data[:, 3])], -1 | ||
).squeeze(1) | ||
return x_train, y_train | ||
|
||
|
||
# Read http://www.gaussianprocess.org/gpml/chapters/RW9.pdf | ||
class SCGP(gpytorch.models.ExactGP): | ||
def __init__(self, train_x, train_y, likelihood, kernel): | ||
super(SCGP, self).__init__(train_x, train_y, likelihood) | ||
self.mean_module = gpytorch.means.ConstantMeanGrad() | ||
self.base_kernel = kernel | ||
self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel) | ||
|
||
def forward(self, x): | ||
mean_x = self.mean_module(x) | ||
covar_x = self.covar_module(x) | ||
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, | ||
covar_x) | ||
|
||
|
||
def scgp_fit(path: str, iters: int, kernel: gpytorch.kernels.Kernel) -> tuple: | ||
train_x, train_y = load_data(path) | ||
|
||
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood( | ||
num_tasks=2 | ||
) # y and y_second_prime | ||
model = SCGP(train_x, train_y, likelihood, kernel) | ||
|
||
# Optimizing hyperparameters via marginal log likelihood | ||
model.train() | ||
likelihood.train() | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) | ||
|
||
for i in trange(iters): | ||
optimizer.zero_grad() | ||
output = model(train_x) | ||
loss = -mll(output, train_y) | ||
loss.backward() | ||
if LOGGING: | ||
tqdm.write( | ||
f"Iter {i + 1}/{iters} - Loss: {loss.item():.3f}" | ||
"noise: {model.likelihood.noise.item():.3f}" | ||
) | ||
optimizer.step() | ||
|
||
return train_x, train_y, model, likelihood | ||
|
||
|
||
def save_plot_scpg( | ||
train_x: torch.Tensor, | ||
train_y: torch.Tensor, | ||
model: SCGP, | ||
likelihood: gpytorch.likelihoods.MultitaskGaussianLikelihood, | ||
ker_name: str, | ||
dataset_name: str | ||
) -> None: | ||
|
||
# Name of the plot | ||
name = f"S2OCGP_Scale{ker_name}_{dataset_name}_test" | ||
|
||
# Evaluation mode | ||
model.train() | ||
model.eval() | ||
likelihood.eval() | ||
|
||
# Initialize plots | ||
f, (y_ax, y_prime_ax) = plt.subplots(1, 2, | ||
figsize=(10, 4), tight_layout=True) | ||
|
||
# Make predictions | ||
with torch.no_grad(), gpytorch.settings.max_cg_iterations(100): | ||
test_x = torch.linspace(0, 1, 100) | ||
predictions = likelihood(model(test_x)) | ||
mean = predictions.mean | ||
lower, upper = predictions.confidence_region() | ||
|
||
# Plotting predictions for f | ||
y_ax.plot(train_x.detach().numpy(), train_y[:, 0].detach().numpy(), "k*") | ||
y_ax.plot(test_x.numpy(), mean[:, 0].numpy(), "b") | ||
y_ax.fill_between( | ||
test_x.numpy(), lower[:, 0].numpy(), upper[:, 0].numpy(), alpha=0.5 | ||
) | ||
y_ax.legend(["Observed Values", "Mean", "Confidence"]) | ||
y_ax.set_title("Function values") | ||
y_ax.set_xlim([0, 1]) | ||
# y_ax.set_ylim([-7.5, 12.5]) | ||
y_ax.set_xlabel(r"$\alpha$") | ||
y_ax.set_ylabel(r"$f_{\boldsymbol{z}}(\alpha)$") | ||
|
||
# Plotting predictions for f' | ||
y_prime_ax.plot(train_x.detach().numpy(), train_y[:, 1].detach().numpy(), | ||
"k*") | ||
y_prime_ax.plot(test_x.numpy(), mean[:, 1].numpy(), "b") | ||
y_prime_ax.fill_between( | ||
test_x.numpy(), lower[:, 1].numpy(), upper[:, 1].numpy(), alpha=0.5 | ||
) | ||
y_prime_ax.legend(["Observed Derivatives", "Mean", "Confidence"]) | ||
y_prime_ax.set_title(r"Second Order Derivatives with respect to $\alpha$") | ||
y_prime_ax.set_xlim([0, 1]) | ||
|
||
y_prime_ax.set_xlabel(r"$\alpha$") | ||
y_prime_ax.set_ylabel(r"$\frac{\mathrm{d}^2}{\mathrm{d}\alpha^2}f_{\boldsymbol{z}}(\alpha)$") | ||
|
||
save_path = PATH / "results" / str(dataset_name) / 's2ocgp' | ||
save_path.mkdir(parents=True, exist_ok=True) | ||
|
||
f.savefig(save_path / (name + ".png"), dpi=600, bbox_inches="tight") | ||
plt.close(f) | ||
|
||
|
||
if __name__ == "__main__": | ||
datasets = { | ||
"uniform": DATA_PATH / "Gaussian_logCA0_uniform_J=20.csv", | ||
"adaptive": DATA_PATH / "Gaussian_logCA0_adaptive_J=20.csv", | ||
"uniform_HC": DATA_PATH / "Gaussian_HC_logCA0_uniform_J=20.csv", | ||
"adaptive_HC": DATA_PATH / "Gaussian_HC_logCA0_adaptive_J=20.csv", | ||
"uniform_HC2": DATA_PATH / "Gaussian_HC_logCA0_uniform_J=20_HC2.csv", | ||
"adaptive_HC2": DATA_PATH / "Gaussian_HC_logCA0_adaptive_J=20_HC2.csv", | ||
"uniform_HC3": DATA_PATH / "Gaussian_HC_logCA0_uniform_J=20_HC3.csv", | ||
"adaptive_HC3": DATA_PATH / "Gaussian_HC_logCA0_adaptive_J=20_HC3.csv", | ||
} | ||
|
||
kernels = {"RBFKernel": RBFKernelSecondGrad(),} | ||
|
||
for ker_name, kernel in kernels.items(): | ||
print(f"Running {ker_name}") | ||
|
||
for dataset_name, dataset_path in datasets.items(): | ||
start_time = time.time() | ||
try: | ||
train_x, train_y, data_scgp, data_likelihood = scgp_fit( | ||
dataset_path, 100, kernel=kernel | ||
) | ||
save_plot_scpg( | ||
train_x, | ||
train_y, | ||
data_scgp, | ||
data_likelihood, | ||
ker_name, | ||
dataset_name | ||
) | ||
final_time = time.time() - start_time | ||
print( | ||
f"Finished {ker_name} on {dataset_name} took {final_time:.2f}s" | ||
) | ||
except Exception as e: | ||
print(f"Error on {ker_name} on {dataset_name} took {e}") | ||
continue |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import torch | ||
from linear_operator.operators import KroneckerProductLinearOperator | ||
from gpytorch.kernels.rbf_kernel import postprocess_rbf, RBFKernel | ||
|
||
|
||
class RBFKernelSecondGrad(RBFKernel): | ||
def forward(self, x1, x2, diag=False, **params): | ||
batch_shape = x1.shape[:-2] | ||
n_batch_dims = len(batch_shape) | ||
n1, d = x1.shape[-2:] | ||
n2 = x2.shape[-2] | ||
|
||
K = torch.zeros( | ||
*batch_shape, | ||
n1 * (2 * d + 1), | ||
n2 * (2 * d + 1), | ||
device=x1.device, | ||
dtype=x1.dtype | ||
) | ||
final_K = torch.zeros( | ||
*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype | ||
) | ||
|
||
# Scale the inputs by the lengthscale (for stability) | ||
x1_ = x1.div(self.lengthscale) | ||
x2_ = x2.div(self.lengthscale) | ||
|
||
# Form all possible rank-1 products for the gradient and Hessian blocks | ||
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d) | ||
outer = outer / self.lengthscale.unsqueeze(-2) | ||
outer = torch.transpose(outer, -1, -2).contiguous() | ||
|
||
# 1) Kernel block | ||
diff = self.covar_dist(x1_, x2_, square_dist=True, **params) | ||
K_11 = postprocess_rbf(diff) | ||
K[..., :n1, :n2] = K_11 | ||
final_K[..., :n1, :n2] = K_11 | ||
|
||
# 2) First gradient block | ||
outer1 = outer.view(*batch_shape, n1, n2 * d) | ||
K[..., :n1, n2 : (n2 * (d + 1))] = outer1 * K_11.repeat( | ||
[*([1] * (n_batch_dims + 1)), d] | ||
) | ||
|
||
# 3) Second gradient block | ||
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d) | ||
outer2 = outer2.transpose(-1, -2) | ||
K[..., n1 : (n1 * (d + 1)), :n2] = -outer2 * K_11.repeat( | ||
[*([1] * n_batch_dims), d, 1] | ||
) | ||
|
||
# 4) Hessian block | ||
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat( | ||
[*([1] * (n_batch_dims + 1)), d] | ||
) | ||
kp = KroneckerProductLinearOperator( | ||
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) | ||
/ self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat( | ||
*batch_shape, 1, 1 | ||
), | ||
) | ||
chain_rule = kp.to_dense() - outer3 | ||
K[..., n1 : (n1 * (d + 1)), n2 : (n2 * (d + 1))] = chain_rule * K_11.repeat( | ||
[*([1] * n_batch_dims), d, d] | ||
) | ||
|
||
# 5) 1-3 block | ||
douter1dx2 = KroneckerProductLinearOperator( | ||
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat( | ||
*batch_shape, 1, 1 | ||
) | ||
/ self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat( | ||
*batch_shape, 1, 1 | ||
), | ||
).to_dense() | ||
|
||
K_13 = (-douter1dx2 + outer1 * outer1) * K_11.repeat( | ||
[*([1] * (n_batch_dims + 1)), d] | ||
) # verified for n1=n2=1 case | ||
K[..., :n1, (n2 * (d + 1)) :] = K_13 | ||
final_K[..., :n1, n2:] = K_13 | ||
|
||
K_31 = (-douter1dx2.transpose(-1, -2) + outer2 * outer2) * K_11.repeat( | ||
[*([1] * n_batch_dims), d, 1] | ||
) # verified for n1=n2=1 case | ||
K[..., (n1 * (d + 1)) :, :n2] = K_31 | ||
final_K[..., n1:, :n2] = K_31 | ||
|
||
# rest of the blocks are all of size (n1*d,n2*d) | ||
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1]) | ||
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d]) | ||
# II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2) | ||
kp2 = KroneckerProductLinearOperator( | ||
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat( | ||
*batch_shape, 1, 1 | ||
) | ||
/ self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat( | ||
*batch_shape, 1, 1 | ||
), | ||
).to_dense() | ||
|
||
# II may not be the correct thing to use. It might be more appropriate to use kp instead?? | ||
II = kp.to_dense() | ||
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d]) | ||
|
||
K_23 = ( | ||
(-kp2 + outer1 * outer1) * (-outer2) + 2.0 * II * outer1 | ||
) * K_11dd # verified for n1=n2=1 case | ||
|
||
K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23 | ||
|
||
K_32 = ( | ||
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2 | ||
) * K_11dd # verified for n1=n2=1 case | ||
|
||
K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32 | ||
|
||
K_33 = ( | ||
(-kp2.transpose(-1, -2) + outer2 * outer2) * (-kp2) | ||
- 2.0 * II * outer2 * outer1 | ||
+ 2.0 * (II) ** 2 | ||
) * K_11dd + ( | ||
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2 | ||
) * outer1 * K_11dd # verified for n1=n2=1 case | ||
|
||
K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33 | ||
final_K[..., n1:, n2:] = K_33 | ||
|
||
# Symmetrize for stability | ||
if n1 == n2 and torch.eq(x1, x2).all(): | ||
final_K = 0.5 * (final_K.transpose(-1, -2) + final_K) | ||
|
||
# Apply a perfect shuffle permutation to match the MutiTask ordering | ||
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1))) | ||
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) | ||
final_K = final_K[..., pi1, :][..., :, pi2] | ||
|
||
return final_K | ||
|
||
def num_outputs_per_input(self, x1, x2): | ||
return x1.size(-1) + 1 |