Feat: Add support of multiple datasets in config
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [x] update tests and/or documentation
- [ ] other (please add here)
Please link to any issues this PR addresses.
Changelog
I've added the ability to use multiple sources for any types of datasets.
After the merger of this PR, users of TorchTune will be able to pass multiple datasets in different formats. For example, it will be possible to mix chat and instruct datasets using different templates, splits, etc.
Example of the new version of the config:
dataset:
- _component_: torchtune.datasets.instruct_dataset
source: tatsu-lab/alpaca
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.chat_dataset
source: Open-Orca/SlimOrca-Dedup
conversation_style: sharegpt
chat_format: Llama2ChatFormat
max_seq_len: 1024
split: train
seed: null
shuffle: True
For backward compatibility, users can continue using the previous format of the dataset field if they do not wish to use multiple datasets:
dataset:
_component_: torchtune.datasets.instruct_dataset
source: tatsu-lab/alpaca
template: AlpacaInstructTemplate
split: train
train_on_input: True
To run unit tests:
pytest ./tests/torchtune/datasets/test_multi_dataset.py
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install) - [x] add unit tests for any new functionality
- [x] update docstrings for any new or updated methods or classes
- [x] run unit tests via
pytest tests - [x] run recipe tests via
pytest tests -m integration_test - [x] manually run any new or modified recipes with sufficient proof of correctness
- [x] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/889
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit f7a3f958a2a32577aeaea6efcfd60656fcea380b with merge base aa650129bcfa08771ba76ad9bdfe7dafc48daffe ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hi @EvilFreelancer!
Thank you for your pull request and welcome to our community.
Action Required
In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.
Process
In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.
Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.
If you have received this in error or have any questions, please contact us at [email protected]. Thanks!
Hi! I've signed CLA, how to rerun failed check?
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!
Thanks for this awesome PR @EvilFreelancer! Handling multiple datasets has been a north star for us, as most mature data pipelines for fine-tuning models typically incorporate multiple data sources, so I appreciate you adding this.
The only one problem is: dataset must have simmilar formats.
This is my main concern - the concatenated datasets must have the same columns AND the same instruct template / chat format. Also, all the keyword arguments will need to be shared (max_seq_len, train_on_input). This is quite restrictive and will require users to do a lot of offline preprocessing work. Ideally, using multiple datasets should be flexible enough that each dataset can have different columns and different template yet we're able to coalesce these together. In the end, all the data gets tokenized as Messages, so I do think this is possible.
What are your thoughts on letting InstructDataset and ChatDataset handle single data sources only, and creating some container class that can hold multiple InstructDatasets or ChatDatasets? Since the dataset classes return the token IDs ready to be used by the model, the container class can focus on handling the logic of sampling from the list of datasets.
class MultiDataset(torch.utils.data.Dataset):
def __init__(self, datasets: List[Dataset]):
self.datasets = datasets
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
# Figure out how to sample/interleave multiple datasets here
Then, in the config, you could specify multiple datasets like this (or something similar):
dataset:
- _component_: torchtune.datasets.instruct_dataset
source: tatsu-lab/alpaca
...
- _component_: torchtune.datasets.instruct_dataset
source: vicgalle/alpaca-gpt4
...
And the recipe can instantiate a MultiDataset if the dataset param is a list. This way, each dataset can keep their individual parameters.
Ideally, we can make use of concatenate_datasets and interleave_datasets from HF but since we have preprocessing logic in our dataset classes that are specific to each data source, we might have to create similar logic ourselves.
Hi @RdoubleA, thank you for your response!
Hm, the MultiDataset class and the possibility to pass an array of datasets through the dataset parameter - it's a great idea, and it's much simpler for end-users to understand how to use this feature than dealing with datasets in similar formats.
I have a couple of spare days, so I may start working on implementing this solution, as it is critical for the project I am working on to have the ability to train on various combinations of datasets.
Hi! I've implemented logic of MultiDataset and enabled it in traning recipes.
I've also removed original multi-source logic and cleaned tests.
How to use:
dataset:
- _component_: torchtune.datasets.instruct_dataset
source: tatsu-lab/alpaca
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.instruct_dataset
source: vicgalle/alpaca-gpt4
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.chat_dataset
source: Open-Orca/SlimOrca-Dedup
conversation_style: sharegpt
chat_format: Llama2ChatFormat
max_seq_len: 1024
split: train
- _component_: torchtune.datasets.chat_dataset
source: ajibawa-2023/Code-290k-ShareGPT
conversation_style: sharegpt
chat_format: Llama2ChatFormat
max_seq_len: 1024
split: train
seed: null
shuffle: True
By the way, this gave an interesting side effect: using this logic, you can freely mix datasets of any format.
In the example that I gave, dataset formats instrcut and chat are mixed, it seems to me that this will be a very interesting and convenient feature of TorchTune.
@RdoubleA hi! Thanks for your review, requested fixes added.
By the way I've noticed one small thing, for example in vicgalle/alpaca-gpt4 dataset have 52k rows, why TorchTune shows me 26k items on training stage? Maybe it's because train/val splits?
UPD. Ah, because of batch_size: 2, got it.
@RdoubleA hi! I've added fixes you mentioned, and couple simple tests on MultiDataset class.
What do you think about:
if not isinstance(cfg_dataset, ListConfig):
cfg_dataset = [cfg_dataset]
datasets = [config.instantiate(cfg_item, tokenizer=self._tokenizer) for cfg_item in cfg_dataset]
ds = utils.MultiDataset(datasets=datasets)
instead of:
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenzier=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = utils.MultiDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
@EvilFreelancer Thanks for the updates! I'll take another pass soon.
In the meantime, since this is a significant feature we are adding, I'd like to make sure this is rigorously tested. Do you mind confirming the following:
- The unit test for
MultiDatasetpasses (please add the test command to the PR summary) - Run one of the distributed recipes using
MultiDatasetand confirm that loss curves and tokens/sec look reasonable (you can use WandBLogger for easy visualization). Main thing I want to confirm here is that we can still sample from multiple datasets correctly in a distributed environment - Run one of the single device recipes using
MultiDatasetfor a similar reason above
Also, I forgot to mention this earlier, but I think a better location for MultiDataset would be in torchtune/datasets instead of torchtune/utils, what do you think?
@EvilFreelancer Can you update the README with examples on how this will now look in the YAML config?
@RdoubleA Hi! I've made some fixes to the PR.
Also, I forgot to mention this earlier, but I think a better location for
MultiDatasetwould be intorchtune/datasetsinstead oftorchtune/utils, what do you think?
The reason I believed that the torchtune/utils namespace was more suitable for this class is because the MultiDataset does not represent an actual dataset, such as Alpaca or OpenOrca, instead, it serves merely as a wrapper over several datasets. However, I agree that it is most logical to move this class to the torchtune/datasets namespace. (code refactored)
The unit test for MultiDataset passes
The description of the PR was updated, and a note on how to run unit tests was added.
pytest ./tests/torchtune/datasets/test_multi_dataset.py
(venv) [pasha-pc] ~/Documents/Repository/nn-nlp/torchtune $ pytest ./tests/torchtune/datasets/test_multi_dataset.py
= test session starts =
platform linux -- Python 3.11.2, pytest-8.1.2, pluggy-1.5.0
rootdir: /home/pasha/Documents/Repository/nn-nlp/torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 3 items
tests/torchtune/datasets/test_multi_dataset.py ...
Run one of the distributed recipes using MultiDataset and confirm that loss curves and tokens/sec look reasonable (you can use WandBLogger for easy visualization). Main thing I want to confirm here is that we can still sample from multiple datasets correctly in a distributed environment
Unfortunately, my local server is equipped with only one GPU. As a result, I've created an instruction and enlisted the help of one of my subscribers who owns a multi-GPU server. A detailed update will be provided tomorrow. For now, I can only confirm that the training has started on the multi-GPU server and that the number of training steps is the same as from my side.
Run one of the single device recipes using MultiDataset for a similar reason above
I've also created an instruction and trained Gemma 2b on a single device. Here is the WandB report about my attempt.
GPU: 1x RTX 4090, CPU: 1x AMD 5950X.
@joecummings hi!
Can you update the README with examples on how this will now look in the YAML config?
I've added small section to dataset page in tutorial.
The tests for Multi GPU Gemma 2B training are available here.
UPD. Second training on the same hardware without using MultiDataset class on a single Aplaca dataset.
As you can see tokens_per_seconds distribution is almost the same.
GPU: 2x RTX 4090, CPU: 1x Xeon Gold 6336Y.
@EvilFreelancer Thanks for sharing the test runs! Sorry for the delay in following up on this. An update from our side is that we've been having a lot of discussions around iterable datasets, which is something that we want to support more and design around moving forward. This would enable more flexibility around interleaving, weighted sampling, etc. for multiple datasets that won't fit in memory.
That being said, I still think the MultiDataset here is valuable for datasets that fit in memory and can still leverage map-style functionality. This will unblock users that want to quickly concatenate multiple data sources until we add iterable datasets and a more powerful MultiDataset. So to accurately reflect the scope of this class, my suggestion is to rename this to ConcatDataset (you might've started with this name to begin with, I apologize) and make it clear in the docstrings that this is for map datasets that all fit in memory.
Other than that, I have no major concerns for the rest of the changes. Let me know if this makes sense!
Hey @EvilFreelancer, are you still planning to make the changes? If not, I'm happy to do it for you and get this merged in. Let me know how you'd like to proceed.
Hi @RdoubleA,
I wanted to update you on the progress regarding the MultiDataset class. It has been renamed to ConcatDataset, along with corresponding updates to all related tests and imports in the recipes. I've also added a comprehensive docstring to the class to enhance clarity and usability.
Apologies for the delay in implementing these changes, the last couple of days have been particularly hectic at work, and I couldn't address these tasks sooner. Thank you for your understanding.
Ah crap one more thing: my suggestion screwed up the rendering of the multi-dataset YAML. Sorry about that! I think you just need to add back a newline before the codeblock statement.
@ebsmothers typo fixed
Thank you for the meticulous code-review and excellent advice, I thoroughly enjoyed working on this project with all of you!