Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] Tiny stories models have longer n_ctx than they were trained with

Open nix-apollo opened this issue 1 year ago • 0 comments

Describe the bug The huggingface tiny stories configs claim to support n_ctx=2048. However, the model was only trained with sequence length 512 (as mentioned here). The models in fact get much worse performance after 512: image

This is not the fault of transformer lens, the problem is upstream in the HF config. I opened a PR upstream here. If that doesn't get promptly fixed, however, I think we should special case the tinystories models to have n_ctx=512 when loading their configs from huggingface here

Code example Code to product the above plot:

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate


def get_dataset(n_stories=1_000, n_ctx=256) -> Dataset:
    full_dataset = load_dataset(
        "skeskinen/TinyStories-GPT4",
        streaming=True,
        split="train",
    )
    short_iterable_dataset = full_dataset.take(n_stories)

    def ds_generator():
        yield from short_iterable_dataset

    raw_dataset = Dataset.from_generator(ds_generator, features=short_iterable_dataset.features)
    tokenizer = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M").tokenizer
    return tokenize_and_concatenate(
        raw_dataset,
        tokenizer=tokenizer,
        max_length=n_ctx,
        column_name="story",
        add_bos_token=False,
    )


N_CTX = 2048
dataset = get_dataset(n_stories=500, n_ctx=N_CTX)
print("number of batch elements:", len(dataset))
dataloader = DataLoader(dataset, batch_size=10)


def get_losses_by_position(model, dataloader):
    losses = []
    with torch.inference_mode():
        for batch in tqdm(dataloader):
            inputs = batch["tokens"][:, :-1].cuda()
            labels = batch["tokens"][:, 1:].cuda()
            logits = model(inputs).flatten(0, 1)  # [batch*n_ctx, vocab]
            flat_losses = F.cross_entropy(logits, labels.flatten(), reduction="none")
            losses.append(flat_losses.reshape(labels.shape))
    losses_tensor = torch.cat(losses)  # [n_batches*batch, n_ctx]
    return losses_tensor.mean(dim=0).cpu()  # [n_ctx]


model_ids = [
    "roneneldan/TinyStories-1M",
    "roneneldan/TinyStories-3M",
    "roneneldan/TinyStories-8M",
    "roneneldan/TinyStories-28M",
    "roneneldan/TinyStories-33M",
    "roneneldan/TinyStories-Instruct-1M",
    "roneneldan/TinyStories-Instruct-3M",
    "roneneldan/TinyStories-Instruct-8M",
    "roneneldan/TinyStories-Instruct-28M",
    "roneneldan/TinyStories-Instruct-33M",
    "roneneldan/TinyStories-1Layer-21M",
    "roneneldan/TinyStories-2Layers-33M",
    "roneneldan/TinyStories-Instuct-1Layer-21M",
    "roneneldan/TinyStories-Instruct-2Layers-33M",
]

losses_by_position = {}
for model_id in model_ids:
    tl_model = HookedTransformer.from_pretrained(model_id)
    tl_model.to("cuda")
    losses_by_position[model_id] = get_losses_by_position(tl_model, dataloader)

# %%
fig, axs = plt.subplots(4, 4, figsize=(10, 10), sharex=True, sharey=True)
for i, (model_id, losses_by_pos) in enumerate(losses_by_position.items()):
    ax = axs.flatten()[i]
    ax.plot(losses_by_pos.cpu().numpy(), "ok", markersize=1)
    ax.set_title(model_id.strip("roneneldan/TinyStories-"))
    ax.axvline(512, ls='--', c='k', alpha=0.3)
    # ax.set_xlim(500, 520)

for i in range(4):
    axs[-1, i].set_xlabel("position")
    axs[i, 0].set_ylabel("loss")

plt.tight_layout()

System Info On latest transformer lens version.

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)

nix-apollo avatar Jan 24 '24 09:01 nix-apollo