jax
jax copied to clipboard
Pallas NotImplementedError: Unimplemented primitive in Pallas TPU lowering: dynamic_slice.
Description
I'm trying to RWKV in jax using Pallas but run into problems with the for loop. In my code, I'm using jax.lax.fori_loop to iterate over the sequence dimension. When indexing into my arrays I get the error: JaxStackTraceBeforeTransformation: NotImplementedError: Unimplemented primitive in Pallas TPU lowering: dynamic_slice. Please file an issue on https://github.com/google/jax/issues.
Code:
# r,k,v,w: T x D
# u: D
# out: T x D
def loop(t,W_t):
out,s = W_t
kv_t = k[t] * v[t] # D x D
out = out.at[t].set((r[:,t] @ (u * kv_t + s)).squeeze(0)) # D @ (D x D) -> 1 x D
s = kv_t + w[t,None] * s # D x D
return out,s
out_ref[...] = jax.lax.fori_loop(0,T,loop,(out,s))[0]
Problem Arises with "k[t] * v[t]"
What jax/jaxlib version are you using?
jax v0.4.21 jaxlib v0.4.21
Which accelerator(s) are you using?
TPU
Additional system info?
Python 3.10, Kaggle Notebook
NVIDIA GPU info
No response