trl
trl copied to clipboard
vsft_llava: ValueError: Expected input batch_size (78528) to match target batch_size (41728).
Running the example commands in: https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py
results in the following error:
/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
warnings.warn('Was asked to gather along dimension 0, but all '
Traceback (most recent call last):
File "/home/kishan/character-tech/flow/archive/vsft_llava.py", line 220, in <module>
trainer.train()
File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 418, in train
output = super().train(*args, **kwargs)
File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 1903, in train
return inner_training_loop(
File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 2245, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 3266, in training_step
loss = self.compute_loss(model, inputs)
File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/transformers/trainer.py", line 3305, in compute_loss
outputs["loss"] = F.cross_entropy(logits_NV, labels_N)
File "/home/kishan/miniconda3/envs/cai-env/lib/python3.10/site-packages/torch/nn/functional.py", line 3086, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (78528) to match target batch_size (41728).
Using pip install git+https://github.com/huggingface/datasets.git seems to have fixed this issue
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.