long_llama icon indicating copy to clipboard operation
long_llama copied to clipboard

Support for gradient_checkpointing

Open Richar-Du opened this issue 11 months ago • 3 comments

Thanks for your awesome work! There is a small problem: when I fine-tune long_llama with gradient_checkpointing, it raises an error: image Could you please update the code in transformers to make long_llama support gradient_checkpointing. I think it is useful for the community to use long_llama. @CStanKonrad

Richar-Du avatar Jul 13 '23 13:07 Richar-Du

Hi, thanks for the request. In the recent commit, I have added initial support for gradient checkpointing (it just skips memory layers). As I am writing, it is not yet present in the Hugging Face repository, so to use it you can download code from the src directory in this repository and write something like this:

from transformers import LlamaTokenizer
from .modeling_longllama import LongLlamaForCausalLM
import torch

MODEL_PATH = "syzymon/long_llama_3b"

tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LongLlamaForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float32)

CStanKonrad avatar Jul 13 '23 14:07 CStanKonrad

Thanks for your commit!

Now I would like to fine-tune longllama, but the sequence is too long and it returns CUDA OOM (4x80G). I wonder if I could fine-tune longllama under a regular framework without support for long context (e.g. the training framework of alpaca or vicuna). If I could not, could you please release the fine-tuning code of longllama?

Richar-Du avatar Jul 14 '23 02:07 Richar-Du

I apologize for the late response. We have recently published the code that allows for fine-tuning the model on a single A100 80GB GPU. We use a total context size of 2048, with last_context_length being 1024. For shorter inputs, we randomly decide how much data will be present in memory. We achieve this by randomly padding the input.

You can try the instruction+chat fine-tuned model in the Colab.

For the Colab model, we provide the fine-tuning config and log of train loss.

CStanKonrad avatar Aug 08 '23 18:08 CStanKonrad