flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

[Feature Request] Customize initialization / Add a switch for turning off FLA's initialization

Open Triang-jyed-driung opened this issue 8 months ago • 4 comments

Feature Request

  1. Include a flag (e.g., use_default_init) in the model configuration or constructor. When set to False, this flag would disable FLA's default initialization logic entirely, allowing users to apply their own initialization methods.
  2. Provide examples demonstrating how to customize initialization or disable it entirely.

Motivation

  1. Changing the default initializer globally affects all models and may lead to suboptimal results for certain architectures (especially RWKV7).
  2. Currently, I see no option to disable or customize the initialization process introduced by FLA.
  3. This initialization is compatible with neither the Pytorch default initialization nor the Hugging Face Transformers' default initialization. This may lead to some layers initialized as Pytorch default, some layers to 0.02, and others 0.006.
  4. This initialization does not scale (in terms of both standard parametrization and maximal update parametrization). Generally, a scalable initialization (Pytorch default / Xavier / Kaiming) involves a standard derivation proportional to 1/sqrt(width).

Your Contribution

I am willing to test and explore more initializations.

Triang-jyed-driung avatar Mar 10 '25 05:03 Triang-jyed-driung

@Triang-jyed-driung Good point! Your contributions are welcome, we will test 0.1 * sqrt(1/d) recently.

yzhangcs avatar Mar 10 '25 05:03 yzhangcs

Include a flag (e.g., use_default_init) in the model configuration or constructor. When set to False, this flag would disable FLA's default initialization logic entirely, allowing users to apply their own initialization methods.

Do you mean overriding init_weights fn in modeling_xxx entirely?

yzhangcs avatar Mar 10 '25 05:03 yzhangcs

I mean, if use_default_init is set to False, then init_weights does nothing and returns.

Triang-jyed-driung avatar Mar 11 '25 05:03 Triang-jyed-driung

You can try:

from transformers.modeling_utils import no_init_weights
with no_init_weights():
    model = AutoModelForCausalLM.from_pretrained()

And I will fetch up to bo's init to RWKV7 in the next few days

zhiyuan1i avatar Mar 12 '25 06:03 zhiyuan1i