Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] `tokenize_and_concatenate` doesn't work with small datasets.

Open yash-srivastava19 opened this issue 6 months ago • 1 comments

Describe the bug

It was mentioned in the docstrings as well that the tokenize_and_concatenate function doesn't work properly with small datasets. I wanted to figure out is there a workaround that can be used.

Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why.

Code example The dataset I'm using is a small dataset, and sometimes contains only single word. Here is the minimal code that reproduces the error.

from datasets.load import load_dataset
from transformer_lens.utils import tokenize_and_concatenate

#----------- Utility Functions -------------

def create_dataset(link):
    ds = load_dataset(
        path=f"link",
        split="train",
        streaming=False,
    )
    return ds

def get_tokens(dataset, tokenizer=model.tokenizer, streaming=True, max_length=sae.cfg.context_size, add_bos_token=sae.cfg.prepend_bos):
    return tokenize_and_concatenate(
        dataset = dataset,
        tokenizer = tokenizer,
        streaming=streaming,
        max_length=max_length,
        column_name=column_name,
        add_bos_token=add_bos_token
    )

#----------- Example Usage -------------
DATASET_1 = "link-to-big-dataset"
DATASET_2 = "link-to-small-dataset"

dataset_1 = create_dataset(DATASET_1)
dataset_1_tokens  = get_tokens(dataset_1) # This gets executed. No issues.

dataset_2 = create_dataset(DATASET_2)
dataset_2_tokens  = get_tokens(dataset_2) # This line breaks.

Here's what the error stack trace looks like :

...
File /opt/conda/lib/python3.10/site-packages/transformer_lens/utils.py:358, in tokenize_and_concatenate(dataset, tokenizer, streaming, max_length, column_name, add_bos_token, num_proc)
    350     return {"tokens": tokens}
    352 tokenized_dataset = dataset.map(
    353     tokenize_function,
    354     batched=True,
    355     num_proc=(num_proc if not streaming else None),
    356     remove_columns=[column_name],
    357 )
--> 358 tokenized_dataset.set_format(type="torch", columns=["tokens"])
    359 return tokenized_dataset

File /opt/conda/lib/python3.10/site-packages/datasets/fingerprint.py:482, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
    478             validate_fingerprint(kwargs[fingerprint_name])
    480 # Call actual function
--> 482 out = func(dataset, *args, **kwargs)
    484 # Update fingerprint of in-place transforms + update in-place history of transforms
    486 if inplace:  # update after calling func so that the fingerprint doesn't change if the function fails

File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:2596, in Dataset.set_format(self, type, columns, output_all_columns, **format_kwargs)
   2594     missing_columns = set(columns) - set(self._data.column_names)
   2595     if missing_columns:
-> 2596         raise ValueError(
   2597             f"Columns {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}"
   2598         )
   2599 if columns is not None:
   2600     columns = columns.copy()  # Ensures modifications made to the list after this call don't cause bugs

ValueError: Columns ['tokens'] not in the dataset. Current columns in the dataset: ['text']

This works perfectly well for the DATASET_1, but for DATASET_2, it breaks.

System Info Describe the characteristic of your environment:

  • How transformer_lens was installed: pip
  • OS: Linux(Kaggle Notebook)
  • Python version : 3.10

Checklist

  • [X] I have checked that there is no similar issue in the repo (required)

yash-srivastava19 avatar Aug 23 '24 06:08 yash-srivastava19