litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Using custom data for `Continue pretraining an LLM`

Open simon-bachhuber opened this issue 1 year ago • 14 comments

The example (https://github.com/Lightning-AI/litgpt?tab=readme-ov-file#continue-pretraining-an-llm) works fine on my machine but as soon as i replace with custom text files that each just contain one english language sentence with no special characters the example no longer works.

I don't understand what the difference between the provided data examples is, and my custom data. Is there some special formatting that i am not seeing?

litgpt pretrain --model_name Meta-Llama-3-8B-Instruct --tokenizer_dir $WORK/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct --initial_checkpoint_dir $WORK/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct --data TextFiles --data.train_data_path "/data/custom_texts/" --train.max_tokens 100_000 --out_dir $WORK/out/custom-model

which results in

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
[rank: 3] Seed set to 42
[rank: 2] Seed set to 42
[rank: 1] Seed set to 42
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

{'data': {'batch_size': 1,
          'max_seq_length': -1,
          'num_workers': 4,
          'seed': 42,
          'tokenizer': None,
          'train_data_path': PosixPath('/home/woody/iwb0/iwb0003h/custom_texts'),
          'val_data_path': None},
 'devices': 'auto',
 'eval': {'initial_validation': False,
          'interval': 1000,
          'max_iters': 100,
          'max_new_tokens': None},
 'initial_checkpoint_dir': PosixPath('/home/woody/iwb0/iwb0003h/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct'),
 'logger_name': 'tensorboard',
 'model_config': None,
 'model_name': 'Meta-Llama-3-8B-Instruct',
 'out_dir': PosixPath('/home/woody/iwb0/iwb0003h/out/custom-model'),
 'precision': None,
 'resume': False,
 'seed': 42,
 'tokenizer_dir': PosixPath('/home/woody/iwb0/iwb0003h/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct'),
 'train': {'beta1': 0.9,
           'beta2': 0.95,
           'epochs': None,
           'global_batch_size': 512,
           'learning_rate': 0.0004,
           'log_interval': 1,
           'lr_warmup_fraction': None,
           'lr_warmup_steps': 2000,
           'max_norm': 1.0,
           'max_seq_length': None,
           'max_steps': None,
           'max_tokens': 100000,
           'micro_batch_size': 4,
           'min_lr': 4e-05,
           'save_interval': 1000,
           'tie_embeddings': False,
           'weight_decay': 0.1}}
[rank: 0] Seed set to 42
Time to instantiate model: 0.03 seconds.
Total parameters: 8,030,261,248
Create an account on https://lightning.ai/ to optimize your data faster using multiple nodes and large machines.
Storing the files under /home/woody/iwb0/iwb0003h/custom_texts/train
Setup started with fast_dev_run=False.
Worker 0 gets 0.0 MB (1 files)
Setup finished in 0.002 seconds. Found 1 items to process.
Starting 1 workers with 1 items. The progress bar is only updated when a worker finishes.
Workers are ready ! Starting data processing...
                                                                                                                                                                                  Rank 0 inferred the following `['no_header_tensor:16']` data format.                                                                                          | 0/1 [00:00<?, ?it/s]
Worker 0 is terminating.
Worker 0 is done.
Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.68it/s]
Workers are finished.
Finished data processing!
Create an account on https://lightning.ai/ to optimize your data faster using multiple nodes and large machines.
Storing the files under /home/woody/iwb0/iwb0003h/custom_texts/val
Setup started with fast_dev_run=False.
Worker 0 gets 0.0 MB (1 files)
Setup finished in 0.001 seconds. Found 1 items to process.
Starting 1 workers with 1 items. The progress bar is only updated when a worker finishes.
Workers are ready ! Starting data processing...
                                                                                                                                                                                  Rank 0 inferred the following `['no_header_tensor:16']` data format.                                                                                          | 0/1 [00:00<?, ?it/s]
