jax icon indicating copy to clipboard operation
jax copied to clipboard

Pallas NotImplementedError: Unimplemented primitive in Pallas TPU lowering: dynamic_slice.

Open faresobeid opened this issue 2 years ago • 0 comments

Description

I'm trying to RWKV in jax using Pallas but run into problems with the for loop. In my code, I'm using jax.lax.fori_loop to iterate over the sequence dimension. When indexing into my arrays I get the error: JaxStackTraceBeforeTransformation: NotImplementedError: Unimplemented primitive in Pallas TPU lowering: dynamic_slice. Please file an issue on https://github.com/google/jax/issues.

Code:

# r,k,v,w: T x D
# u: D
# out: T x D
def loop(t,W_t):
    out,s = W_t
    kv_t = k[t] * v[t] # D x D
    out = out.at[t].set((r[:,t] @ (u * kv_t + s)).squeeze(0)) # D @ (D x D) -> 1 x D
    s = kv_t + w[t,None] * s # D x D
    return out,s
out_ref[...] = jax.lax.fori_loop(0,T,loop,(out,s))[0]

Problem Arises with "k[t] * v[t]"

What jax/jaxlib version are you using?

jax v0.4.21 jaxlib v0.4.21

Which accelerator(s) are you using?

TPU

Additional system info?

Python 3.10, Kaggle Notebook

NVIDIA GPU info

No response

faresobeid avatar Dec 09 '23 14:12 faresobeid