long-context-attention icon indicating copy to clipboard operation
long-context-attention copied to clipboard

pytorch_attn_forward的op_type默认参数是否应该改为efficient?

Open genghisun opened this issue 3 months ago • 10 comments

我没有A100 H100这种卡,所以用不了flash attention,只能用pytorch的sdpa实现。 最近跑的开源代码里用到了xFuserLongContextAttention(),我加上了xFuserLongContextAttention(attn_type=AttnType.TORCH),然而排查发现内部在调用select_flash_attn_impl的时候,虽然用的是pytorch_attn_forward https://github.com/feifeibear/long-context-attention/blob/2c9b7120e70392c83acd2006a4f716aa407143ac/yunchang/kernels/init.py#L126-L130 但这里pytorch_attn_forwardop_type默认参数还是flash,调用不到flash attention还是会报错。想问下这里的默认参数是否应该改为efficient? https://github.com/feifeibear/long-context-attention/blob/2c9b7120e70392c83acd2006a4f716aa407143ac/yunchang/kernels/attention.py#L41-L53 我理解用户如果用到了pytorch_attn_forward这个函数,就说明他不想用flash attention了。而如果他想用flash attention,他的attn_type就应该是FA什么的,也就不会用到pytorch_attn_forward这个函数了。

还不太熟悉这个库,只是提个小建议,如果说的不对请见谅

genghisun avatar Sep 24 '25 03:09 genghisun

这个是个 default 参数,真正使用时候会被正确赋值。

feifeibear avatar Sep 26 '25 02:09 feifeibear

不是的,我不是小白,我说的就是真正使用的时候,就算设置了attn_type=AttnType.TORCHselect_flash_attn_impl选择出了原生pytorch实现pytorch_attn_forward,但由于很多代码不会手动传attn_type参数到pytorch_attn_forward函数,导致使用了pytorch_attn_forward函数的默认参数,导致调用的实际上是FA而不是sdpa。

以下是我看到的一些使用场景,可能还有更多 https://github.com/feifeibear/long-context-attention/blob/2c9b7120e70392c83acd2006a4f716aa407143ac/yunchang/ring/ring_flash_attn.py#L36-L48 https://github.com/xdit-project/xDiT/blob/cd061157d13d074d9db69b2fcf5e26b408d8b74d/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py#L103-L132

genghisun avatar Sep 26 '25 03:09 genghisun

select_flash_attn_impl 不就已经根据 attn_type 选择除了正确的函数么?有什么函数

你看最新 main 分支吧?你看的 commid 是不是太久了

feifeibear avatar Sep 28 '25 02:09 feifeibear

我看的就是最新的main分支。 当设置了attn_type=AttnType.TORCH,select_flash_attn_impl选择出了原生pytorch实现pytorch_attn_forward,但pytorch_attn_forward默认用的是flash(FA)而不是efficient

genghisun avatar Sep 28 '25 02:09 genghisun

明白了。你能不能交一个 PR 把 attn_type 传给调用的 pytorch_attn_forward 函数

feifeibear avatar Sep 28 '25 11:09 feifeibear

我觉得还是应该把pytorch_attn_forward的op_type默认参数改为efficient更好一些,因为select_flash_attn_impl选择出的其他fn(比如flash_attn_forward_aiterflash_attn_forwardflash_attn3_func_forward等等)都没有传attn_type进去。

很愿意贡献一个PR~,我想了两个方案:

  1. 直接把pytorch_attn_forward的op_type默认参数改为efficient。
  2. 在pytorch_attn_forward旁边新增一个torch_attn_forward 🟰 pytorch_attn_forward(op_type='efficient')

方佬看看怎么做比较好

genghisun avatar Sep 29 '25 02:09 genghisun

为啥默认是 efficient 更好呢,fa 不如 efficient 么?

feifeibear avatar Sep 30 '25 01:09 feifeibear

我的理解是attn_type=AttnType.TORCH的时候,select_flash_attn_impl才会选出pytorch_attn_forward,所以应该优先efficient,因为用户是用不了或者不想用FA才设置attn_type=AttnType.TORCH的,默认op_type是FA的话不符合用户预期。

反过来,如果用户想用FA,就会设置attn_type=AttnType.FA,select_flash_attn_impl就会选出flash_attn_forward而不是pytorch_attn_forward。

genghisun avatar Sep 30 '25 02:09 genghisun

可能主要问题在于pytorch_attn_forward内部本身就有不同的实现版本吧,包括op_type in ["flash", "efficient"]。 Pytorch2.8.0文档里现在更是有MATH、FLASH_ATTENTION、EFFICIENT_ATTENTION、CUDNN_ATTENTION四种实现。 我觉得可以把AttnType.TORCH分的更细一些,像AttnType.SAGE_xxx有很多版本一样,分成AttnType.TORCH_MATHAttnType.TORCH_FLASHAttnType.TORCH_EFFICIENTAttnType.TORCH_CUDNN怎么样?

genghisun avatar Sep 30 '25 02:09 genghisun

AttnType.TORCH分的更细 没问题的

feifeibear avatar Oct 01 '25 06:10 feifeibear