Automatically differentiate Pallas kernel failed
Description
Pallas doc claims automatic differentiation of pallas kernel work but is slower. However, when tested on tpu, AssertionError is raised. Tested with simple add kernel. ( Reference: https://jax.readthedocs.io/en/latest/pallas/design.html#grad-of-pallas-call
Tested with kaggle tpu & trc (both v3-8), both showed same result.
Traceback (most recent call last):
File "/kaggle/working/pallas_bug_test.py", line 26, in <module>
pallas_grad_output=grad_pallas_call(x,y)
File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 621, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 691, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 2176, in _vjp
out_primals, vjp = ad.vjp(flat_fun, primals_flat)
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 143, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 134, in linearize
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
AssertionError
Code to reproduce error:
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
def pallas_call(x,y):
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
grad_pallas_call = jax.grad(pallas_call)
x=jax.random.normal(jax.random.key(0), (1024, 1024))
y=jax.random.normal(jax.random.key(0), (1024, 1024))
pallas_output=pallas_call(x,y)
print(pallas_output.sum())
pallas_grad_output=grad_pallas_call(x,y)
System info (python version, jaxlib version, accelerator, etc.)
TPU V3 jax-0.4.30 jaxlib-0.4.30 libtpu-nightly-0.1.dev20240617
jax: 0.4.30
jaxlib: 0.4.30
numpy: 1.26.4
python: 3.10.13 (main, Mar 12 2024, 12:16:25) [GCC 12.2.0]
Thanks @Lime-Cakes, this looks like a bug indeed.
Thanks @Lime-Cakes, this looks like a bug indeed.
Any possible workaround atm besides writing backward kernel too or best to just wait for bugfix?
I think you probably want to write a backwards kernel anyway, because the automatically derived one (even if we fixed the assertion) is guaranteed to be inefficient.
Having a slow and correct one makes it easier to write the backwards pass, since I'll have a correct kernel to test my backwards pass impl against. It'd be quite useful to have