maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Long Context

Open peregilk opened this issue 6 months ago • 1 comments

I am trying to adapt Llama3 for long context, is 128k. I am training on a v5-256, and are trying to follow the procedure explained in https://arxiv.org/pdf/2407.14482. Basically this states:

 We set the batch size to 32 to fit 4 million tokens in a batch and use a learning rate of
3e-5 to train 2000 steps (8B tokens in total).

I have prepared a dataset with 128k context, using the HF dataset. (Thanks @aireenmei).

My challenge is however that setting per_device_batch_size=1 gives me a global batch size of 256. This is way too high, and I get OOM errors. I want to split batches across devices. I attempted setting num_pipeline_microbatches=8but this does not seem to work.

Are there other ways of accomplishing this? I understand gradient accumulation is not implemented (and I am not sure if it will work here).

peregilk avatar Jul 28 '24 10:07 peregilk