equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Export attention function similar to torch.nn.functional.scaled_dot_product_attention

Open Artur-Galstyan opened this issue 1 year ago • 2 comments

Currently, the functions exist in the _attention.py file but are not explicitly exported. But a lot of people want to write their own custom MHA implementation and could use these functions.

(I'm aware that I can simply import them nonetheless, but because it's not in the docs and not everyone goes through the source code, that can be easily overseen)

WDYT? Other framework have a dedicated "functional" package in them. It'd be great to have something similar.

Artur-Galstyan avatar Jun 21 '24 07:06 Artur-Galstyan

I'd be happy to add these to the public API. Just never had a request for that before!

On the topic of functional APIs, one of the nice thinks about the functional-programming-nature of JAX+Equinox is how we kind of get that for free! if you want functions that look like this:

weight_and_bias = init_params(...)
linear(weight_and_bias, x)

then these can be obtained as just:

init_params = eqx.nn.Linear.__init__
linear = eqx.nn.Linear.__call__

patrick-kidger avatar Jun 21 '24 20:06 patrick-kidger

For what it's worth I second this ! I've implemented attention more than once in equinox and didn't know this was hidden in the library :) Would have saved some time / debugging 👍

TugdualKerjan avatar Dec 12 '24 17:12 TugdualKerjan