celerite2 icon indicating copy to clipboard operation
celerite2 copied to clipboard

Work out traceable JVP and transpose rules for JAX

Open dfm opened this issue 5 years ago • 1 comments

It should be possible to write the JVP ops using existing celerite primitives. This would allow support for higher order differentiation and perhaps it won't cause a significant computational overhead.

For example, the matmul_lower JVP can be implemented as follows:

def matmul_lower_jvp(arg_values, arg_tangents):
    def make_zero(x, t):
        return lax.zeros_like_array(x) if type(t) is ad.Zero else t

    t, c, U, V, Y = arg_values
    tp, cp, Up, Vp, Yp = (
        make_zero(x, t) for x, t in zip(arg_values, arg_tangents)
    )
    
    Ut = -(c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * U + Up
    Vt = (c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * V + Vp
    Zp = matmul_lower(t, c, U, V, Yp)
    Zp += matmul_lower(t, c, Ut, V, Y)
    Zp += matmul_lower(t, c, U, Vt, Y)
    
    return matmul_lower_p.bind(t, c, U, V, Y), (Zp, None)

But I haven't figured out the correct transpose yet.

dfm avatar Nov 09 '20 14:11 dfm

I figured out the transpose rules for multiplication. We need to generalize the matmul to have "propagators" on both the left and right. But in that case, if

Z = mml(t, cl, cr, U, V, Y)

then

bU = mml(t, cr, cl, bZ, Y, V)
bV = mmu(t, cr, cl, bZ, Y, U)
bY = mmu(t, cl, cr, U, V, bZ)
import numpy as np

def mml(t, cl, cr, U, V, Y):
    Z = np.empty_like(Y)
    Z[0] = 0.0
    F = np.zeros((U.shape[1], Y.shape[1]))
    for n in range(1, U.shape[0]):
        F += np.outer(V[n - 1], Y[n - 1])

        pl = np.exp(cl * (t[n - 1] - t[n]))
        pr = np.exp(cr * (t[n - 1] - t[n]))
        F = np.diag(pl) @ F @ np.diag(pr)
        
        Z[n] = U[n] @ F
    return Z

def mmu(t, cl, cr, U, V, Y):
    Z = np.empty_like(Y)
    Z[-1] = 0.0
    F = np.zeros((U.shape[1], Y.shape[1]))
    for n in range(U.shape[0] - 2, -1, -1):
        F += np.outer(U[n + 1], Y[n + 1])

        pl = np.exp(cl * (t[n] - t[n + 1]))
        pr = np.exp(cr * (t[n] - t[n + 1]))
        F = np.diag(pl) @ F @ np.diag(pr)

        Z[n] = V[n] @ F
    return Z

N = 100
J = 4
K = 3

t = np.sort(np.random.uniform(0, 10, N))
cl = np.random.rand(J)
cr = np.zeros(K)

U = np.random.randn(N, J)
V = np.random.randn(N, J)
Y = np.random.randn(N, K)

Up = np.exp(-cl[None, :] * t[:, None]) * U
Vp = np.exp(cl[None, :] * t[:, None]) * V

assert np.allclose(mml(t, cl, cr, U, V, Y), np.tril(Up @ Vp.T, -1) @ Y)
assert np.allclose(mmu(t, cl, cr, U, V, Y), np.triu(Vp @ Up.T, 1) @ Y)

bZ = ....

bY = mmu(t, cl, cr, U, V, bZ)
bU = mml(t, cr, cl, bZ, Y, V)
bV = mmu(t, cr, cl, bZ, Y, U)

dfm avatar Dec 09 '20 20:12 dfm