metaseq icon indicating copy to clipboard operation
metaseq copied to clipboard

Adding flash attention for sequence parallel

Open dianaml0 opened this issue 1 year ago • 4 comments

Patch Description Creating this PR off of #511, so it can be reviewed by @stephenroller

The last commit (3d709dba5c4be713fd821dc4e0f6b6f90f5ead40) removes some changes from the sequence parallel code which enabled testing with world size of 1. CI is not currently running the test anyway because CI needs to be updated for the test to run.

The forward and backward tests are passing right now. However in some cases, about .2% of the elements fail

Testing steps Unit Test gpu_tests/test_sequence_parallel_transformer_layer.py

dianaml0 avatar Dec 23 '22 21:12 dianaml0