Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] `tokenize_and_concatenate` doesn't work with small datasets.
Describe the bug
It was mentioned in the docstrings as well that the tokenize_and_concatenate
function doesn't work properly with small datasets. I wanted to figure out is there a workaround that can be used.
Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why.
Code example The dataset I'm using is a small dataset, and sometimes contains only single word. Here is the minimal code that reproduces the error.
from datasets.load import load_dataset
from transformer_lens.utils import tokenize_and_concatenate
#----------- Utility Functions -------------
def create_dataset(link):
ds = load_dataset(
path=f"link",
split="train",
streaming=False,
)
return ds
def get_tokens(dataset, tokenizer=model.tokenizer, streaming=True, max_length=sae.cfg.context_size, add_bos_token=sae.cfg.prepend_bos):
return tokenize_and_concatenate(
dataset = dataset,
tokenizer = tokenizer,
streaming=streaming,
max_length=max_length,
column_name=column_name,
add_bos_token=add_bos_token
)
#----------- Example Usage -------------
DATASET_1 = "link-to-big-dataset"
DATASET_2 = "link-to-small-dataset"
dataset_1 = create_dataset(DATASET_1)
dataset_1_tokens = get_tokens(dataset_1) # This gets executed. No issues.
dataset_2 = create_dataset(DATASET_2)
dataset_2_tokens = get_tokens(dataset_2) # This line breaks.
Here's what the error stack trace looks like :
...
File /opt/conda/lib/python3.10/site-packages/transformer_lens/utils.py:358, in tokenize_and_concatenate(dataset, tokenizer, streaming, max_length, column_name, add_bos_token, num_proc)
350 return {"tokens": tokens}
352 tokenized_dataset = dataset.map(
353 tokenize_function,
354 batched=True,
355 num_proc=(num_proc if not streaming else None),
356 remove_columns=[column_name],
357 )
--> 358 tokenized_dataset.set_format(type="torch", columns=["tokens"])
359 return tokenized_dataset
File /opt/conda/lib/python3.10/site-packages/datasets/fingerprint.py:482, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
478 validate_fingerprint(kwargs[fingerprint_name])
480 # Call actual function
--> 482 out = func(dataset, *args, **kwargs)
484 # Update fingerprint of in-place transforms + update in-place history of transforms
486 if inplace: # update after calling func so that the fingerprint doesn't change if the function fails
File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:2596, in Dataset.set_format(self, type, columns, output_all_columns, **format_kwargs)
2594 missing_columns = set(columns) - set(self._data.column_names)
2595 if missing_columns:
-> 2596 raise ValueError(
2597 f"Columns {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}"
2598 )
2599 if columns is not None:
2600 columns = columns.copy() # Ensures modifications made to the list after this call don't cause bugs
ValueError: Columns ['tokens'] not in the dataset. Current columns in the dataset: ['text']
This works perfectly well for the DATASET_1, but for DATASET_2, it breaks.
System Info Describe the characteristic of your environment:
- How
transformer_lens
was installed: pip - OS: Linux(Kaggle Notebook)
- Python version : 3.10
Checklist
- [X] I have checked that there is no similar issue in the repo (required)