jax
jax copied to clipboard
how to define pallas kernels usable in both fwd and bwd mode?
right now, there's clear documentation re how to use custom_jvp
and custom_vjp
, but the automatic transposition of custom_jvp
isn't necessarily going to be very good for pallas kernels (is it even defined?) and using custom_vjp
isn't compatible with jvp
.
#17840 might be of relevance here, as this PR makes it possible to do jvp-of-custom_vjp
.
In practive it's still a little buggy and I haven't had the chance to fix it up yet, but it might be the essence of the solution needed.
cc @sharadmv @apaszke @mattjj (I suspect one of you would be able to comment on this...)
but the automatic transposition of custom_jvp isn't necessarily going to be very good for pallas kernels (is it even defined?) and using custom_vjp isn't compatible with jvp.
As Patrick implied, this seems to be more of a JAX issue than a pallas specific one. If JAX supports defining both custom vjp/jvp simultaneously, this pallas use case would work. @froystig