[Bug] Attention to qkv should use QKVParallelLinear
Environment
Agonistic
Describe the bug
When using LoRA to adapt Wan attention I found the qkv layers were not replaced as intended because they used ReplicatedLinear instead of QKVParallelLinear. Other models used one ReplicatedLinear with hidden_dim * 3
This should be a quick fix, but we should always use QKVParallelLinear because it fuses 3 matmuls into a larger matmul, saving kernel launch overhead. (And also for consistent style)
Will need a separate PR for this and update _param_names_mapping
cc @SolitaryThinker @jzhang38
Reproduction
None needed
Oh I see that's probably because the Wan ppl originally wrote it that way...we might have to keep it for weight loading, but still rename other layers to QKVParallelLinear
https://github.com/hao-ai-lab/FastVideo/blob/bdfdf1dfeea2aebac8a462df9c3bcb2d1d11a01c/fastvideo/v1/models/loader/fsdp_load.py#L206 https://github.com/hao-ai-lab/FastVideo/blob/bdfdf1dfeea2aebac8a462df9c3bcb2d1d11a01c/fastvideo/v1/configs/models/dits/hunyuanvideo.py#L55
EDIT: We don't apply TP for DiT so no need for QKVParallelLinear. Can potentially submit a PR to fuse Wan weight loading later
Would TP be needed in the 14B parameter model? Wanted to finetune that on a new VAE I was writing up. Happy to test this for you but wanted to know if there was some tests you guys had for this.
@philippe-eecs the paper says they do not use TP (too much comm overhead), just FSDP + unified sequence parallel
I think CPU offload in FSDP is almost always faster and more memory efficient than TP for encoders. Happy to know if you find a counter example