Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] Tiny stories models have longer n_ctx than they were trained with
Describe the bug
The huggingface tiny stories configs claim to support n_ctx=2048
. However, the model was only trained with sequence length 512 (as mentioned here). The models in fact get much worse performance after 512:
This is not the fault of transformer lens, the problem is upstream in the HF config. I opened a PR upstream here. If that doesn't get promptly fixed, however, I think we should special case the tinystories models to have n_ctx=512
when loading their configs from huggingface here
Code example Code to product the above plot:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate
def get_dataset(n_stories=1_000, n_ctx=256) -> Dataset:
full_dataset = load_dataset(
"skeskinen/TinyStories-GPT4",
streaming=True,
split="train",
)
short_iterable_dataset = full_dataset.take(n_stories)
def ds_generator():
yield from short_iterable_dataset
raw_dataset = Dataset.from_generator(ds_generator, features=short_iterable_dataset.features)
tokenizer = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M").tokenizer
return tokenize_and_concatenate(
raw_dataset,
tokenizer=tokenizer,
max_length=n_ctx,
column_name="story",
add_bos_token=False,
)
N_CTX = 2048
dataset = get_dataset(n_stories=500, n_ctx=N_CTX)
print("number of batch elements:", len(dataset))
dataloader = DataLoader(dataset, batch_size=10)
def get_losses_by_position(model, dataloader):
losses = []
with torch.inference_mode():
for batch in tqdm(dataloader):
inputs = batch["tokens"][:, :-1].cuda()
labels = batch["tokens"][:, 1:].cuda()
logits = model(inputs).flatten(0, 1) # [batch*n_ctx, vocab]
flat_losses = F.cross_entropy(logits, labels.flatten(), reduction="none")
losses.append(flat_losses.reshape(labels.shape))
losses_tensor = torch.cat(losses) # [n_batches*batch, n_ctx]
return losses_tensor.mean(dim=0).cpu() # [n_ctx]
model_ids = [
"roneneldan/TinyStories-1M",
"roneneldan/TinyStories-3M",
"roneneldan/TinyStories-8M",
"roneneldan/TinyStories-28M",
"roneneldan/TinyStories-33M",
"roneneldan/TinyStories-Instruct-1M",
"roneneldan/TinyStories-Instruct-3M",
"roneneldan/TinyStories-Instruct-8M",
"roneneldan/TinyStories-Instruct-28M",
"roneneldan/TinyStories-Instruct-33M",
"roneneldan/TinyStories-1Layer-21M",
"roneneldan/TinyStories-2Layers-33M",
"roneneldan/TinyStories-Instuct-1Layer-21M",
"roneneldan/TinyStories-Instruct-2Layers-33M",
]
losses_by_position = {}
for model_id in model_ids:
tl_model = HookedTransformer.from_pretrained(model_id)
tl_model.to("cuda")
losses_by_position[model_id] = get_losses_by_position(tl_model, dataloader)
# %%
fig, axs = plt.subplots(4, 4, figsize=(10, 10), sharex=True, sharey=True)
for i, (model_id, losses_by_pos) in enumerate(losses_by_position.items()):
ax = axs.flatten()[i]
ax.plot(losses_by_pos.cpu().numpy(), "ok", markersize=1)
ax.set_title(model_id.strip("roneneldan/TinyStories-"))
ax.axvline(512, ls='--', c='k', alpha=0.3)
# ax.set_xlim(500, 520)
for i in range(4):
axs[-1, i].set_xlabel("position")
axs[i, 0].set_ylabel("loss")
plt.tight_layout()
System Info On latest transformer lens version.
Checklist
- [x] I have checked that there is no similar issue in the repo (required)