unilm icon indicating copy to clipboard operation
unilm copied to clipboard

Large batch size when pretraining E5 models

Open novoselrok opened this issue 1 year ago • 7 comments

In the E5 paper, you mention that you're using a batch size of 32768 during pretraining.

I'm planning on additionally pretraining the e5-{small,base}-v2 models on a custom dataset, but I'm having a hard time fitting batch sizes larger than ~128 on a single GPU (T4, A100). I'm already using half-precision training (torch.autocast), and I'm trying to implement checkpointing.

Do you have any tips or code on how to fit such a large batch size during training?

novoselrok avatar Jun 08 '23 18:06 novoselrok

Sure, the common techniques are:

  1. Use gradient checkpointing, it saves at least half of GPU memory while being ~30% slower.
  2. Use DeepSpeed launcher if possible, its ZeRO stage 2 partitions optimizer states and gradients across multiple GPUs.
  3. Restrict your input maximum length. You can freeze the position embedding, train with small input length (like 128 / 256), and then test with input length 512.

Also, if you use mined hard negatives, you do not need a large batch size. One hard negative is worth thousands of in-batch negatives.

Best, Liang

intfloat avatar Jun 09 '23 05:06 intfloat

Thanks that is very helpful!

  1. For gradient checkpointing, I've found this reference implementation for contrastive learning: https://sourcegraph.com/github.com/microsoft/LoRA/-/blob/examples/NLU/examples/research_projects/longform-qa/eli5_utils.py?L128-164 Is this a practical approach, or are there more effective approaches?
  2. Thanks, I'll try out DeepSpeed.
  3. Freezing the position embeddings is very interesting. I'm curious why does that make a difference?

Unfortunately, no mined negatives 🙂

novoselrok avatar Jun 09 '23 06:06 novoselrok

If you are using the Trainer from HuggingFace transformers library, gradient checkpointing is enabled by simply passing an argument --gradient_checkpointing True.

Shorter inputs mean less activation values to store in GPU memory. But using different lengths during training and test would create a discrepancy, we thus freeze the position embeddings to make the model stick to the values learned in the original MLM pretraining. Empirically, this strategy has a small gain over only fine-tuning part of the position embeddings.

intfloat avatar Jun 09 '23 06:06 intfloat

I'm using a native PyTorch training loop, but I can try to migrate to the Trainer class for the gradient checkpointing and the DeepSpeed integration. Hopefully, that gets me to batch size 32768 😄

novoselrok avatar Jun 09 '23 06:06 novoselrok

@intfloat One additional question: what is the difference between e5-{small,base,large} and e5-{small,base,large}-v2 models on Huggingface?

novoselrok avatar Jun 14 '23 07:06 novoselrok

The v2 models are pre-trained on larger text pair datasets, the network architecture and training recipes are the same.

intfloat avatar Jun 14 '23 11:06 intfloat

There is a way to train with too large batch for a single gpu in cost of one extra forward pass and in case of separate encoders for queries and passages: https://arxiv.org/pdf/2101.06983.pdf Minimal example for a single gpu:

# define some variables
bs = 1024  # target batch size
mbs = 256  # micro batch size (maximum available for gpu)
hidden_dim = 768  # model dim
tau = 0.01
loss_fn = torch.nn.CrossEntropyLoss()
device = "cuda"

# generate batch
batch = next(batch_iterator)  # dict, each value is a tensor with bs elements 

# compute representations
model.eval()
q = torch.zeros((bs, hidden_dim), device=device)  
p = torch.zeros((bs, hidden_dim), device=device)  
with torch.no_grad():  
    for i in range(0, bs, mbs):  
        j = i + mbs
        q[i:j], p[i:j] = model(**{k: v[i:j] for k, v in batch.items()})

# compute loss and derivatives w.r.t q and p  
q.requires_grad = True  
p.requires_grad = True  
s = F.normalize(q) @ F.normalize(p).t() / tau
labels = torch.arange(bs, device=device)  
loss = loss_fn(s, labels)
loss.backward()

# compute derivatives w.r.t model  
model.train()  
for i in range(0, bs, mbs):  
    j = i + mbs  
    qi, pi = model(**{k: v[i:j] for k, v in batch.items()}) 
    qi.backward(gradient=q.grad[i:j])  
    pi.backward(gradient=p.grad[i:j])

ololo123321 avatar Dec 11 '23 09:12 ololo123321