trl icon indicating copy to clipboard operation
trl copied to clipboard

vsft_llava: ValueError: Expected input batch_size (78528) to match target batch_size (41728).

Open kishan-character opened this issue 1 year ago • 2 comments

Running the example commands in: https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py

results in the following error:

/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
Traceback (most recent call last):
  File "/home/kishan/character-tech/flow/archive/vsft_llava.py", line 220, in <module>
    trainer.train()
  File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 418, in train
    output = super().train(*args, **kwargs)
  File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 1903, in train
    return inner_training_loop(
  File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 2245, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 3266, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 3305, in compute_loss
    outputs["loss"] = F.cross_entropy(logits_NV, labels_N)
  File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/torch/nn/functional.py", line 3086, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (78528) to match target batch_size (41728).

kishan-character avatar May 31 '24 18:05 kishan-character

Using pip install git+https://github.com/huggingface/datasets.git seems to have fixed this issue

kishan-character avatar May 31 '24 18:05 kishan-character

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Jul 01 '24 15:07 github-actions[bot]