trl
trl copied to clipboard
Using IterableDataset crashed the SFTTrainer
When using datasets.IterableDataset instead of datasets.Dataset, trl trainer will crash
dataset = datasets.IterableDataset.from_generator(get_training_data(custom_args.data_path), features=datasets.Features({"prompt": datasets.Value("string")}))
dataset_train, dataset_eval = dataset, dataset
trainer = trl.SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
eval_dataset=dataset_eval,
dataset_text_field="prompt",
max_seq_length=custom_args.max_seq_length,
peft_config=peft_config,
args=training_args
)
It results into error:
Traceback (most recent call last):
File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/training_hf_debug.py", line 134, in <module>
main()
File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/training_hf_debug.py", line 108, in main
trainer = trl.SFTTrainer(
File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
return f(*args, **kwargs)
File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 362, in __init__
train_dataset = self._prepare_dataset(
File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 508, in _prepare_dataset
return self._prepare_non_packed_dataloader(
File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 582, in _prepare_non_packed_dataloader
tokenized_dataset = dataset.map(
TypeError: IterableDataset.map() got an unexpected keyword argument 'num_proc'