trl icon indicating copy to clipboard operation
trl copied to clipboard

Using IterableDataset crashed the SFTTrainer

Open helloworld1 opened this issue 1 year ago • 0 comments

When using datasets.IterableDataset instead of datasets.Dataset, trl trainer will crash

dataset = datasets.IterableDataset.from_generator(get_training_data(custom_args.data_path), features=datasets.Features({"prompt": datasets.Value("string")}))

dataset_train, dataset_eval = dataset, dataset

trainer = trl.SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset_train,
        eval_dataset=dataset_eval,
        dataset_text_field="prompt",
        max_seq_length=custom_args.max_seq_length,
        peft_config=peft_config,
        args=training_args
    )

It results into error:

Traceback (most recent call last):
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/training_hf_debug.py", line 134, in <module>
    main()
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/training_hf_debug.py", line 108, in main
    trainer = trl.SFTTrainer(
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 362, in __init__
    train_dataset = self._prepare_dataset(
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 508, in _prepare_dataset
    return self._prepare_non_packed_dataloader(
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 582, in _prepare_non_packed_dataloader
    tokenized_dataset = dataset.map(
TypeError: IterableDataset.map() got an unexpected keyword argument 'num_proc'

helloworld1 avatar Jun 23 '24 05:06 helloworld1