ring-flash-attention
ring-flash-attention copied to clipboard
Ring attention implementation with flash attention
多机训练速度问题
我尝试在多机多卡上训练,发现耗时相比单机上要增加很多,想同训练环境下相比Deepspeed Ulysses耗时增加了三倍,而单机上却没有这个问题,请问是什么原因导致的呢?
I use 4 gpus to run the code. my command is ``` torchrun --nproc_per_node 4 test/test_ring_flash_attn_varlen_func.py ``` my error is ``` rank1]: Traceback (most recent call last): [rank1]: File "/home/xxxx/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py",...
Were you able to find out the reason for the small numerical errors in backward pass with ring flash attention? I found the errors increase as you increase the world...
 您好,我在使用EasyContext的zigzag_ring_flash_attn模式的时候报错如上 我的所有数据都被group by length到32768+1的长度上(根据https://github.com/jzhang38/EasyContext/issues/31#issue-2308064466) 在数据并行模式下可以正常运行,但序列并行报错。 **code:** ``` def main(args): if args.output_dir: os.makedirs(args.output_dir, exist_ok=True) if args.wandb: import wandb wandb.login() set_seed(args.seed) timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulate_every, mixed_precision="bf16", log_with="wandb" if...
Hi @zhuzilin, follow up from https://github.com/zhuzilin/ring-flash-attention/issues/15 I just wanted to verify the causal, and I simply use loop because I dont have multigpus, but it should be working, when I...
Hi, when I increase the seqlen from 1024 * 8 to 1024 * 64 here: https://github.com/zhuzilin/ring-flash-attention/blob/9e2a7e543d6461cc935d44142fc99660de7b8579/benchmark/benchmark_varlen_qkvpacked_func.py#L18 Then, I run the code with ```python torchrun benchmark/benchmark_varlen_qkvpacked_func.py ``` The program starts to...
1. It seems the batch dimension will be disappeared after _upad_input function (this function is usually copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input). Then the block_lse obtained from L118 in zigzag_ring_flash_attn_varlen.py only has 2...