DESC
DESC copied to clipboard
SImplify Zernike Radial with vmap() and other performance suggestions
First Suggestion: #758 increased the performance of Zernike radial function but to be able to use reverse mode autodiff, we had to make some compromises in terms of performance or terms of duplicated code for different derivative orders.
The main problem is inside jnp.vectorize decorator all of the inputs are transformed into tracer values, and the vector dxs=jnp.arange(0,dr+1) cannot be generated. We tried to declare dr as a static argument inside the jit function which then calls the vectorized function, but the vectorize function causes it to be tracer even outside of the vectorize function (probably because of the jit compilation). Our solution was to create dxs with manually declared static value MAXDR, which is the required derivative order. But that caused the function evaluation to be 2-4 times slower than what was possible.
A similar issue is discussed in Google/JAX issues for vmap() here. Maybe we can implement this to make dr as static inside vmap() and vectorize by vmap().
Second Suggestion: Not always we have complete set of modes for ZernikePolynomials. This requires a check for the recursion relation function to whether to update the output with calculated value inside the loop or pass. Among many methods tested, a fori_loop over every index of m values was the best which provides all the checks required (duplicate modes as well as not existing nodes). However, this still has a major effect on the performance (without any checks, if the mode set is complete, the function is 3-4 times faster compared to one with checks). Some other methods tested (for reference) are,
# Find the index corresponding to the original array
# I changed arange function to get rid of 0 as index confusion
# so if index is full of 0s, there is no such mode
# (FASTER THAN CURRENT WAY BUT LACKS A CHECK FOR DUPLICATE MODES)
index = jnp.where(
jnp.logical_and(m == alpha, n == N),
jnp.arange(1, m.size + 1),
0,
)
idx = jnp.sum(index)
# needed for proper index
idx -= 1
result = (-1) ** N * r**alpha * P_n
out = out.at[:, idx].set(jnp.where(idx >= 0, result, out.at[:, idx].get()))
# Replace only if that mode exists (SLOW)
mask = jnp.logical_and(m == alpha, n == N)
result = (-1) ** N * r**alpha * P_n
out = jnp.where(mask[None, :], result[:, None], out)
# (SLOW)
result = (-1) ** N * r**alpha * P_n
mask = jnp.logical_and(m == alpha, n == N)
idx = jnp.where(mask, jnp.arange(mask.size), mask.size+1)
out = out.at[:,idx].set(result[:, None])
# (SLOW)
def update(x, args):
index, result, out = args
idx = index.at[x].get()
idx -= 1
out = out.at[:, idx].set(jnp.where(idx >= 0, result, out.at[:, idx].get()))
return (index, result, out)
index = jnp.where(
jnp.logical_and(m == alpha, n == N),
jnp.arange(1, m.size + 1),
0,
)
result = (-1) ** N * r**alpha * P_n
_, _, out = fori_loop(0, index.size, update, (index, result, out))
In my opinion, if we can implement these to zerinke_radial function, there is room for 3-5 times speed improvement as well as reduction of duplicated code (if #758 ends up with new functions for the derivatives).
@YigitElma can we close this?