Worker 0 is terminating.
Worker 0 is done.
Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.69it/s]
Workers are finished.
Finished data processing!
Validating ...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/bin/litgpt", line 8, in <module>
[rank0]:     sys.exit(main())
[rank0]:              ^^^^^^
[rank0]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/__main__.py", line 143, in main
[rank0]:     fn(**kwargs)
[rank0]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 123, in setup
[rank0]:     main(
[rank0]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 207, in main
[rank0]:     fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
[rank0]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 235, in fit
[rank0]:     validate(fabric, model, val_dataloader, max_iters=2)   # sanity check
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 362, in validate
[rank0]:     val_loss = torch.stack(losses).mean()
[rank0]:                ^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: stack expects a non-empty TensorList
[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/bin/litgpt", line 8, in <module>
[rank3]:     sys.exit(main())
[rank3]:              ^^^^^^
[rank3]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/__main__.py", line 143, in main
[rank3]:     fn(**kwargs)
[rank3]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 123, in setup
[rank3]:     main(
[rank3]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 207, in main
[rank3]:     fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
[rank3]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 235, in fit
[rank3]:     validate(fabric, model, val_dataloader, max_iters=2)   # sanity check
[rank3]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank3]:     return func(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 362, in validate
[rank3]:     val_loss = torch.stack(losses).mean()
[rank3]:                ^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: stack expects a non-empty TensorList
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/bin/litgpt", line 8, in <module>
[rank1]:     sys.exit(main())
[rank1]:              ^^^^^^
[rank1]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/__main__.py", line 143, in main
[rank1]:     fn(**kwargs)
[rank1]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 123, in setup
[rank1]:     main(
[rank1]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 207, in main
[rank1]:     fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
[rank1]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 235, in fit
[rank1]:     validate(fabric, model, val_dataloader, max_iters=2)   # sanity check
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 362, in validate
[rank1]:     val_loss = torch.stack(losses).mean()
[rank1]:                ^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: stack expects a non-empty TensorList
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/bin/litgpt", line 8, in <module>
[rank2]:     sys.exit(main())
[rank2]:              ^^^^^^
[rank2]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/__main__.py", line 143, in main
[rank2]:     fn(**kwargs)
[rank2]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 123, in setup
[rank2]:     main(
[rank2]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 207, in main
[rank2]:     fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
[rank2]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 235, in fit
[rank2]:     validate(fabric, model, val_dataloader, max_iters=2)   # sanity check
[rank2]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litgpt/pretrain.py", line 362, in validate
[rank2]:     val_loss = torch.stack(losses).mean()
[rank2]:                ^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: stack expects a non-empty TensorList

simon-bachhuber avatar May 29 '24 14:05 simon-bachhuber

Good question. Maybe it's too small so it can't generate the validation set. Does the same issue occur if you make the dataset larger, e.g., duplicating the sentence?

rasbt avatar May 29 '24 15:05 rasbt

I created the the files ex1.txt and ex2.txt which each contain 1000 lines with the sentence Roses are always blue and the world is populated by roses.

The error is the same.

simon-bachhuber avatar May 29 '24 16:05 simon-bachhuber

Thanks, this definitely sounds like an issue then to look into.

rasbt avatar May 29 '24 19:05 rasbt

After some first investigation i believe this to be related to the fact that max_seq_length=-1.

Consider the following example:

from litgpt.data import TextFiles
from pathlib import Path
from litgpt.tokenizer import Tokenizer

tokenizer = Tokenizer("/home/woody/iwb0/iwb0003h/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct")
text_files = TextFiles(Path("/home/woody/iwb0/iwb0003h/custom_texts_roses"))
text_files.connect(tokenizer, max_seq_length=-1)
text_files.prepare_data()
text_files.setup()
dl = text_files.val_dataloader()
sample = next(iter(dl))

This throws the error

---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
Cell In[37], line 1
----> 1 len(dl)

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/torch/utils/data/dataloader.py:475, in DataLoader.__len__(self)
    457 def __len__(self) -> int:
    458     if self._dataset_kind == _DatasetKind.Iterable:
    459         # NOTE [ IterableDataset and __len__ ]
    460         #
   (...)
    473 
    474         # Cannot statically verify that dataset is Sized
--> 475         length = self._IterableDataset_len_called = len(self.dataset)  # type: ignore[assignment, arg-type]
    476         if self.batch_size is not None:  # IterableDataset doesn't allow custom sampler or batch_sampler
    477             from math import ceil

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litdata/streaming/dataset.py:162, in StreamingDataset.__len__(self)
    161 def __len__(self) -> int:
--> 162     return self.get_len(1, 1)

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litdata/streaming/dataset.py:169, in StreamingDataset.get_len(self, num_workers, batch_size)
    167 worker_env = _WorkerEnv.detect()
    168 if self.shuffler is None:
--> 169     cache = self._create_cache(worker_env=worker_env)
    170     self.shuffler = self._create_shuffler(cache)
    171 return self.shuffler.get_len(self.distributed_env, num_workers, batch_size, self.current_epoch)

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litdata/streaming/dataset.py:142, in StreamingDataset._create_cache(self, worker_env)
    133         self.input_dir.path = cache_path
    135 cache = Cache(
    136     input_dir=self.input_dir,
    137     item_loader=self.item_loader,
   (...)
    140     max_cache_size=self.max_cache_size,
    141 )
--> 142 cache._reader._try_load_config()
    144 if not cache.filled:
    145     raise ValueError(
    146         f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file."
    147         " HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
    148     )

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litdata/streaming/reader.py:211, in BinaryReader._try_load_config(self)
    209 def _try_load_config(self) -> Optional[ChunksConfig]:
    210     """Try to load the chunks config if the index files are available."""
--> 211     self._config = ChunksConfig.load(self._cache_dir, self._serializers, self._remote_input_dir, self._item_loader)
    212     return self._config

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litdata/streaming/config.py:213, in ChunksConfig.load(cls, cache_dir, serializers, remote_dir, item_loader)
    210 if not os.path.exists(cache_index_filepath):
    211     return None
--> 213 return ChunksConfig(cache_dir, serializers, remote_dir, item_loader)

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litdata/streaming/config.py:61, in ChunksConfig.__init__(self, cache_dir, serializers, remote_dir, item_loader)
     58 self._config["data_spec"] = treespec_loads(self._config["data_spec"])
     60 self._item_loader.setup(self._config, self._chunks, serializers)
---> 61 self._intervals = self._item_loader.generate_intervals()
     62 self._length = self._intervals[-1][-1]
     63 self._downloader = None

File /home/woody/iwb0/iwb0003h/.conda/envs/litgpt/lib/python3.11/site-packages/litdata/streaming/item_loader.py:177, in TokensLoader.generate_intervals(self)
    175 for chunk in self._chunks:
    176     dim = chunk["dim"]
--> 177     num_blocks = dim // self._block_size
    178     end += num_blocks
    179     intervals.append((begin, end))

ZeroDivisionError: integer division or modulo by zero

However, setting the max_seq_length e.g. to 10, the code works just fine.

simon-bachhuber avatar May 29 '24 21:05 simon-bachhuber

Same issue here. Code should either assert or at least throw warning about this.

MaxGonzalezSaez-Diez avatar Jul 16 '24 08:07 MaxGonzalezSaez-Diez

Thanks for reporting. I totally missed your follow-up @SimiPixel . Based on the error message, it kind of looks like a LitData issue.

The LitData version was just upgraded a few days ago via #1573 and I was wondering which version you have currently installed @MaxGonzalezSaez-Diez ? It could well be an issue that still persists but just double-checking before looking further into it.

rasbt avatar Jul 16 '24 12:07 rasbt

I have litdata version 0.2.6

MaxGonzalezSaez-Diez avatar Jul 16 '24 13:07 MaxGonzalezSaez-Diez

Oh I see. Could you try to upgrade (pip install litdata==0.2.16) and see if the error still persists then?

rasbt avatar Jul 16 '24 13:07 rasbt

Still persists.

MaxGonzalezSaez-Diez avatar Jul 16 '24 14:07 MaxGonzalezSaez-Diez

Arg, I was hoping this would fix it. I currently don't have any good explanation for this and would have to look into it. Thanks for sharing

rasbt avatar Jul 16 '24 14:07 rasbt

@MaxGonzalezSaez-Diez Is it possible to attach the files here that are causing the problems?

rasbt avatar Jul 16 '24 16:07 rasbt

Tl;dr: I think this error appears when using large max_seq_length with very little data.

I am unsure what exactly is causing the problem or which specific file is responsible for it. After some experiments I think I at least have a feeling: I am trying to train a model with long context window. When I set max_seq_length to let's say 10, everything works. However, once I increase it beyond 2048 it breaks.

Here is what I did that seems to have fixed the issue (for everyone else randomly reading through this):

  1. Make sure to delete the train and val folder that are created in the directory you give to TextFiles
  2. Create two big .txt files. I created a bunch of random text lines and then copied and pasted them to the length of 50k lines (around 80B)
  3. I re-ran the litgpt pretrain command. Now it took like 3 to 4 mins to create the train and val folder with the .bin and index.json files
  4. You should be good to go.

What I suspect something related to this happening:

  1. I ran the command first with max_seq_length 10 and I was using just 2 small .txt files to first make sure everything runs smoothly. This seems to have created the train and val folder with the .bin and index.json files.
  2. I then increased the max_seq_length argument to 8192. However, it seems like the val folder had very little data (the length of the val_dataloader was 0) which (I assume) is causing the issue.
  3. It seems to have gotten fixed after doing what I described previously.

MaxGonzalezSaez-Diez avatar Jul 17 '24 07:07 MaxGonzalezSaez-Diez

Thanks for the super detailed write-up. So I suspect there are two potential issues

  1. Short texts that are shorter than the max_seq_length.

  2. Leftover files from testing

I ran the command first with max_seq_length 10 and I was using just 2 small .txt files to first make sure everything runs smoothly. This seems to have created the train and val folder with the .bin and index.json files.

When I remember correctly I did this by design to avoid long reprocessing of the text each time. It's probably also a relatively common use case to try multiple models or hparam configs on the same training set, so we should reprocess the dataset every time.

I am still wondering though how we could avoid such errors. Perhaps just a warning that says that these files already exist and the preprocessing is skipped, and maybe a tip to delete the files in case the source files have changed to trigger reprocessing, would help?

rasbt avatar Jul 17 '24 22:07 rasbt

A warning would be great. I somehow assumed that the reprocessing would be automatic every time.

MaxGonzalezSaez-Diez avatar Jul 18 '24 07:07 MaxGonzalezSaez-Diez