Yu Zhang
Yu Zhang
@yzhangcs be careful that `RCP_IN2` should be multiplied after `tl.dot`, otherwise would lead to big precision loss.
@ching-sui1995 Hi, thank you for your interests. You can just utilize [this script](https://github.com/sustcsonglin/flash-linear-attention/blob/main/utils/convert_from_llama.py) to convert a hf style Llama-like pretrained LLMs into fla models. This process will keep the matching...
The peak lr is set to $3\times 10^{-5}$ with 1K warmup steps as said in the paper 
Actually the script follows these steps: 1) initialize a GSA or any desired FLA model; 2) load Mistral; 3) search for matching blocks; 4) replace the newly initialized blocks with...
@ching-sui1995 Hi, you can follow the steps in https://github.com/sustcsonglin/flash-linear-attention/tree/main/training#continual-pretraining to reproduce the results for now.
@muellerzr Hi, wondering if there are any progress on this bug. 👀 I also met this when trying the latest accelerate.
@AACengineer Hi, Transformer++ conducts decoding also in an autoregressive manner. During training, Transformer++ can be fully parallelized. However, we can also make use of the parallel scan to improve the...
@ByronHsu > Curious how did you measure the loss? ```py torch.manual_seed(42) batch_size, seq_len, hidde_size, vocab_size = 8, 4096, 2048, 128000 x = torch.randn(batch_size * seq_len, hidde_size).cuda().bfloat16().requires_grad_() target = torch.randint(0, vocab_size,...
@ByronHsu That's the diffs of current impls. I think this loss makes sense given that the vocab is large and the first output is computed under fp32 while the second...
@ByronHsu Hi, have you compared the final loss of FLCE with the naive counterpart? I think chunking the input into several pieces might be problematic for very large V and...