flax icon indicating copy to clipboard operation
flax copied to clipboard

nnx.make_causal_mask() usage

Open windmaple opened this issue 10 months ago • 3 comments

So this is a follow-up on #4290 (@cgarciae). For building a causal LM, I need to use causal masking. Here is my attempt (by adding a single line using the code from #4290:

batch_size = 2
seqlen = 40
emb_size = 256

x = jnp.ones((batch_size, seqlen, emb_size))

mha = nnx.MultiHeadAttention(
  in_features=emb_size, num_heads=2, decode=True, rngs=nnx.Rngs(0)
)
shape = x.shape

 for i in range(seqlen): # iterate all tokens
  y = mha(inputs_q=x[:, i : i + 1],
          mask=nnx.make_causal_mask(x[:, i : i + 1]))   #newly added

The error I got is:

AssertionError: masks must have same rank: (5, 4)

I cannot make sense of this error :(

windmaple avatar Jan 25 '25 11:01 windmaple

Hi @windmaple, in decode mode MultiHeadAttention is always causal, meaning you don't have to provide a mask in this case. See:

https://github.com/google/flax/blob/a8a192ff167f8b25b9b568cfece44ef043a82dad/flax/nnx/nn/attention.py#L518-L528

cgarciae avatar Jan 28 '25 03:01 cgarciae

Yeah, I realized that, since we are feeding token in one by one.

However, for some reason it's not working as expected. I'll try to provide a repro.

windmaple avatar Jan 28 '25 12:01 windmaple

Here is the notebook: https://colab.research.google.com/drive/1kk7xcFSA7KzVQnekfqmdd1Gq_Z4qsLvU#scrollTo=NIOXoY1xgiww

Turning on KV cache makes it so much slower, which doesn't make any sense to me :(

windmaple avatar Jan 30 '25 10:01 windmaple