torchtune
torchtune copied to clipboard
Support for unstructured text corpus datasets for CPT
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [x] update tests and/or documentation
- [ ] other (please add here)
Please link to any issues this PR addresses: #845, #809
Continued pre-training involves an identical data processing pipeline as standard pre-training, where the model simply predicts the next token and completes the text. It does not require any templating or prompt formatting. Our existing dataset classes are specifically designed for instruct and chat tuning, but don't support free-form text corpuses.
Here, we add a dataset class TextDataset
that simply calls load_dataset
and tokenizes the text directly without any further processing. It uses encode
which simply adds BOS/EOS tokens if needed. This should be compatible with llama2, llama3, and other model's formatting requirements.
Changelog
What are the changes made in this PR?
- Add
TextDataset
and appropriate tests - Add an example dataset builder,
cnn_dailymail_articles_dataset
and appropriate tests
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install
) - [x] add unit tests for any new functionality
- [x] update docstrings for any new or updated methods or classes
- [x] run unit tests via
pytest tests
- [x] run recipe tests via
pytest tests -m integration_test
- [x] manually run any new or modified recipes with sufficient proof of correctness
- [x] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
Ran the full finetune distributed recipe with the cnn dailymail dataset.
2024-04-24:23:12:06,234 INFO [_utils.py:34] Running FullFinetuneRecipeDistributed with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B/original/
checkpoint_files:
- consolidated.00.pth
model_type: LLAMA3
output_dir: /tmp/Meta-Llama-3-8B/
recipe_checkpoint: null
dataset:
_component_: torchtune.datasets.cnn_dailymail_articles_dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
epochs: 3
gradient_accumulation_steps: 1
log_every_n_steps: null
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
log_dir: /tmp/alpaca-llama3-finetune
name: cnn_dailymail
project: torchtune
model:
_component_: torchtune.models.llama3.llama3_8b
optimizer:
_component_: torch.optim.AdamW
foreach: false
lr: 2.0e-05
output_dir: /tmp/alpaca-llama3-finetune
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3-8B/original/tokenizer.model
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/868
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 55a6de32745b9c5c29ec2e568823c12a8e07d1de with merge base 29ae975fc6d2f8e85ce33634116e0bda0472253c ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
do we in general plan to support CPT in torchtune and plan for an RFC around this possibly?
I don't know that CPT will be sufficiently different from our finetune recipes to warrant its own, besides different data handling. @rohan-varma would you know :) or maybe @pbontrager ?
I don't know that CPT will be sufficiently different from our finetune recipes to warrant its own, besides different data handling. @rohan-varma would you know :) or maybe @pbontrager ?
I don't think so initially. I think this dataset is good and we could add a recipe in the future if there are additional custom things that people want for advanced use cases.
Codecov Report
Attention: Patch coverage is 55.20833%
with 43 lines
in your changes are missing coverage. Please review.
Project coverage is 27.20%. Comparing base (
cb8e65a
) to head (6ead093
). Report is 2 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #868 +/- ##
===========================================
- Coverage 67.10% 27.20% -39.91%
===========================================
Files 174 180 +6
Lines 7423 7518 +95
===========================================
- Hits 4981 2045 -2936
- Misses 2442 5473 +3031
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Any update here? I want to train on https://huggingface.co/datasets/allenai/dolma soon.