haystack
haystack copied to clipboard
DPR Training doesn't support num_positives > 1
Describe the bug
Trying to train DPR with num_positives > 1 results in an error.
This was discovered in #3084
Error message
Traceback (most recent call last):
File "/home/user/train_dpr/train_dpr.py", line 141, in <module>
retriever.train(
File "/home/user/.local/lib/python3.9/site-packages/haystack/nodes/retriever/dense.py", line 686, in train
trainer.train()
File "/home/user/.local/lib/python3.9/site-packages/haystack/modeling/training/base.py", line 291, in train
loss = self.compute_loss(batch, step)
File "/home/user/.local/lib/python3.9/site-packages/haystack/modeling/training/base.py", line 393, in compute_loss
per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
File "/home/user/.local/lib/python3.9/site-packages/haystack/modeling/model/biadaptive_model.py", line 198, in logits_to_loss
all_losses = self.logits_to_loss_per_head(logits, **kwargs)
File "/home/user/.local/lib/python3.9/site-packages/haystack/modeling/model/biadaptive_model.py", line 185, in logits_to_loss_per_head
all_losses.append(head.logits_to_loss(logits=logits_for_one_head, **kwargs))
File "/home/user/.local/lib/python3.9/site-packages/haystack/modeling/model/prediction_head.py", line 1071, in logits_to_loss
loss = self.loss_fct(softmax_scores, targets)
File "/home/user/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/user/.local/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 211, in forward
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
File "/home/user/.local/lib/python3.9/site-packages/torch/nn/functional.py", line 2689, in nll_loss
return torch._C._nn.nll_loss_nd(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (8) to match target batch_size (224).
Expected behavior The training works with multiple positive contexts.
Additional context https://github.com/deepset-ai/haystack/issues/3084#issuecomment-1227669048
To Reproduce
Use a dataset with multiple positive contexts for each question and pass num_positives > 1 to train
FAQ Check
- [x] Have you had a look at our new FAQ page?
System:
- OS:
- GPU/CPU:
- Haystack version (commit or version number): v1.8.0
- DocumentStore:
- Reader:
- Retriever: DensePassageRetriever