flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

[Bug] [RWKV7] `fuse_norm` cause performance drop in throughput

Open Triang-jyed-driung opened this issue 7 months ago • 3 comments

Checklist

  • [x] I have checked FAQs and existing issues for similar problems
  • [x] My GPU is H100 and I have installed triton-nightly built by fla team, and double checked FAQs
  • [x] Please report this bug in English to ensure wider understanding and support

Describe the Bug

Setting fused_norm=True and train RWKV7 model results in the following error: TypeError: LayerNorm.forward() takes 2 positional arguments but 4 were given.

Steps to Reproduce the Bug

model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True)
model.train()
model.config.fuse_norm = True
model.forward(fake_labels)

Alternatively you can use flame to reproduce the bug

Expected Behavior

fuse_norm should be working.

Environment Information

  1. Torch: 2.7.0 preview
  2. Triton: 3.3.0 preview

Triang-jyed-driung avatar Apr 21 '25 08:04 Triang-jyed-driung

I have to manually modify the config to get fuse_norm working. However, the speed slows down by 10% (118kt/s -> 107kt/s) compared to the previous run without fuse_norm. Also, I see the following warning:

python3.13/site-packages/torch/_dynamo/variables/functions.py:1271: UserWarning: Dynamo does not know how to trace the builtin `cuda_utils.get_device_properties.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))

Triang-jyed-driung avatar Apr 21 '25 09:04 Triang-jyed-driung

@Triang-jyed-driung Hi, I think a better way is to modify the config file for the pretrained ckpt. Have you tried

AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True, fuse_norm=True)

yzhangcs avatar Apr 21 '25 17:04 yzhangcs

@Triang-jyed-driung Hi, I think a better way is to modify the config file for the pretrained ckpt. Have you tried

AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True, fuse_norm=True)

Yes, The fuse_norm configuration instantiates member variables when loading the model, so most operations need to be completed before loading the pretrained model to minimize interruptions in the computation graph.

@Triang-jyed-driung Thank you for your report :) The drop in throughput is a known issue. I will investigate in detail and optimize all norm-related kernels to improve throughput across a wide range of shapes.

zhiyuan1i avatar Apr 21 '25 23:04 zhiyuan1i

This issue is stale because it has been open for 30 days with no activity.

github-actions[bot] avatar Jul 13 '25 00:07 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Jul 20 '25 00:07 github-actions[bot]