litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Sample packing for pretraining/fine-tuning

Open alitirmizi23 opened this issue 1 year ago • 16 comments

I was wondering if there are sample packing approaches defined somewhere for preprocessing and tokenization of datasets? I looked through different prepare_*.py, but couldn't find anything related to packing multiple sequences being packed into max_length for efficiency etc

Also, wondering how the data prep works as of now in the lit-gpt framework:

  • if a document, article, instruction/output pair exceeds the max sequence length, how is it treated? What about if a doc/article/instruction-output pair falls short of max seq. length? are the remaining time steps padded or are more sequences packed until max length is achieved?

alitirmizi23 avatar Oct 06 '23 00:10 alitirmizi23

cc @awaelchli, if you'd like to answer

carmocca avatar Nov 18 '23 22:11 carmocca

Just chiming in, from what I understand, this is not a simple feature to implement in general.

As one current example, the axolotl finetuning harness implements efficient sample packing with correct block diagonal attention masking through a series of monkey patches for the underlying huggingface model definitions for a few of the very popular models like llama and mistral. Though I have not looked through the code in detail, I believe it leverages the fact that the flash attention api supports the masking required to implement this scheme.

It seems like the simplicity of the lit-gpt model definition might actually make this easier to implement as a first class feature. It is relevant for efficient finetuning (the reason it's incorporated into axolotl), and general wisdom (and whispers from inside large corps) suggest that this type of block diagonal masking is better for large scale training code.

I (and other collaborators) would be very interested in this feature and it would increase the attractiveness of lit-gpt's model building code as a hf alternative. Just my 2c!

jwkirchenbauer avatar Nov 30 '23 22:11 jwkirchenbauer

if a document, article, instruction/output pair exceeds the max sequence length, how is it treated?

Depends on the data preparation, but our scripts trim it: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/scripts/prepare_alpaca.py#L116-L117

What about if a doc/article/instruction-output pair falls short of max seq. length? are the remaining time steps padded or are more sequences packed until max length is achieved?

They are packed in the pretraining scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/pretrain/redpajama.py#L249-L257 and padded in the fine-tuning scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/finetune/full.py#L232-L246

suggest that this type of block diagonal masking is better for large scale training code.

The inconvenience is that torch.nn.functional.scaled_dot_product_attention will not use flash-attn if an attention mask is passed. It would be necessary to integrate this specific flavor of flash attention: https://github.com/Dao-AILab/flash-attention/issues/654 which would again require building it

carmocca avatar Jan 18 '24 22:01 carmocca

https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752

vgoklani avatar Jan 31 '24 02:01 vgoklani

We'd also be very interested in this feature!

corbt avatar Apr 28 '24 14:04 corbt

@carmocca let’s revive this issue it doesn’t look like spda from PyTorch 2.3 has solved the underlying issue, if that’s the case let’s add flash attention as an optional dependency

lantiga avatar Apr 28 '24 15:04 lantiga

PyTorch has added support for arbitrary custom masks as long which are meant to be performant when used with torch.compile: https://github.com/pytorch/pytorch/pull/121845

They are also considering more generic API changes that are in discussion: https://github.com/pytorch/pytorch/issues/110681.

As of today, Tri Dao's package is the only option as far as I know.

carmocca avatar Apr 29 '24 12:04 carmocca

PyTorch has added support for arbitrary custom masks as long which are meant to be performant when used with torch.compile: pytorch/pytorch#121845

They are also considering more generic API changes that are in discussion: pytorch/pytorch#110681.

As of today, Tri Dao's package is the only option as far as I know.

I think that xformers is doing it as well

samsja avatar Jul 06 '24 01:07 samsja

@samsja CUDNN attention is most likely the best option today (see flash attention 3 paper figures) that supports attention masks. xformers is not as competitive on H100s at least.

carmocca avatar Jul 15 '24 12:07 carmocca

oh I see, any chance litgpt will integrate some of this option at this point ?

By any chance do you have some benchmark comparing fa2/fa3/xformers/torch sdpa ?

samsja avatar Jul 15 '24 13:07 samsja

That's a good question. We don't have a benchmark but LitGPT already supports FlashAttention-2 via PyTorch's SDPA. The plan is to also support FlashAttention-3 (#1578)

rasbt avatar Jul 15 '24 13:07 rasbt

That's a good question. We don't have a benchmark but LitGPT already supports FlashAttention-2 via PyTorch's SDPA. The plan is to also support FlashAttention-3 (#1578)

unfortunately torch sdpa cannot leverage flash attention with custom masks (for context stuffing), contrary to xformers and the original flash attention implementation. Bit of a blocker. I am currently implementing it using xformers and litgpt

samsja avatar Jul 15 '24 13:07 samsja