flash-linear-attention
flash-linear-attention copied to clipboard
[Feature Request] Customize initialization / Add a switch for turning off FLA's initialization
Feature Request
- Include a flag (e.g.,
use_default_init) in the model configuration or constructor. When set toFalse, this flag would disable FLA's default initialization logic entirely, allowing users to apply their own initialization methods. - Provide examples demonstrating how to customize initialization or disable it entirely.
Motivation
- Changing the default initializer globally affects all models and may lead to suboptimal results for certain architectures (especially RWKV7).
- Currently, I see no option to disable or customize the initialization process introduced by FLA.
- This initialization is compatible with neither the Pytorch default initialization nor the Hugging Face Transformers' default initialization. This may lead to some layers initialized as Pytorch default, some layers to 0.02, and others 0.006.
- This initialization does not scale (in terms of both standard parametrization and maximal update parametrization). Generally, a scalable initialization (Pytorch default / Xavier / Kaiming) involves a standard derivation proportional to 1/sqrt(width).
Your Contribution
I am willing to test and explore more initializations.