maxtext
maxtext copied to clipboard
Long Context
I am trying to adapt Llama3 for long context, is 128k. I am training on a v5-256, and are trying to follow the procedure explained in https://arxiv.org/pdf/2407.14482. Basically this states:
We set the batch size to 32 to fit 4 million tokens in a batch and use a learning rate of
3e-5 to train 2000 steps (8B tokens in total).
I have prepared a dataset with 128k context, using the HF dataset. (Thanks @aireenmei).
My challenge is however that setting per_device_batch_size=1
gives me a global batch size of 256. This is way too high, and I get OOM errors. I want to split batches across devices. I attempted setting num_pipeline_microbatches=8
but this does not seem to work.
Are there other ways of accomplishing this? I understand gradient accumulation is not implemented (and I am not sure if it will work here).