torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

feat: added packed continued pretraining dataset functionality

Open calmitchell617 opened this issue 10 months ago • 5 comments

Context

This PR seeks to resolve issue #809 by adding the ability to pack and tokenize input/label pairs from a Huggingface dataset and then continue a model's pretraining phase on that processed dataset. The dataset can be local, or streamed.

This is my first attempt contributing to this repo. I expect a few changes will be requested, but hope you will consider it a good effort!

Ccing @RdoubleA.

One limitation of this new feature is that you can only pack as many examples as will fit in your system's RAM. This does not mean you can't use large datasets - just that you will only be able to use as many examples as will fit in RAM. This is because when creating a dataset class, you need to implement __get_item__, which requires an index to get an item. However, you will not know what a given index will return without packing the examples ahead of time.

One other thing to note is the train_on_input argument in the stack_dataset function, which is only there for compatibility with existing config yaml files. I'm not convinced it's a good practice to include that kind of thing, but I without it, we may have to change many of the existing recipe configs.

Changelog

The PR implements a new class, ConcatDataset, which was inspired by the ConcatDataset class in llama-recipes.

It also implements a new example dataset, stack_dataset, which makes use of ConcatDataset to stream and pack the the Stack V1 Dedup dataset. This is a popular (and very large) dataset for training coding assistants.

Test plan

The new ConcatDataset class and stack_dataset function both get tests that attempt to mimic the structure and spirit of existing tests - specifically, I used the tests covering InstructDataset and alpaca_dataset as a guide.

Also, I tested the code by training on a custom dataset which is designed to increase a model's knowledge of a certain Python library, and after a few thousand optimizer steps, the model had clearly learned about that library.

You can kick off a training run with a command like this:

clear && tune run \
    --nproc_per_node=4 \
    full_finetune_distributed \
    --config llama3/8B_full \
    batch_size=1 \
    seed=29 \
    tokenizer.path=<checkpoint_dir> \
    checkpointer.checkpoint_dir=<checkpoint_dir> \
    checkpointer.output_dir=<checkpoint_dir> \
    dataset=torchtune.datasets.stack_dataset \
    gradient_accumulation_steps=5 \
    lr_scheduler.num_warmup_steps=100 \
    enable_activation_checkpointing=True \
    epochs=3 \
    dataset.max_rows=200 \
    dataset.streaming=True
    dataset.token='<HF_token_goes_here>'

calmitchell617 avatar Apr 21 '24 15:04 calmitchell617

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/827

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Apr 21 '24 15:04 pytorch-bot[bot]

@SLR722 thank you for the basic copy/paste error finds.

@RdoubleA, I'll address your questions and comments one by one:

How do you plan to extend packing functionality to all our other datasets (InstructDataset and ChatDataset)? Is there a way can keep the packing logic isolated in a way where we can just turn it on with a flag in any of our datasets?

It's almost certainly possible to make the packing feature available to other dataset types. I will investigate what is required to do that.

There's two significant features here - streaming and packing. What is most important for what you are trying to accomplish?

Packing a customized dataset is the more important than streaming for my use case.

I imagine streaming would be best served by an iterable dataset, which we don't currently have.

I have not implemented an iterable dataset before, but maybe using one would alleviate the issue of having to fit the entire dataset in CPU RAM. I will look into that.

I ask because I think we should try to tackle both of these features separately.

Yes, that may be a good idea. It will be trivial to switch to a smaller example dataset to finish developing the packing functionality, then we can address streaming a larger dataset afterwards.

can we just use load_dataset's ability to load from a local file, or is there some missing functionality?

load_dataset() does not allow you to load a customized dataset from your local disk. This is a very important part of many people's workflow. Here is a minimal reproducible example showing what happens when you try to load a local dataset, saved with save_to_disk(), with load_dataset():

from datasets import load_dataset, load_from_disk

dataset_id = 'hails/mmlu_no_train'
subset = 'astronomy'
local_path = 'my_dataset'

# Load the dataset
dataset = load_dataset(
    dataset_id,
    subset,
    trust_remote_code=True, # you can comment this out, but it will raise a warning
    )['dev']
print(f'Length of dataset before customizing: {len(dataset)}')

# Customize our dataset (select the first 3 rows)
dataset = dataset.select(range(3))
print(f'Length of dataset after customizing: {len(dataset)}')

# Save our customized dataset to disk
dataset.save_to_disk(local_path)

# Load our customized dataset from disk, will work
dataset = load_from_disk(local_path)
print(dataset[0])

# Attempt to load local dataset with load_dataset (will fail)
try:
    dataset = load_dataset(local_path)
except Exception as e:
    print(f"Got error: {e}")

calmitchell617 avatar Apr 23 '24 07:04 calmitchell617

load_dataset() does not allow you to load a customized dataset from your local disk. This is a very important part of many people's workflow. Here is a minimal reproducible example showing what happens when you try to load a local dataset, saved with save_to_disk(), with load_dataset()

@calmitchell617 Drive-by comment on this, but is this an issue with load_dataset or is it an issue with how you're saving the dataset in that example? For instance in our repo we take a common open-source dataset, select a few rows, and store locally for testing. This local dataset gets passed to some of our recipe tests, which call load_dataset with no issues. Here is the script I used to construct the local dataset, it is just a json.dump. Then we build the config for our local dataset class for testing here. This goes through our normal recipe flow and so calls into load_dataset here.

ebsmothers avatar Apr 23 '24 15:04 ebsmothers

@ebsmothers yes, it is most likely an issue of how I'm saving the dataset, but I believe the save_to_disk() function is quite popular in the HF ecosystem.

Whenever I try to load a dataset saved with save_to_disk with load_dataset, I get an error:

You are trying to load a dataset that was saved using save_to_disk. Please use load_from_disk instead.

You will see that error if you run the script pasted above.

calmitchell617 avatar Apr 23 '24 16:04 calmitchell617

@ebsmothers @RdoubleA @kartikayk, thank you all very much for taking my use case and problems into account! I really appreciate it.

Since you are all asking my opinion on various use cases and functions used in various, fragmented places, I would like to provide one central comment listing the functions that, in my experience, are the most commonly used entrypoints into the HF ecosystem:

Datasets

The most common way to load a dataset is load_dataset, followed by load_from_disk. load_from_disk if used if you downloaded a dataset, did some customization, then saved it with save_to_disk.

As far as I know, you cannot load a dataset saved with save_to_disk with load_dataset - you must use load_from_disk. See comment above.

Models

The most common way I've seen to load a model is from_pretrained. Lots of people will be looking for the models they train with torchtune to work OOTB with that function. I outlined my methodology for converting a model output by torchtune to one that can be loaded with that function here.

EDIT: Maybe I should have made this it's own issue - oops. Sorry to hijack an existing issue.

calmitchell617 avatar Apr 23 '24 16:04 calmitchell617

Closing this PR in favor of #875

calmitchell617 avatar May 07 '24 07:05 calmitchell617