flax
flax copied to clipboard
nnx.make_causal_mask() usage
So this is a follow-up on #4290 (@cgarciae). For building a causal LM, I need to use causal masking. Here is my attempt (by adding a single line using the code from #4290:
batch_size = 2
seqlen = 40
emb_size = 256
x = jnp.ones((batch_size, seqlen, emb_size))
mha = nnx.MultiHeadAttention(
in_features=emb_size, num_heads=2, decode=True, rngs=nnx.Rngs(0)
)
shape = x.shape
for i in range(seqlen): # iterate all tokens
y = mha(inputs_q=x[:, i : i + 1],
mask=nnx.make_causal_mask(x[:, i : i + 1])) #newly added
The error I got is:
AssertionError: masks must have same rank: (5, 4)
I cannot make sense of this error :(
Hi @windmaple, in decode mode MultiHeadAttention is always causal, meaning you don't have to provide a mask in this case. See:
https://github.com/google/flax/blob/a8a192ff167f8b25b9b568cfece44ef043a82dad/flax/nnx/nn/attention.py#L518-L528
Yeah, I realized that, since we are feeding token in one by one.
However, for some reason it's not working as expected. I'll try to provide a repro.
Here is the notebook: https://colab.research.google.com/drive/1kk7xcFSA7KzVQnekfqmdd1Gq_Z4qsLvU#scrollTo=NIOXoY1xgiww
Turning on KV cache makes it so much slower, which doesn't make any sense to me :(