celerite2
celerite2 copied to clipboard
Work out traceable JVP and transpose rules for JAX
It should be possible to write the JVP ops using existing celerite primitives. This would allow support for higher order differentiation and perhaps it won't cause a significant computational overhead.
For example, the matmul_lower JVP can be implemented as follows:
def matmul_lower_jvp(arg_values, arg_tangents):
def make_zero(x, t):
return lax.zeros_like_array(x) if type(t) is ad.Zero else t
t, c, U, V, Y = arg_values
tp, cp, Up, Vp, Yp = (
make_zero(x, t) for x, t in zip(arg_values, arg_tangents)
)
Ut = -(c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * U + Up
Vt = (c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * V + Vp
Zp = matmul_lower(t, c, U, V, Yp)
Zp += matmul_lower(t, c, Ut, V, Y)
Zp += matmul_lower(t, c, U, Vt, Y)
return matmul_lower_p.bind(t, c, U, V, Y), (Zp, None)
But I haven't figured out the correct transpose yet.
I figured out the transpose rules for multiplication. We need to generalize the matmul to have "propagators" on both the left and right. But in that case, if
Z = mml(t, cl, cr, U, V, Y)
then
bU = mml(t, cr, cl, bZ, Y, V)
bV = mmu(t, cr, cl, bZ, Y, U)
bY = mmu(t, cl, cr, U, V, bZ)
import numpy as np
def mml(t, cl, cr, U, V, Y):
Z = np.empty_like(Y)
Z[0] = 0.0
F = np.zeros((U.shape[1], Y.shape[1]))
for n in range(1, U.shape[0]):
F += np.outer(V[n - 1], Y[n - 1])
pl = np.exp(cl * (t[n - 1] - t[n]))
pr = np.exp(cr * (t[n - 1] - t[n]))
F = np.diag(pl) @ F @ np.diag(pr)
Z[n] = U[n] @ F
return Z
def mmu(t, cl, cr, U, V, Y):
Z = np.empty_like(Y)
Z[-1] = 0.0
F = np.zeros((U.shape[1], Y.shape[1]))
for n in range(U.shape[0] - 2, -1, -1):
F += np.outer(U[n + 1], Y[n + 1])
pl = np.exp(cl * (t[n] - t[n + 1]))
pr = np.exp(cr * (t[n] - t[n + 1]))
F = np.diag(pl) @ F @ np.diag(pr)
Z[n] = V[n] @ F
return Z
N = 100
J = 4
K = 3
t = np.sort(np.random.uniform(0, 10, N))
cl = np.random.rand(J)
cr = np.zeros(K)
U = np.random.randn(N, J)
V = np.random.randn(N, J)
Y = np.random.randn(N, K)
Up = np.exp(-cl[None, :] * t[:, None]) * U
Vp = np.exp(cl[None, :] * t[:, None]) * V
assert np.allclose(mml(t, cl, cr, U, V, Y), np.tril(Up @ Vp.T, -1) @ Y)
assert np.allclose(mmu(t, cl, cr, U, V, Y), np.triu(Vp @ Up.T, 1) @ Y)
bZ = ....
bY = mmu(t, cl, cr, U, V, bZ)
bU = mml(t, cr, cl, bZ, Y, V)
bV = mmu(t, cr, cl, bZ, Y, U)