Fuse MLP in attention mechanism
Due to https://github.com/facebookresearch/xformers/issues/286 we cannot currently fuse the bias/gelu/activation into a single kernel using triton. This means we're just use a standard MLP and are probably taking a perf hit.
In megatron deepspeed, they use a torch scripted GeLU/Bias function with additional flags to fuse the operation: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/master/megatron/model/fused_bias_gelu.py
I haven't managed to get this to work, as the global settings in this file cause the xFormers rotary embeddings to fail: https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py#L21
Combining this scripted function with standard Linear operations + Dropout may give us a slight performance boost. Waiting for triton dropout support in BF16 seems like it might take some while (I think related https://github.com/openai/triton/pull/431)
cc @blefaudeux
This may also be a viable solution: https://github.com/facebookresearch/xformers/pull/352
I tried this briefly but ran into errors based on the ZeRO 3 hooks. Going to re-try to see if I can get this to work!