Metalhead.jl
Metalhead.jl copied to clipboard
Move MHAttention layer to Flux
@theabhirath @darsnack do you think the multi head attention layer is now general enough we can move it to Flux?
Is NNAttentionLib.matmul simple enough that can be ported without carrying over another library dependance?
If Flux is willing to take on NeuralAttentionlib as a dep, and we would need to rework it to accept more inputs. Currently, it only accepts 3D.
Oh sorry, I posted my comment before the page refreshed with your's. I think all we need to provide is a parallelized version of NNlib.batched_mul for 4D inputs. It could be something that's specific for this layer that calls batched_mul under the hood, since I know there were some concerns about generic 4D implementation when it was brought up.
NeuralAttentionlib already works with more than 3D inputs - one of the reasons I used it as a dep was that it would allow that functionality in the future (see https://github.com/FluxML/Metalhead.jl/pull/135#discussion_r827697134). The only concern could probably be that while the GPU path is parallelised (it uses the same CUBLAS functions underneath as NNlib), the CPU path is not (https://github.com/FluxML/Metalhead.jl/pull/135#discussion_r829147436). And NeuralAttentionlib basically already provides a readymade multiheaded self-attention layer. I had thought about a PR but decided against it because vanilla attention is hardly ever used anymore - most attention layers involve some novelty and so have to be written out in a custom manner (A 4D+ version of NNlib.batched_mul wouldn't hurt though, as I brought up in https://github.com/FluxML/NNlib.jl/issues/391)
So you say that the attention layer in pytorch is hardly used in practice? https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
I'm not sure what you mean. I'm pretty sure everyone in this thread wants to add the layer to Flux if that's what you're getting at.
So you say that the attention layer in pytorch is hardly used in practice? https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
No, this one is used quite often...I meant the layer in the form as it is in Metalhead, which is quite task-specific, as opposed to say the NeuralAttentionlib functions (https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.multihead_qkv_attention or https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.generic_multihead_qkv_attention, although the PyTorch function certainly exposes a lot more stuff that can be tweaked - something the Flux layer could possibly incorporate)
The only concern could probably be that while the GPU path is parallelised (it uses the same CUBLAS functions underneath as NNlib), the CPU path is not (https://github.com/FluxML/Metalhead.jl/pull/135#discussion_r829147436).
I can adapt the same multithreading approach that batched_mul use if the CPU part is really a concern.
I'm not sure what you mean. I'm pretty sure everyone in this thread wants to add the layer to Flux if that's what you're getting at.
Sorry I misread this remark:
I had thought about a PR but decided against it because vanilla attention is hardly ever used anymore - most attention layers involve some novelty and so have to be written out in a custom manner
Having the new NeuralAttentionLib dependency in Flux should be fine, it seems a well designed and well maintained library. Maybe it contains more than what is strictly needed, so I was hoping we could just consolidate things in NNlib and avoid dispersion. I would be ok with both paths forward.