equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Expose attention weights' head dimension

Open neel04 opened this issue 1 year ago • 8 comments

Currently, equinox handles attention heads opaquely - it reshapes QKV through the _project method to add the heads dimension.

However, sharding via the heads dimension is commonly used when parallelizing the model.

I feel that the {query | key | value}_proj should be splitted to expose the head dimension.

WDYT?

neel04 avatar Sep 05 '24 15:09 neel04

Sorry, it's not totally clear to me what change you're suggesting. Can you expand?

patrick-kidger avatar Sep 06 '24 17:09 patrick-kidger

The W_{q | k | v} projections are 2D - of the general shape (query_size, num_heads * qk_size).

I feel that the head dimension should be explicit here - so the shape would be 3D of (query_size, num_heads, qk_size).

This might be a bit tricky to incorporate I suppose - but it's definitely quite helpful, for example sharding along the head dimension or weight sharing.

neel04 avatar Sep 06 '24 23:09 neel04

Ah, I see what you're saying! So I think much like QKV fusion, this would unfortunately be a backward-incompatible change.

For specifically the purposes of sharding, then I think whatever we should do depends on whatever you and dlwh come up with in #825.

patrick-kidger avatar Sep 07 '24 16:09 patrick-kidger

🤷 Even with a specific sharding API, ideally one should only need to deal with the model PyTree. If someone wants to shard on the heads dimension, then you would still have to insert explicit reshapes during MHA computation to convert the 3D array created during sharding back to 2D for the _project method.

I suppose one could add a check to fold the heads dimension if the array is 3D in MHA... but that seems janky

neel04 avatar Sep 07 '24 21:09 neel04

I wonder if at this point, it might be better to add another optimized version of eqx.nn.MultiHeadAttention - that would internally use jax's SDPA for more performance, expose heads for sharding and customization, fuse QKV and add @Artur-Galstyan's cache as well.

Users who want to explicitly adopt the newer features could switch over.

Or I suppose one needs a seperate lib of equinox utilities with such feature-complete modules.

neel04 avatar Sep 07 '24 21:09 neel04

@neel04 This might be relevant for you; I started converting PyTorch models to Equinox models and a lot of them use multi_head_attention_forward function from PyTorch (see this). My version is a 1:1 copy of the PyTorch version (I've got some tests to ensure that they are numerically equivalent), though I'm not using sharding or anything. I'd happy to take a contribution if you want to add sharding or anything :)

Artur-Galstyan avatar Apr 08 '25 11:04 Artur-Galstyan

That's pretty handy @Artur-Galstyan. Have you considered spinning off a seperate equinox-utils repo wherein one can contribute new (and useful) utilities like a beefy MHA implementation? I feel almost everyone ends up re-implementing lots of stuff which would've been better streamlined into an external repo.

We could add a: PyTorch friendly MHA, a highly performant KV-cache SDPA-reliant MHA, maybe even a bare-bones auto-sharding setup similar to what I'm using currently for MP/TP?

neel04 avatar Apr 08 '25 12:04 neel04

Yes, although not explicitly named equinox-utils, the repo jaxonmodels is meant to be exactly that: to have layers and functions that extend Equinox AND popular ML models. For example, the current implementation of BatchNorm in Equinox differs from the PyTorch version and since the goal was to have a 100%-ish match between the two frameworks, I also wrote a PyTorch compliant BatchNorm version (but it makes no sense to open a PR against Equinox because it's a breaking change of BatchNorm)

And yes, regarding sharding (and other things like KV-caches etc.) will come, esp. once I get to implement LLMs like Llama and Gemma. It will just take some time (there is so much code left to write!)

I'm open to suggestions and collaboration on that front :) Feel free to contact me (let's not use GH issues as a chat interface 😄)

Artur-Galstyan avatar Apr 08 '25 13:04 Artur-Galstyan