Export attention function similar to torch.nn.functional.scaled_dot_product_attention
Currently, the functions exist in the _attention.py file but are not explicitly exported. But a lot of people want to write their own custom MHA implementation and could use these functions.
(I'm aware that I can simply import them nonetheless, but because it's not in the docs and not everyone goes through the source code, that can be easily overseen)
WDYT? Other framework have a dedicated "functional" package in them. It'd be great to have something similar.
I'd be happy to add these to the public API. Just never had a request for that before!
On the topic of functional APIs, one of the nice thinks about the functional-programming-nature of JAX+Equinox is how we kind of get that for free! if you want functions that look like this:
weight_and_bias = init_params(...)
linear(weight_and_bias, x)
then these can be obtained as just:
init_params = eqx.nn.Linear.__init__
linear = eqx.nn.Linear.__call__
For what it's worth I second this ! I've implemented attention more than once in equinox and didn't know this was hidden in the library :) Would have saved some time / debugging 👍