flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

Produce T2R Experiments in Gated Slot Attention Paper

Open ching-sui1995 opened this issue 1 year ago • 4 comments

Hi @yzhangcs

Congrats on your Gated Slot Attention Paper ! this work is really interesting.

I want to be able to reproduce your experiments on “finetuning pretrained Transformers to RNNs” (T2R) setting.

Could you share the code and let us know the steps that we need to finetune a model (like Mistral) into your GSA ?

ching-sui1995 avatar Oct 08 '24 22:10 ching-sui1995

@ching-sui1995 Hi, thank you for your interests. You can just utilize this script to convert a hf style Llama-like pretrained LLMs into fla models. This process will keep the matching parameters unchanged while newly initializing any mismatched parameters. After conversion, you can easily finetune the converted ckpt on Slimpajama/Fineweb data using the hyperparams reported in GSA.

yzhangcs avatar Oct 12 '24 05:10 yzhangcs

The peak lr is set to $3\times 10^{-5}$ with 1K warmup steps as said in the paper

yzhangcs avatar Oct 12 '24 05:10 yzhangcs

Thank you @yzhangcs for providing the conversion script. However, this only converts the HF style Llama model into fla style but does not provide additional functionality to initialize the new GSA layers as discussed in your paper.

Am I missing something here ?

Simply looking for an end-to-end solution on how to convert the model entirely with new layers for RNN-like finetuning.

ching-sui1995 avatar Oct 12 '24 18:10 ching-sui1995

Actually the script follows these steps: 1) initialize a GSA or any desired FLA model; 2) load Mistral; 3) search for matching blocks; 4) replace the newly initialized blocks with pretrained ones.

This approach ensures that mismatched structures are initialized in the same manner as GSA.

yzhangcs avatar Oct 12 '24 18:10 yzhangcs

@ching-sui1995 Hi, you can follow the steps in https://github.com/sustcsonglin/flash-linear-attention/tree/main/training#continual-pretraining to reproduce the results for now.

yzhangcs avatar Oct 18 '24 11:10 yzhangcs

Thank you @yzhangcs and @sustcsonglin !! you guys are amazing !

ching-sui1995 avatar Oct 19 '24 15:10 ching-sui1995