Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] HookedTranformer.generate() with model.tokenizer unset gives pad_token_id error

Open JackCai1206 opened this issue 1 year ago • 2 comments

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug When calling HookedTranformer.generate() without initializing the model with a tokenizer and having the default use_past_kv_cache=True, I am getting an error 'NoneType' object has no attribute 'pad_token_id'. I believe this is because, when use_past_kv_cache=True, the library needs to determine the pad tokens from an input, so it assumes mode.tokenizer is also set so it can the pad token id. However, in my use case, I am not using a tokenizer for my model. Maybe there can be a way to manually supply the pad_token_id in this case?

Code example

from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch as t

cfg = HookedTransformerConfig(
    n_layers=1,
    n_heads=1,
    d_model=4,
    d_vocab=4,
    n_ctx=4,
    d_head=2,
    act_fn='relu'
)

model = HookedTransformer(cfg)

input = t.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=t.long)
model.generate(input, eos_token_id=0)

System Info Describe the characteristic of your environment: Linux transformer_lens version 1.12.0

Additional context Add any other context about the problem here.

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)

JackCai1206 avatar Jan 18 '24 16:01 JackCai1206

Interesting! What's the use case?

Either way, it'd be an easy fix to just add an optional pad_token_id parameter to the generate function, feel free to make a PR

neelnanda-io avatar Jan 18 '24 19:01 neelnanda-io

Sounds good I can try to make a PR here. I am training a model to solve integer addition problems, and it uses simple character-level tokenization. Maybe it is possible to load such a tokenizer from hugginface but I haven't found any, so I just used a custom encode and decode function. This means the model.tokenizer is unset for me.

JackCai1206 avatar Jan 18 '24 21:01 JackCai1206