jax icon indicating copy to clipboard operation
jax copied to clipboard

how to define pallas kernels usable in both fwd and bwd mode?

Open GallagherCommaJack opened this issue 1 year ago • 3 comments

right now, there's clear documentation re how to use custom_jvp and custom_vjp, but the automatic transposition of custom_jvp isn't necessarily going to be very good for pallas kernels (is it even defined?) and using custom_vjp isn't compatible with jvp.

GallagherCommaJack avatar Dec 30 '23 23:12 GallagherCommaJack

#17840 might be of relevance here, as this PR makes it possible to do jvp-of-custom_vjp.

In practive it's still a little buggy and I haven't had the chance to fix it up yet, but it might be the essence of the solution needed.

patrick-kidger avatar Jan 02 '24 18:01 patrick-kidger

cc @sharadmv @apaszke @mattjj (I suspect one of you would be able to comment on this...)

skye avatar Jan 02 '24 22:01 skye

but the automatic transposition of custom_jvp isn't necessarily going to be very good for pallas kernels (is it even defined?) and using custom_vjp isn't compatible with jvp.

As Patrick implied, this seems to be more of a JAX issue than a pallas specific one. If JAX supports defining both custom vjp/jvp simultaneously, this pallas use case would work. @froystig

sharadmv avatar Jan 02 '24 22:01 sharadmv