stheno
stheno copied to clipboard
Documentation about Multi-Output Regression
Hi @wesselb,
I am trying to use your example of Multi-Output Regression with some data I have. I don't understand how to correctly give them to the VGP and them make a prediction. My data as input x_obs are not the same, so it's not exactly as the example. I have nine x observation as [x1,x2,x3,x4,x5,x6,x7,x8,x9] with their y observation as [y1,y2,y3,y4,y5,y6,y7,y8,y9]. Also, with your example provided, is it possible to optimize some hyperparameters if we had some in the VGP ?
Here are my code I was trying to use, with 3 different outputs to simulate data. Thank you in advance for your help.
import matplotlib.pyplot as plt
from wbml.plot import tweak
from stheno import B, Measure, GP, EQ, Delta, Matern52
class VGP:
"""A vector-valued GP."""
def __init__(self, ps):
self.ps = ps
def __add__(self, other):
return VGP([f + g for f, g in zip(self.ps, other.ps)])
def lmatmul(self, A):
m, n = A.shape
ps = [0 for _ in range(m)]
for i in range(m):
for j in range(n):
ps[i] += A[i, j] * self.ps[j]
return VGP(ps)
# Define points to predict at.
x = B.linspace(0, 10, 5)
# Create some sample data.
x1 = np.atleast_2d(np.linspace(0, 10, 5)).T
x2 = np.atleast_2d(np.linspace(0, 9, 5)).T
x3 = np.atleast_2d(np.linspace(0, 7, 5)).T
y1 = np.atleast_2d(np.linspace(0, 10, 5)).T
y2 = np.atleast_2d(np.linspace(0, 10, 5)).T
y3 = np.atleast_2d(np.linspace(0, 10, 5)).T
x_obs = [x1,x2,x3]
y_obs = [y1,y2,y3]
# Model parameters:
m = 3
p = 3
H = B.randn(p, m)
with Measure() as prior:
# Construct latent functions.
us = VGP([GP(Matern52()) for _ in range(m)])
# Construct multi-output prior.
fs = us.lmatmul(H)
# Construct noise.
e = VGP([GP(0 * Delta()) for _ in range(p)])
# Construct observation model.
ys = e + fs
# Sample a true, underlying function and observations.
samples = prior.sample(*(p(x) for p in zip(fs.ps)), *(p(x_obs) for p, x_obs in zip(ys.ps, x_obs)))
fs_true, ys_obs = samples[:p], samples[p:]
# Compute the posterior and make predictions.
post = prior.condition(*((p(x_obs), y_obs) for p, y_obs, x_obs in zip(ys.ps, ys_obs, x_obs)))
preds = [post(p(x)) for p in fs.ps]
# Plot results.
def plot_prediction(x, f, pred, x_obs=None, y_obs=None):
plt.plot(x, f, label="True", style="test")
if x_obs is not None:
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = pred.marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
plt.figure(figsize=(10, 6))
for i in range(3):
plt.subplot(3, 1, i + 1)
plt.title(f"Output {i + 1}")
plot_prediction(x, fs_true[i], preds[i], x_obs, ys_obs[i])
plt.show()
Hi @vdsmax!
I've put together a simple MOGP model (not using the example) which might better suit your use case. The script uses JAX to learn hyperparameters. (You can also use another AD framework if you like.)
from stheno.jax import GP, Matern52, Measure
from varz.jax import Vars, minimise_l_bfgs_b
from wbml.plot import tweak
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np
x1 = np.linspace(0, 10, 30)
x2 = np.linspace(0, 9, 40)
x3 = np.linspace(0, 7, 50)
# Generate some test data.
f = GP(Matern52())
y1 = f(x1, 0.2).sample().flatten()
y2 = f(x2, 0.2).sample().flatten()
y3 = f(x3, 0.2).sample().flatten()
p = 3 # Number of outputs
m = 3 # Number of latent processes
def model(vs):
ps = vs.struct
with Measure() as prior:
# Create independent processes with learnable length scales initialised to `1`.
us = [
GP(Matern52().stretch(ps_u.scale.positive(1)))
for ps_u, _ in zip(ps.us, range(p))
]
# Mix processes together to induce correlations between the outputs.
H = ps.mixing_matrix.unbounded(shape=(p, m))
fs = [0 for _ in range(p)]
for i in range(p):
for j in range(m):
fs[i] = fs[i] + H[i, j] * us[j]
# Create learnable observation noises initialised to `0.1`
noises = ps.noises.positive(0.1, shape=(p,))
return prior, fs, noises
def objective(vs):
prior, fs, noises = model(vs)
return -prior.logpdf(
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
# Perform learning.
vs = Vars(jnp.float64)
minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
vs.print() # Display learned parameters.
# Compute posterior and predictions.
prior, fs, noises = model(vs)
posterior = prior | (
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
f1_post = posterior(fs[0])
f2_post = posterior(fs[1])
f3_post = posterior(fs[2])
def plot_posterior(x, f, x_obs=None, y_obs=None):
if x_obs is not None:
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = f(x).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
# Plot results.
plt.figure(figsize=(10, 6))
x_to_plot = np.linspace(0, 10, 200)
plt.subplot(3, 1, 1)
plt.title("Output 1")
plot_posterior(x_to_plot, f1_post, x1, y1)
plt.subplot(3, 1, 2)
plt.title("Output 2")
plot_posterior(x_to_plot, f2_post, x2, y2)
plt.subplot(3, 1, 3)
plt.title("Output 3")
plot_posterior(x_to_plot, f3_post, x3, y3)
plt.show()
The script produces the following plot:
Let me know if this suits your needs. :)
Thank you very much for your code example. It is running on my side too, and I have the same results by using my CPU.
Because the computational time is high for nine inputs by using a CPU, I wanted to use my GPU to see if it will be faster. I followed the steps to use CUDA with the Jax library and was able to link both of them. However, by using the same code as you give me, I obtained this time an error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-4-45889ae4f67a> in <module>
55 # Perform learning.
56 vs = Vars(jnp.float64)
---> 57 minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
58 vs.print() # Display learned parameters.
59
~/python-env/lib/python3.6/site-packages/varz/minimise.py in minimise_l_bfgs_b(f, vs, f_calls, iters, trace, names, jit)
77 trace=trace,
78 names=names,
---> 79 jit=jit,
80 )
81
~/python-env/lib/python3.6/site-packages/varz/minimise.py in _minimise_l_bfgs_b(f, vs, f_calls, iters, trace, names, jit)
154 # Run function once to ensure that all variables are initialised and
155 # available.
--> 156 res = convert(f(vs, *args), tuple)
157 val_init, args = res[0], res[1:]
158
<ipython-input-4-45889ae4f67a> in objective(vs)
49 (fs[0](x1, noises[0]), y1),
50 (fs[1](x2, noises[1]), y2),
---> 51 (fs[2](x3, noises[2]), y3),
52 )
53
~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function._BoundFunction.__call__()
~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()
~/python-env/lib/python3.6/site-packages/stheno/model/measure.py in logpdf(self, *pairs)
461 """
462 fdd, y = combine(*pairs)
--> 463 return self(fdd).logpdf(y)
464
465 @_dispatch
~/python-env/lib/python3.6/site-packages/stheno/random.py in logpdf(self, x)
210 B.logdet(self.var)[..., None] # Correctly line up with `iqf_diag`.
211 + B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi)
--> 212 + B.iqf_diag(self.var, B.subtract(x, self.mean))
213 )
214 / 2
~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()
~/python-env/lib/python3.6/site-packages/matrix/ops/iqf_diag.py in iqf_diag(a, b)
33 @B.dispatch
34 def iqf_diag(a, b):
---> 35 return iqf_diag(a, b, b)
36
37
~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()
~/python-env/lib/python3.6/site-packages/matrix/ops/iqf_diag.py in iqf_diag(a, b, c)
20 """
21 chol = B.cholesky(a)
---> 22 chol_b = B.solve(chol, b)
23 if c is b:
24 chol_c = chol_b
~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()
~/python-env/lib/python3.6/site-packages/lab/util.py in wrapper(*args, **kw_args)
212
213 # Retry call.
--> 214 return getattr(B, f.__name__)(*args, **kw_args)
215
216 return wrapper
~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()
~/python-env/lib/python3.6/site-packages/matrix/ops/solve.py in solve(a, b)
41 )
42 a, b = align_batch(a.mat, b)
---> 43 return Dense(B.trisolve(B.dense(a), B.dense(b), lower_a=True))
44
45
~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()
~/python-env/lib/python3.6/site-packages/lab/shape.py in f_wrapped(*args, **kw_args)
183 @wraps(f)
184 def f_wrapped(*args, **kw_args):
--> 185 return f(*(unwrap_dimension(arg) for arg in args), **kw_args)
186
187 return dispatch(f_wrapped)
~/python-env/lib/python3.6/site-packages/lab/jax/linear_algebra.py in triangular_solve(a, b, lower_a)
125 )
126
--> 127 return batch_computation(_triangular_solve, (a, b), (2, 2))
128
129
~/python-env/lib/python3.6/site-packages/lab/util.py in batch_computation(f, xs, ranks)
149 for index in indices:
150 batches.append(
--> 151 f(*[x[_translate_index(index, s)] for x, s in zip(xs, batch_shapes)])
152 )
153
~/python-env/lib/python3.6/site-packages/lab/jax/linear_algebra.py in _triangular_solve(a_, b_)
122 def _triangular_solve(a_, b_):
123 return jsla.solve_triangular(
--> 124 a_, b_, trans="N", lower=lower_a, check_finite=False
125 )
126
~/python-env/lib/python3.6/site-packages/jax/_src/scipy/linalg.py in solve_triangular(***failed resolving arguments***)
223 overwrite_b=False, debug=None, check_finite=True):
224 del overwrite_b, debug, check_finite
--> 225 return _solve_triangular(a, b, trans, lower, unit_diagonal)
226
227
~/python-env/lib/python3.6/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
425 flat_fun, *args_flat,
426 device=device, backend=backend, name=flat_fun.__name__,
--> 427 donated_invars=donated_invars, inline=inline)
428 out_pytree_def = out_tree()
429 out = tree_unflatten(out_pytree_def, out_flat)
~/python-env/lib/python3.6/site-packages/jax/core.py in bind(self, fun, *args, **params)
1558
1559 def bind(self, fun, *args, **params):
-> 1560 return call_bind(self, fun, *args, **params)
1561
1562 def process(self, trace, fun, tracers, params):
~/python-env/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1549 params_tuple, out_axes_transforms)
1550 tracers = map(top_trace.full_raise, args)
-> 1551 outs = primitive.process(top_trace, fun, tracers, params)
1552 return map(full_lower, apply_todos(env_trace_todo(), outs))
1553
~/python-env/lib/python3.6/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1561
1562 def process(self, trace, fun, tracers, params):
-> 1563 return trace.process_call(self, fun, tracers, params)
1564
1565 def post_process(self, trace, out_tracers, params):
~/python-env/lib/python3.6/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
604
605 def process_call(self, primitive, f, tracers, params):
--> 606 return primitive.impl(f, *tracers, **params)
607 process_map = process_call
608
~/python-env/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
593 *unsafe_map(arg_spec, args))
594 try:
--> 595 return compiled_fun(*args)
596 except FloatingPointError:
597 assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
~/python-env/lib/python3.6/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, kept_var_idx, *args)
891 for i, x in enumerate(args)
892 if x is not token and i in kept_var_idx))
--> 893 out_bufs = compiled.execute(input_bufs)
894 check_special(xla_call_p.name, out_bufs)
895 return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
RuntimeError: Internal: Unable to launch triangular solve for thunk 0x2c46c570
Do I need to add something to the code to make it work with a GPU ?
Ouch! That doesn't look good. Could you confirm that running other JAX code on the GPU works fine? If that's the case, I can look into this more closely to see what's going on.
I tried some examples of JAX code with my GPU (like these one: https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) and it was working. I think the issue come from the library. I have jax-0.2.17 and jaxlib-0.1.65+cuda110 install on my computer
Hey @vdsmax,
That's very frustrating. I'm not sure what's going wrong. I am able to run the example on my end on a GPU. I am running jaxlib-0.1.73+cuda11.cudnn82
and jax-0.2.25
.
I've created a version of the example using TensorFlow. Perhaps that works for you:
from stheno.tensorflow import GP, Matern52, Measure
from varz.tensorflow import Vars, minimise_l_bfgs_b
from wbml.plot import tweak
import lab.tensorflow as B
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
B.set_global_device("gpu")
x1 = np.linspace(0, 10, 200)
x2 = np.linspace(0, 9, 200)
x3 = np.linspace(0, 7, 200)
# Generate some test data.
f = GP(Matern52())
y1 = f(x1, 0.2).sample().flatten()
y2 = f(x2, 0.2).sample().flatten()
y3 = f(x3, 0.2).sample().flatten()
p = 3 # Number of outputs
m = 3 # Number of latent processes
def model(vs):
ps = vs.struct
with Measure() as prior:
# Create independent processes with learnable length scales initialised to `1`.
us = [
GP(Matern52().stretch(ps_u.scale.positive(1)))
for ps_u, _ in zip(ps.us, range(p))
]
# Mix processes together to induce correlations between the outputs.
H = ps.mixing_matrix.unbounded(shape=(p, m))
fs = [0 for _ in range(p)]
for i in range(p):
for j in range(m):
fs[i] = fs[i] + H[i, j] * us[j]
# Create learnable observation noises initialised to `0.1`
noises = ps.noises.positive(0.1, shape=(p,))
return prior, fs, noises
def objective(vs):
prior, fs, noises = model(vs)
return -prior.logpdf(
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
# Perform learning.
vs = Vars(tf.float64)
minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
vs.print() # Display learned parameters.
# Compute posterior and predictions.
prior, fs, noises = model(vs)
posterior = prior | (
(fs[0](x1, noises[0]), y1),
(fs[1](x2, noises[1]), y2),
(fs[2](x3, noises[2]), y3),
)
f1_post = posterior(fs[0])
f2_post = posterior(fs[1])
f3_post = posterior(fs[2])
def plot_posterior(x, f, x_obs=None, y_obs=None):
if x_obs is not None:
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = f(x).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
# Plot results.
plt.figure(figsize=(10, 6))
x_to_plot = np.linspace(0, 10, 200)
plt.subplot(3, 1, 1)
plt.title("Output 1")
plot_posterior(x_to_plot, f1_post, x1, y1)
plt.subplot(3, 1, 2)
plt.title("Output 2")
plot_posterior(x_to_plot, f2_post, x2, y2)
plt.subplot(3, 1, 3)
plt.title("Output 3")
plot_posterior(x_to_plot, f3_post, x3, y3)
plt.show()