stf Example not working
Current version: commit 84156f179f91f519e48185414391d040112f2d34 (HEAD -> main, origin/main, origin/HEAD) updated on Jun 3 2024
I tired to run the following script in example/scripts/stf.py:
# regular:
python examples/scripts/sft.py \
--model_name_or_path="facebook/opt-350m" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=64 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--gradient_checkpointing
Error message:
Map: 0%| | 0/9846 [00:00<?, ? examples/s]
Traceback (most recent call last):
File "/Users/tatoaoliang/Downloads/Work/trl/examples/scripts/sft.py", line 137, in <module>
trainer = SFTTrainer(
^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 360, in __init__
train_dataset = self._prepare_dataset(
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 506, in _prepare_dataset
return self._prepare_non_packed_dataloader(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 574, in _prepare_non_packed_dataloader
tokenized_dataset = dataset.map(
^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 602, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 567, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3156, in map
for rank, done, content in Dataset._map_single(**dataset_kwargs):
File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3548, in _map_single
batch = apply_function_on_filtered_inputs(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3417, in apply_function_on_filtered_inputs
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 545, in tokenize
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/formatting/formatting.py", line 271, in __getitem__
value = self.data[key]
~~~~~~~~~^^^^^
KeyError: None
I check the codes, here is the original snippet of _prepare_non_packed_dataloader function in "trl/trainer/sft_trainer.py" 529 line:
def _prepare_non_packed_dataloader(
self,
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
formatting_func=None,
add_special_tokens=True,
remove_unused_columns=True,
):
#### debugger told me that formatting_func is None and dataset_text_field is None
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False
#### so use_formatting_func is False
# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
add_special_tokens=add_special_tokens,
truncation=True,
padding=False,
max_length=max_seq_length,
return_overflowing_tokens=False,
return_length=False,
)
So it seems that formatting_func should not be None.
it is defined in sft_trainer.py , line 313
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
and get_formatting_func_from_dataset is in trl/extras/dataset_formatting.py, line 60:
def get_formatting_func_from_dataset(
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- `ChatML` with [{"role": str, "content": str}]
- `instruction` with [{"prompt": str, "completion": str}]
Args:
dataset (Dataset): User dataset
tokenizer (AutoTokenizer): Tokenizer used for formatting
Returns:
Callable: Formatting function if the dataset format is supported else None
"""
if isinstance(dataset, Dataset):
if "messages" in dataset.features:
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "messages")
if "conversations" in dataset.features:
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
return None
But openassistant-guanaco dataset only has the feature "text", so it is incompatible.
https://huggingface.co/datasets/timdettmers/openassistant-guanaco?row=0
I am also running into the same issue, what other package versions are you using? I am able to run some examples like the basic SFTTrainer from the README, but stf.py is not working
same problem here
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.