maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Flash attention - head_dim 64

Open peregilk opened this issue 1 year ago • 4 comments

I have tried using MaxText to train Llama 3.2 3B. This seems to work fine with just minor modifications to the configs.

However, I am unable to train the Llama 1B. The reason is that Flash/Splash attention seem to require that the head_dim is divisible by 128. The head_dim of the 1B model is only 64. I get a "not implemented" error. Using dot_product attention for long context lengths is really challenging.

Any ideas?

peregilk avatar Nov 18 '24 07:11 peregilk

It would be great if anyone solve this: DeepSeek also suffers from this problem as queries and keys have 192 dims. A workaround is to pad to 256 dims, but this results in unnecessary computations.

Also, keys and values need to have the same dim when using flash/splash attention, which is not the case in DeepSeek, whose keys have 192 dims but values have 128 dims. Padding values to 256 dims solves the problem at the expense of unnecessary computations.

rodrigo-f-nogueira avatar Feb 09 '25 10:02 rodrigo-f-nogueira

@gobbleturk @RissyRan can you please take a look

shralex avatar Feb 19 '25 07:02 shralex

Hi, thanks for reaching out! Could you provide more detailed logs for not implemented error with 64 dim? Yeah, padding may be needed based on hardware design for 192 dims.

For this key and values assertion, I would suggest submit a feature request in the JAX repo.

RissyRan avatar Feb 19 '25 17:02 RissyRan

Hi @RissyRan , thanks, I've already opened an issue in the jax repo: https://github.com/jax-ml/jax/issues/26433 For now, a quick workaround was to pad the dimensions as below:

The code below should be inserted right before this line:

def maybe_pad_for_flash_attn(arr, multiple_of=128):
  head_dim = arr.shape[-1]
  # Calculate the padding length needed to make the head dimension a multiple of 128
  pad_length = (multiple_of - (head_dim % multiple_of)) % multiple_of
  if pad_length > 0:
      # Pad only along the last dimension
      arr = lax.pad(arr, 
                    jnp.array(0.0, dtype=arr.dtype),  # value
                    [(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, pad_length, 0)])
  return arr

# # Flash attention needs head_dim to be a multiple of 128, so pad if necessary.
# Find the largest among the three head dims. In the case of deepseek v2/v3, this will be 192
max_head_dim = max(q.shape[-1], k.shape[-1], v.shape[-1])

# Find the next multiple of 128 that is greater than the largest head_dim.
# In the case of deepseek v2 and v3, this will 256
multiple_of = ((max_head_dim // 128) + 1) * 128

if self.config.attention == "flash":
  q = maybe_pad_for_flash_attn(q, multiple_of=multiple_of)
  q = nn.with_logical_constraint(q, self.query_axis_names)

  k = maybe_pad_for_flash_attn(k, multiple_of=multiple_of)
  k = nn.with_logical_constraint(k, self.key_axis_names)

  v = maybe_pad_for_flash_attn(v, multiple_of=multiple_of)
  v = nn.with_logical_constraint(v, self.value_axis_names)

rodrigo-f-nogueira avatar Feb 19 '25 21:02 rodrigo-f-nogueira

The recent change is merged, please have a try, https://github.com/jax-ml/jax/pull/30862

RissyRan avatar Aug 20 '25 20:08 RissyRan