flash-linear-attention
flash-linear-attention copied to clipboard
Produce T2R Experiments in Gated Slot Attention Paper
Hi @yzhangcs
Congrats on your Gated Slot Attention Paper ! this work is really interesting.
I want to be able to reproduce your experiments on “finetuning pretrained Transformers to RNNs” (T2R) setting.
Could you share the code and let us know the steps that we need to finetune a model (like Mistral) into your GSA ?
@ching-sui1995 Hi, thank you for your interests. You can just utilize this script to convert a hf style Llama-like pretrained LLMs into fla models. This process will keep the matching parameters unchanged while newly initializing any mismatched parameters. After conversion, you can easily finetune the converted ckpt on Slimpajama/Fineweb data using the hyperparams reported in GSA.
The peak lr is set to $3\times 10^{-5}$ with 1K warmup steps as said in the paper
Thank you @yzhangcs for providing the conversion script. However, this only converts the HF style Llama model into fla style but does not provide additional functionality to initialize the new GSA layers as discussed in your paper.
Am I missing something here ?
Simply looking for an end-to-end solution on how to convert the model entirely with new layers for RNN-like finetuning.
Actually the script follows these steps: 1) initialize a GSA or any desired FLA model; 2) load Mistral; 3) search for matching blocks; 4) replace the newly initialized blocks with pretrained ones.
This approach ensures that mismatched structures are initialized in the same manner as GSA.
@ching-sui1995 Hi, you can follow the steps in https://github.com/sustcsonglin/flash-linear-attention/tree/main/training#continual-pretraining to reproduce the results for now.
Thank you @yzhangcs and @sustcsonglin !! you guys are amazing !