gpt-neox icon indicating copy to clipboard operation
gpt-neox copied to clipboard

Adding replay into GPT-NeoX

Open AIproj opened this issue 10 months ago • 1 comments

This PR aims to add replay to GPT-NeoX. I had implemented this for the paper Simple and Scalable Strategies to Continually Pre-train Large Language Models that shows simple ways to efficiently continue to pretrain by improving adaptation to new data while mitigating forgetting of previous data. Note that this PR can serve as a basis to add the ability to resume training from a certain index in a dataset, based on how I implemented this feature for replay datasets.

How to use

I tried to make the descriptions of the replay args informative enough to serve as documentation. An example of a config using replay is also provided in tests/config/example_replay_config.yml.

Unsupported/untested features:

  • (UNTESTED) Using replay AND weighting by number of documents. There's an assert to throw an error if someone tries to use both.
  • (UNSUPPORTED) Using replay AND splitting the datasets automatically instead of providing separate train, val and test paths. There's an assert to throw an error if someone tries to use both.
  • (UNSUPPORTED) Using replay AND label data. There's an assert to throw an error if someone tries to use both. As indicated in comments, it might be doable by adding a replay_label_data arg that would specify the prefix to the idx and data path of replay label data, then generate the specific replay label data path from the prefix, and treat it in a similar way as the training data in the block
    # The concatenate_train_replay_paths bool is necessary to avoid issues when this function gets called a second time.
    if neox_args.is_replay_enabled and concatenate_train_replay_paths:
        # Merge replay data paths into train data paths logic, but need to keep track of
        # what paths in train_data_paths came from replay
        num_replay_data_paths = len(neox_args.replay_data_paths)
        num_non_replay_data_paths = len(neox_args.train_data_paths)
        neox_args.train_data_paths += neox_args.replay_data_paths

Pending tests

Currently, the tests required are:

  1. Sanity check that first few batches are the same with/without these changes.
  2. Similarly as above, check that label data support did not break with this.
  3. Sanity check that given two datasets, not using replay but having 0.5 weights for each is the same as setting one dataset as training dataset, and the other as replay dataset with replay fraction 0.5.

The tests can follow the procedure described in tests/model/test_batch_replicability.py. Tests 1 and 3 were passed with the Summit version of NeoX, but I'll need to run them again on the replay implementation based on the current main branch of NeoX. I'll probably need someone else to test that label data support (test 2) did not break as I'm unfamiliar with this feature of NeoX and am currently too busy to take that on.

AIproj avatar Apr 13 '24 00:04 AIproj

Please ignore the above commits. I accidentally pushed to upstream when modifying this branch in my fork.

bentherien avatar Apr 14 '24 20:04 bentherien