Metalhead.jl icon indicating copy to clipboard operation
Metalhead.jl copied to clipboard

Move MHAttention layer to Flux

Open CarloLucibello opened this issue 3 years ago • 9 comments
trafficstars

@theabhirath @darsnack do you think the multi head attention layer is now general enough we can move it to Flux?

CarloLucibello avatar Apr 04 '22 12:04 CarloLucibello

Is NNAttentionLib.matmul simple enough that can be ported without carrying over another library dependance?

CarloLucibello avatar Apr 04 '22 12:04 CarloLucibello

If Flux is willing to take on NeuralAttentionlib as a dep, and we would need to rework it to accept more inputs. Currently, it only accepts 3D.

darsnack avatar Apr 04 '22 12:04 darsnack

Oh sorry, I posted my comment before the page refreshed with your's. I think all we need to provide is a parallelized version of NNlib.batched_mul for 4D inputs. It could be something that's specific for this layer that calls batched_mul under the hood, since I know there were some concerns about generic 4D implementation when it was brought up.

darsnack avatar Apr 04 '22 12:04 darsnack

NeuralAttentionlib already works with more than 3D inputs - one of the reasons I used it as a dep was that it would allow that functionality in the future (see https://github.com/FluxML/Metalhead.jl/pull/135#discussion_r827697134). The only concern could probably be that while the GPU path is parallelised (it uses the same CUBLAS functions underneath as NNlib), the CPU path is not (https://github.com/FluxML/Metalhead.jl/pull/135#discussion_r829147436). And NeuralAttentionlib basically already provides a readymade multiheaded self-attention layer. I had thought about a PR but decided against it because vanilla attention is hardly ever used anymore - most attention layers involve some novelty and so have to be written out in a custom manner (A 4D+ version of NNlib.batched_mul wouldn't hurt though, as I brought up in https://github.com/FluxML/NNlib.jl/issues/391)

theabhirath avatar Apr 04 '22 13:04 theabhirath

So you say that the attention layer in pytorch is hardly used in practice? https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

CarloLucibello avatar Apr 06 '22 18:04 CarloLucibello

I'm not sure what you mean. I'm pretty sure everyone in this thread wants to add the layer to Flux if that's what you're getting at.

darsnack avatar Apr 06 '22 18:04 darsnack

So you say that the attention layer in pytorch is hardly used in practice? https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

No, this one is used quite often...I meant the layer in the form as it is in Metalhead, which is quite task-specific, as opposed to say the NeuralAttentionlib functions (https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.multihead_qkv_attention or https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.generic_multihead_qkv_attention, although the PyTorch function certainly exposes a lot more stuff that can be tweaked - something the Flux layer could possibly incorporate)

theabhirath avatar Apr 06 '22 18:04 theabhirath

The only concern could probably be that while the GPU path is parallelised (it uses the same CUBLAS functions underneath as NNlib), the CPU path is not (https://github.com/FluxML/Metalhead.jl/pull/135#discussion_r829147436).

I can adapt the same multithreading approach that batched_mul use if the CPU part is really a concern.

chengchingwen avatar Apr 06 '22 23:04 chengchingwen

I'm not sure what you mean. I'm pretty sure everyone in this thread wants to add the layer to Flux if that's what you're getting at.

Sorry I misread this remark:

I had thought about a PR but decided against it because vanilla attention is hardly ever used anymore - most attention layers involve some novelty and so have to be written out in a custom manner

Having the new NeuralAttentionLib dependency in Flux should be fine, it seems a well designed and well maintained library. Maybe it contains more than what is strictly needed, so I was hoping we could just consolidate things in NNlib and avoid dispersion. I would be ok with both paths forward.

CarloLucibello avatar Apr 07 '22 06:04 CarloLucibello