unsloth
unsloth copied to clipboard
TemplateError: Conversation roles must alternate user/assistant/user/assistant/...
Code:
from unsloth import FastLanguageModel
import torch
max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "mistralai/Mistral-7B-Instruct-v0.2", # Choose ANY! eg mistralai/Mistral-7B-Instruct-v0.2
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
#@title Alignment Handbook utils
import os
import re
from typing import List, Literal, Optional
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
def apply_chat_template(
example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"
):
def _strip_prefix(s, pattern):
# Use re.escape to escape any special characters in the pattern
return re.sub(f"^{re.escape(pattern)}", "", s)
if task == "dpo":
if all(k in example.keys() for k in ("chosen", "rejected")):
# Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]
# Insert system message
# if example["chosen"][0]["role"] != "system":
# prompt_messages.insert(0, {"role": "system", "content": ""})
# else:
# print(msg['role'] == 'system' in example['chosen'])
# print(msg['role'] == 'system' in example['rejected'])
for msg, msg1 in zip(example['chosen'], example['rejected']):
print(msg['role'], msg1['role'])
prompt_messages.insert(0, example["chosen"][0])
# TODO: handle case where chosen/rejected also have system messages
chosen_messages = example["chosen"][1:]
rejected_messages = example["rejected"][1:]
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
example["text_prompt"] = tokenizer.apply_chat_template(
prompt_messages, tokenize=False, add_generation_prompt=True
)
example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
else:
raise ValueError(
f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
else:
raise ValueError(
f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}"
)
return example
def get_datasets(
data_config: dict,
splits: List[str] = ["train", "test"],
shuffle: bool = True,
) -> DatasetDict:
"""
Loads one or more datasets with varying training set proportions.
Args:
data_config (`DataArguments` or `dict`):
Dataset configuration and split proportions.
splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training and testing/validation data.
Returns
[`DatasetDict`]: The dataset dictionary containing the loaded datasets.
"""
if type(data_config) is dict:
# Structure of the input is:
# dataset_mixer = {
# "dataset1": 0.5,
# "dataset1": 0.3,
# "dataset1": 0.2,
# }
dataset_mixer = data_config
else:
raise ValueError(f"Data config {data_config} not recognized.")
raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
return raw_datasets
def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict:
"""
Loads and mixes datasets according to proportions specified in `dataset_mixer`.
Args:
dataset_mixer (`dict`):
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
splits (Optional[List[str]], *optional*, defaults to `None`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training and testing/validation data.
"""
raw_datasets = DatasetDict()
raw_train_datasets = []
raw_val_datasets = []
fracs = []
for ds, frac in dataset_mixer.items():
fracs.append(frac)
for split in splits:
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(ds, split=split)
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(ds, split))
if "train" in split:
raw_train_datasets.append(dataset)
elif "test" in split:
raw_val_datasets.append(dataset)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
if any(frac < 0 for frac in fracs):
raise ValueError("Dataset fractions cannot be negative.")
if len(raw_train_datasets) > 0:
train_subsets = []
for dataset, frac in zip(raw_train_datasets, fracs):
train_subset = dataset.select(range(int(frac * len(dataset))))
train_subsets.append(train_subset)
if shuffle:
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
else:
raw_datasets["train"] = concatenate_datasets(train_subsets)
# No subsampling for test datasets to enable fair comparison across models
if len(raw_val_datasets) > 0:
if shuffle:
raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
else:
raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
if len(raw_datasets) == 0:
raise ValueError(
f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
)
return raw_datasets
raw_datasets = get_datasets(
{"HuggingFaceH4/ultrafeedback_binarized" : 0.005}, # 0.5% sampled
splits = ["train_prefs", "test_prefs"],
)
column_names = list(raw_datasets["train"].features)
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs = {"tokenizer": tokenizer, "task": "dpo"},
# num_proc = 12,
remove_columns = column_names,
desc = "Formatting comparisons with prompt template",
)
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
for split in ["train", "test"]:
raw_datasets[split] = raw_datasets[split].rename_columns(
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
)
Error:
user user
assistant assistant
---------------------------------------------------------------------------
TemplateError Traceback (most recent call last)
[<ipython-input-25-5f055c4a6ef3>](https://localhost:8080/#) in <cell line: 7>()
5 column_names = list(raw_datasets["train"].features)
6
----> 7 raw_datasets = raw_datasets.map(
8 apply_chat_template,
9 fn_kwargs = {"tokenizer": tokenizer, "task": "dpo"},
12 frames
[/usr/local/lib/python3.10/dist-packages/datasets/dataset_dict.py](https://localhost:8080/#) in map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)
866 cache_file_names = {k: None for k in self}
867 return DatasetDict(
--> 868 {
869 k: dataset.map(
870 function=function,
[/usr/local/lib/python3.10/dist-packages/datasets/dataset_dict.py](https://localhost:8080/#) in <dictcomp>(.0)
867 return DatasetDict(
868 {
--> 869 k: dataset.map(
870 function=function,
871 with_indices=with_indices,
[/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
591 self: "Dataset" = kwargs.pop("self")
592 # apply actual function
--> 593 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
594 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
595 for dataset in datasets:
[/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
556 }
557 # apply actual function
--> 558 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
559 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
560 # re-apply format to the output
[/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py](https://localhost:8080/#) in map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3103 desc=desc or "Map",
3104 ) as pbar:
-> 3105 for rank, done, content in Dataset._map_single(**dataset_kwargs):
3106 if done:
3107 shards_done += 1
[/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py](https://localhost:8080/#) in _map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
3456 _time = time.time()
3457 for i, example in shard_iterable:
-> 3458 example = apply_function_on_filtered_inputs(example, i, offset=offset)
3459 if update_data:
3460 if i == 0:
[/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py](https://localhost:8080/#) in apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
3359 if with_rank:
3360 additional_args += (rank,)
-> 3361 processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
3362 if isinstance(processed_inputs, LazyDict):
3363 processed_inputs = {
[<ipython-input-23-f7edf0bc29c5>](https://localhost:8080/#) in apply_chat_template(example, tokenizer, task, assistant_prefix)
35 chosen_messages = example["chosen"][1:]
36 rejected_messages = example["rejected"][1:]
---> 37 example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
38 example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
39 example["text_prompt"] = tokenizer.apply_chat_template(
[/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py](https://localhost:8080/#) in apply_chat_template(self, conversation, chat_template, add_generation_prompt, tokenize, padding, truncation, max_length, return_tensors, return_dict, **tokenizer_kwargs)
1743 compiled_template = self._compile_jinja_template(chat_template)
1744
-> 1745 rendered = compiled_template.render(
1746 messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
1747 )
[/usr/local/lib/python3.10/dist-packages/jinja2/environment.py](https://localhost:8080/#) in render(self, *args, **kwargs)
1299 return self.environment.concat(self.root_render_func(ctx)) # type: ignore
1300 except Exception:
-> 1301 self.environment.handle_exception()
1302
1303 async def render_async(self, *args: t.Any, **kwargs: t.Any) -> str:
[/usr/local/lib/python3.10/dist-packages/jinja2/environment.py](https://localhost:8080/#) in handle_exception(self, source)
934 from .debug import rewrite_traceback_stack
935
--> 936 raise rewrite_traceback_stack(source=source)
937
938 def join_path(self, template: str, parent: str) -> str:
<template> in top-level template code()
[/usr/local/lib/python3.10/dist-packages/jinja2/sandbox.py](https://localhost:8080/#) in call(_SandboxedEnvironment__self, _SandboxedEnvironment__context, _SandboxedEnvironment__obj, *args, **kwargs)
391 if not __self.is_safe_callable(__obj):
392 raise SecurityError(f"{__obj!r} is not safely callable")
--> 393 return __context.call(__obj, *args, **kwargs)
394
395
[/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py](https://localhost:8080/#) in raise_exception(message)
1788
1789 def raise_exception(message):
-> 1790 raise TemplateError(message)
1791
1792 jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
TemplateError: Conversation roles must alternate user/assistant/user/assistant/...
How can I solve this???