[Data] Implement dataset mixer for combining datasets in training
Feature request
In the alignment-handbook, we implemented a "dataset mixer" that allows one to easily combine datasets in varying proportions, provided they all share the same schema.
It could be interesting to port this mixer to TRL, so that users can easily combine datasets during training. The only caveat I see is that to support the CLI training e.g. trl sft ... we'd need a data structure that is compatible because dict objects don't place nice with CLIs.
Motivation
Advanced post-training typically combines different datasets / proportions. Supporting this in TRL would allow us to gradually deprecate the handbook in favour of using the lib directly.
Your contribution
Open to discussion :)
@lewtun Shouldn't this be a part of Datasets instead?
Datasets already has Interleaves, which also mixes datasets in a similar way. It's not quite integrable as it is, but may be useful.
As for integrating it into the CLI, I think the best way to do it would be with a config file, like a JSON, something like this::
trl sft --dataset_mixer --mix_config mix_config.json
mix_config.json:
{
"dataset_mixer": {
"dataset_1": 0.4,
"dataset_2": 0.3,
"dataset_3": 0.2
},
"splits": ["train", "train", "test"],
"configs": ["main", "math", "logic"],
"columns_to_keep": ["text", "input", "text"],
"shuffle": true
}
or a simpler JSON:
[
["dataset_1", 0.4, "main", "train", "text"],
["dataset_2", 0.3, "math", "train", "input"],
["dataset_3", 0.2, "logic", "test", "text"],
]
if the parsing doesn't get too complicated.
How does that sound?
Yes good point about whether this should live in datasets or trl. Gently pinging @lhoestq for his thoughts on this.
Context for Quentin: a common workflow in LLM fine-tuning is to mix datasets with mixed proportions (e.g. 30% of dataset together with 50% of dataset B). The interleave_datasets() method in datasets does something similar, but uses sampling and a stopping strategy which isn't quite the same thing we want here because we don't want to stop when one dataset is exhausted and we don't want to oversample either. One possibility would be to add a new stopping strategy like none_exhausted which samples the precise proportions (or less if there are not enough samples)
This works no ? And it feels easier to understand than a mixer (are the examples taken at random ? is the fraction representing the proportion of the original dataset or the resulting dataset ?)
concatenate_datasets([
dataset.select(range(int(frac * len(dataset))))
for dataset, frac in zip(datasets, fracs)
])
This works no ? And it feels easier to understand than a mixer (are the examples taken at random ? is the fraction representing the proportion of the original dataset or the resulting dataset ?)
concatenate_datasets([ dataset.select(range(int(frac * len(dataset)))) for dataset, frac in zip(datasets, fracs) ])
Yes, that is effectively what we did in the alignment-handbook here. In our case, the fractions represent the proportion per dataset and the samples are just taken from the first N rows.
My question is more about whether it makes sense for datasets to have a utility method to load and combine multiple datasets from the Hub? Here's how we did it in the handbook (link) - happy to keep this logic in trl if it's too niche!
Since the mixing logic is kinda specific I don't think it should be in datasets. Though if there is something more general / explicit we can reconsider