jax icon indicating copy to clipboard operation
jax copied to clipboard

Automatically differentiate Pallas kernel failed

Open Lime-Cakes opened this issue 1 year ago • 3 comments

Description

Pallas doc claims automatic differentiation of pallas kernel work but is slower. However, when tested on tpu, AssertionError is raised. Tested with simple add kernel. ( Reference: https://jax.readthedocs.io/en/latest/pallas/design.html#grad-of-pallas-call

Tested with kaggle tpu & trc (both v3-8), both showed same result.

Traceback (most recent call last):
  File "/kaggle/working/pallas_bug_test.py", line 26, in <module>
    pallas_grad_output=grad_pallas_call(x,y)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 621, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 691, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 2176, in _vjp
    out_primals, vjp = ad.vjp(flat_fun, primals_flat)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 143, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 134, in linearize
    assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
AssertionError

Code to reproduce error:

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def add_vectors_kernel(x_ref, y_ref, o_ref):
    x, y = x_ref[...], y_ref[...]
    o_ref[...] = x + y
def pallas_call(x,y):
    return pl.pallas_call(add_vectors_kernel,
                          out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
                         )(x, y)

grad_pallas_call = jax.grad(pallas_call)

x=jax.random.normal(jax.random.key(0), (1024, 1024))
y=jax.random.normal(jax.random.key(0), (1024, 1024))
pallas_output=pallas_call(x,y)
print(pallas_output.sum())
pallas_grad_output=grad_pallas_call(x,y)

System info (python version, jaxlib version, accelerator, etc.)

TPU V3 jax-0.4.30 jaxlib-0.4.30 libtpu-nightly-0.1.dev20240617

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.10.13 (main, Mar 12 2024, 12:16:25) [GCC 12.2.0]

Lime-Cakes avatar Jun 29 '24 09:06 Lime-Cakes

Thanks @Lime-Cakes, this looks like a bug indeed.

superbobry avatar Jul 01 '24 09:07 superbobry

Thanks @Lime-Cakes, this looks like a bug indeed.

Any possible workaround atm besides writing backward kernel too or best to just wait for bugfix?

Lime-Cakes avatar Jul 01 '24 09:07 Lime-Cakes

I think you probably want to write a backwards kernel anyway, because the automatically derived one (even if we fixed the assertion) is guaranteed to be inefficient.

superbobry avatar Jul 01 '24 09:07 superbobry

Having a slow and correct one makes it easier to write the backwards pass, since I'll have a correct kernel to test my backwards pass impl against. It'd be quite useful to have

Lime-Cakes avatar Jul 03 '24 15:07 Lime-Cakes