haystack icon indicating copy to clipboard operation
haystack copied to clipboard

Input batch size does not match target batch size when using more than one positive context when training DPR

Open Raldir opened this issue 3 years ago • 7 comments

Describe the bug When fine-tuning DPR on a custom dataset with parameter num_positives being larger than one, I get a batch size mismatch error when starting training. Training works perfectly fine when num_positives=1. The list of 'positive_ctxs' is set to always be equal to num_postives. I have also tried out the default dataset (i.e. biencoder-nq) -- same issue. So when I train with a batch size of 16 and num_positives=2 I get the error below.

Error message File "/home/XXX/miniconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 213, in forward return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) File "/home/XXX/miniconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/functional.py", line 2262, in nll_loss .format(input.size(0), target.size(0))) ValueError: Expected input batch_size (16) to match target batch_size (32).

Additional context Code is essentially identical to the tutorial. I am using the latest version of the repository.

Raldir avatar May 09 '21 09:05 Raldir

I have the same issue. Also using code from tutorial. only works for num_positives=1. Did you find a way to solve the issue? error still occurs when using the json datasets from the tutorial

ottozastrow avatar Jul 23 '22 20:07 ottozastrow

thanks for reopening - happy to provide more info, but since its the code from the tutorial I assume thats all you need

ottozastrow avatar Jul 25 '22 08:07 ottozastrow

Hi @ottozastrow just to confirm you are referring to this tutorial Tutorial9_DPR_training.ipynb?

And could you tell me which version of Haystack are you using?

sjrl avatar Jul 25 '22 09:07 sjrl

sure - tested with 1.5.0 and 1.6.0 and occurs on both and yes that is the tutorial I followed (though locally not on colab) the link I used for the tutorial is this

  File "/home/otto/anaconda3/lib/python3.9/site-packages/haystack/modeling/training/base.py", line 291, in train
    loss = self.compute_loss(batch, step)
  File "/home/otto/anaconda3/lib/python3.9/site-packages/haystack/modeling/training/base.py", line 376, in compute_loss
    per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
  File "/home/otto/anaconda3/lib/python3.9/site-packages/haystack/modeling/model/biadaptive_model.py", line 194, in logits_to_loss
    all_losses = self.logits_to_loss_per_head(logits, **kwargs)
  File "/home/otto/anaconda3/lib/python3.9/site-packages/haystack/modeling/model/biadaptive_model.py", line 181, in logits_to_loss_per_head
    all_losses.append(head.logits_to_loss(logits=logits_for_one_head, **kwargs))
  File "/home/otto/anaconda3/lib/python3.9/site-packages/haystack/modeling/model/prediction_head.py", line 1071, in logits_to_loss
    loss = self.loss_fct(softmax_scores, targets)
  File "/home/otto/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/otto/anaconda3/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 211, in forward
    return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
  File "/home/otto/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2532, in nll_loss
    return torch._C._nn.nll_loss_nd(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (8) to match target batch_size (32).
``

ottozastrow avatar Jul 25 '22 10:07 ottozastrow

Hi @ottozastrow. Unfortunately, this doesn't look like a quick fix. I'll be adding it to our sprint backlog so we can dedicate time to this issue in the near future.

sjrl avatar Jul 25 '22 16:07 sjrl

ok thanks! (just incase this is helpful: for my usecase DPR with only 1 positive is not usable)

ottozastrow avatar Jul 25 '22 16:07 ottozastrow

Maybe @bogdankostic knows what's causing the problem or even has an idea for quick fix? He's a DPR expert! 🙂

julian-risch avatar Jul 26 '22 06:07 julian-risch