pytorch_attn_forward的op_type默认参数是否应该改为efficient?
我没有A100 H100这种卡,所以用不了flash attention,只能用pytorch的sdpa实现。
最近跑的开源代码里用到了xFuserLongContextAttention(),我加上了xFuserLongContextAttention(attn_type=AttnType.TORCH),然而排查发现内部在调用select_flash_attn_impl的时候,虽然用的是pytorch_attn_forward
https://github.com/feifeibear/long-context-attention/blob/2c9b7120e70392c83acd2006a4f716aa407143ac/yunchang/kernels/init.py#L126-L130
但这里pytorch_attn_forward的op_type默认参数还是flash,调用不到flash attention还是会报错。想问下这里的默认参数是否应该改为efficient?
https://github.com/feifeibear/long-context-attention/blob/2c9b7120e70392c83acd2006a4f716aa407143ac/yunchang/kernels/attention.py#L41-L53
我理解用户如果用到了pytorch_attn_forward这个函数,就说明他不想用flash attention了。而如果他想用flash attention,他的attn_type就应该是FA什么的,也就不会用到pytorch_attn_forward这个函数了。
还不太熟悉这个库,只是提个小建议,如果说的不对请见谅
这个是个 default 参数,真正使用时候会被正确赋值。
不是的,我不是小白,我说的就是真正使用的时候,就算设置了attn_type=AttnType.TORCH,select_flash_attn_impl选择出了原生pytorch实现pytorch_attn_forward,但由于很多代码不会手动传attn_type参数到pytorch_attn_forward函数,导致使用了pytorch_attn_forward函数的默认参数,导致调用的实际上是FA而不是sdpa。
以下是我看到的一些使用场景,可能还有更多 https://github.com/feifeibear/long-context-attention/blob/2c9b7120e70392c83acd2006a4f716aa407143ac/yunchang/ring/ring_flash_attn.py#L36-L48 https://github.com/xdit-project/xDiT/blob/cd061157d13d074d9db69b2fcf5e26b408d8b74d/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py#L103-L132
我看的就是最新的main分支。 当设置了attn_type=AttnType.TORCH,select_flash_attn_impl选择出了原生pytorch实现pytorch_attn_forward,但pytorch_attn_forward默认用的是flash(FA)而不是efficient。
明白了。你能不能交一个 PR 把 attn_type 传给调用的 pytorch_attn_forward 函数
我觉得还是应该把pytorch_attn_forward的op_type默认参数改为efficient更好一些,因为select_flash_attn_impl选择出的其他fn(比如flash_attn_forward_aiter、flash_attn_forward、flash_attn3_func_forward等等)都没有传attn_type进去。
很愿意贡献一个PR~,我想了两个方案:
- 直接把pytorch_attn_forward的op_type默认参数改为efficient。
- 在pytorch_attn_forward旁边新增一个torch_attn_forward 🟰 pytorch_attn_forward(op_type='efficient')
方佬看看怎么做比较好
为啥默认是 efficient 更好呢,fa 不如 efficient 么?
我的理解是attn_type=AttnType.TORCH的时候,select_flash_attn_impl才会选出pytorch_attn_forward,所以应该优先efficient,因为用户是用不了或者不想用FA才设置attn_type=AttnType.TORCH的,默认op_type是FA的话不符合用户预期。
反过来,如果用户想用FA,就会设置attn_type=AttnType.FA,select_flash_attn_impl就会选出flash_attn_forward而不是pytorch_attn_forward。
可能主要问题在于pytorch_attn_forward内部本身就有不同的实现版本吧,包括op_type in ["flash", "efficient"]。
Pytorch2.8.0文档里现在更是有MATH、FLASH_ATTENTION、EFFICIENT_ATTENTION、CUDNN_ATTENTION四种实现。
我觉得可以把AttnType.TORCH分的更细一些,像AttnType.SAGE_xxx有很多版本一样,分成AttnType.TORCH_MATH,AttnType.TORCH_FLASH,AttnType.TORCH_EFFICIENT,AttnType.TORCH_CUDNN怎么样?
AttnType.TORCH分的更细 没问题的