-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Documentation about Multi-Output Regression #20
Comments
Hi @vdsmax! I've put together a simple MOGP model (not using the example) which might better suit your use case. The script uses JAX to learn hyperparameters. (You can also use another AD framework if you like.) from stheno.jax import GP, Matern52, Measure
from varz.jax import Vars, minimise_l_bfgs_b
from wbml.plot import tweak
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np
x1 = np.linspace(0, 10, 30)
x2 = np.linspace(0, 9, 40)
x3 = np.linspace(0, 7, 50)
# Generate some test data.
f = GP(Matern52())
y1 = f(x1, 0.2).sample().flatten()
y2 = f(x2, 0.2).sample().flatten()
y3 = f(x3, 0.2).sample().flatten()
p = 3 # Number of outputs
m = 3 # Number of latent processes
def model(vs):
ps = vs.struct
with Measure() as prior:
# Create independent processes with learnable length scales initialised to `1`.
us = [
GP(Matern52().stretch(ps_u.scale.positive(1)))
for ps_u, _ in zip(ps.us, range(p))
]
# Mix processes together to induce correlations between the outputs.
H = ps.mixing_matrix.unbounded(shape=(p, m))
fs = [0 for _ in range(p)]
for i in range(p):
for j in range(m):
fs[i] = fs[i] + H[i, j] * us[j]
# Create learnable observation noises initialised to `0.1`
noises = ps.noises.positive(0.1, shape=(p,))
return prior, fs, noises
def objective(vs):
prior, fs, noises = model(vs)
return -prior.logpdf(
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
# Perform learning.
vs = Vars(jnp.float64)
minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
vs.print() # Display learned parameters.
# Compute posterior and predictions.
prior, fs, noises = model(vs)
posterior = prior | (
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
f1_post = posterior(fs[0])
f2_post = posterior(fs[1])
f3_post = posterior(fs[2])
def plot_posterior(x, f, x_obs=None, y_obs=None):
if x_obs is not None:
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = f(x).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
# Plot results.
plt.figure(figsize=(10, 6))
x_to_plot = np.linspace(0, 10, 200)
plt.subplot(3, 1, 1)
plt.title("Output 1")
plot_posterior(x_to_plot, f1_post, x1, y1)
plt.subplot(3, 1, 2)
plt.title("Output 2")
plot_posterior(x_to_plot, f2_post, x2, y2)
plt.subplot(3, 1, 3)
plt.title("Output 3")
plot_posterior(x_to_plot, f3_post, x3, y3)
plt.show() The script produces the following plot: Let me know if this suits your needs. :) |
Thank you very much for your code example. It is running on my side too, and I have the same results by using my CPU. Because the computational time is high for nine inputs by using a CPU, I wanted to use my GPU to see if it will be faster. I followed the steps to use CUDA with the Jax library and was able to link both of them. However, by using the same code as you give me, I obtained this time an error:
Do I need to add something to the code to make it work with a GPU ? |
Ouch! That doesn't look good. Could you confirm that running other JAX code on the GPU works fine? If that's the case, I can look into this more closely to see what's going on. |
I tried some examples of JAX code with my GPU (like these one: https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) and it was working. I think the issue come from the library. |
Hey @vdsmax, That's very frustrating. I'm not sure what's going wrong. I am able to run the example on my end on a GPU. I am running I've created a version of the example using TensorFlow. Perhaps that works for you: from stheno.tensorflow import GP, Matern52, Measure
from varz.tensorflow import Vars, minimise_l_bfgs_b
from wbml.plot import tweak
import lab.tensorflow as B
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
B.set_global_device("gpu")
x1 = np.linspace(0, 10, 200)
x2 = np.linspace(0, 9, 200)
x3 = np.linspace(0, 7, 200)
# Generate some test data.
f = GP(Matern52())
y1 = f(x1, 0.2).sample().flatten()
y2 = f(x2, 0.2).sample().flatten()
y3 = f(x3, 0.2).sample().flatten()
p = 3 # Number of outputs
m = 3 # Number of latent processes
def model(vs):
ps = vs.struct
with Measure() as prior:
# Create independent processes with learnable length scales initialised to `1`.
us = [
GP(Matern52().stretch(ps_u.scale.positive(1)))
for ps_u, _ in zip(ps.us, range(p))
]
# Mix processes together to induce correlations between the outputs.
H = ps.mixing_matrix.unbounded(shape=(p, m))
fs = [0 for _ in range(p)]
for i in range(p):
for j in range(m):
fs[i] = fs[i] + H[i, j] * us[j]
# Create learnable observation noises initialised to `0.1`
noises = ps.noises.positive(0.1, shape=(p,))
return prior, fs, noises
def objective(vs):
prior, fs, noises = model(vs)
return -prior.logpdf(
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
# Perform learning.
vs = Vars(tf.float64)
minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
vs.print() # Display learned parameters.
# Compute posterior and predictions.
prior, fs, noises = model(vs)
posterior = prior | (
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
f1_post = posterior(fs[0])
f2_post = posterior(fs[1])
f3_post = posterior(fs[2])
def plot_posterior(x, f, x_obs=None, y_obs=None):
if x_obs is not None:
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = f(x).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
# Plot results.
plt.figure(figsize=(10, 6))
x_to_plot = np.linspace(0, 10, 200)
plt.subplot(3, 1, 1)
plt.title("Output 1")
plot_posterior(x_to_plot, f1_post, x1, y1)
plt.subplot(3, 1, 2)
plt.title("Output 2")
plot_posterior(x_to_plot, f2_post, x2, y2)
plt.subplot(3, 1, 3)
plt.title("Output 3")
plot_posterior(x_to_plot, f3_post, x3, y3)
plt.show() |
Hi @wesselb,
I am trying to use your example of Multi-Output Regression with some data I have. I don't understand how to correctly give them to the VGP and them make a prediction.
My data as input x_obs are not the same, so it's not exactly as the example. I have nine x observation as [x1,x2,x3,x4,x5,x6,x7,x8,x9] with their y observation as [y1,y2,y3,y4,y5,y6,y7,y8,y9].
Also, with your example provided, is it possible to optimize some hyperparameters if we had some in the VGP ?
Here are my code I was trying to use, with 3 different outputs to simulate data. Thank you in advance for your help.
The text was updated successfully, but these errors were encountered: