Bugs with def classify_review
Bug description
In page 201,
def classify_review(
text, model, tokenizer, device, max_length=None,
pad_token_id=50256):
model.eval()
input_ids = tokenizer.encode(text)
supported_context_length = model.pos_emb.weight.shape[1]
input_ids = input_ids[:min(
max_length, supported_context_length
)]
input_ids += [pad_token_id] * (max_length - len(input_ids))
issue 1: model.pos_emb.weight.shape is [1024,768], of which 1024 is the context length and 768 is the dimension of the embeddings. So, should it be shape[0] instead?
issue 2: if supported_contex_length is smaller than max_length, it makes sense to truncate the input_ids to supported_contex_length. But in this case, what's the meaning of expand to max_length? Though in latter code, calling of classify_review uses max_length of 120 (train_dataset.max_length), which is much smaller than supported_context_length.
So a better way to handle both cases is to:
supported_context_length = model.pos_emb.weight.shape[0]
max_len = min(max_length,supported_context_length) if max_length else supported_context_length
input_ids = input_ids[:max_len]
input_ids += [pad_token_id] * (max_len - len(input_ids))
It also takes care of the case when the max_length is None.
What operating system are you using?
None
Where do you run your code?
None
Environment
To add more context, this is following discussion #659.
Concerning issue1 there's already a note/comment in the classify_review function so that's fine.
For issue2, @d-kleine mentioned it was already discussed in #499. Here I believe it was for the edge case someone manually passes an integer arg > model's context length or leaves the max_length parameter as None. In both cases, it'll return an error.
Imho, a less intrusive option would be to add a note/comment as done for issue1, or add 2 asserts.
Thanks for opening the discussion. Regarding issue 1, I think you may have an older version of the book as this issue has been fixed in a reprint back in January:
Regarding your 2nd suggestion. I agree that this is a bit more robust and cleaner. Thanks, I'll add it.
spam?
yes sir