llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

Incorrect use of `dtype` resulting into incorrect `token_ids`

Open damin604 opened this issue 1 year ago • 1 comments

In the __iter__ method of the ConcatTokensDataset class, the dtype argument is not specified for the statement yield {'tokens': np.asarray(concat_sample).tobytes()}. The default dtype used by numpy is np.int32.

On the other hand, in the _read_binary_tokenized_sample method of the StreamingTextDataset class, the dtype is specified as np.int64, which results in incorrect token_ids.

Below is example displaying the issue,

import numpy as np
import torch

tokenizer = config.tokenizer

def load_sample(tokens, dtype):
    np_tokens = np.frombuffer(tokens, dtype=dtype)
    pt_tokens = torch.from_numpy(np_tokens)
    print(f"dtype = {dtype}", pt_tokens, tokenizer.decode(pt_tokens), sep="\n")

# dataset is an object of StreamingTextDataset
tokens = dataset.get_item(0)["tokens"]
load_sample(tokens, np.int64)
print()
load_sample(tokens, np.int32)

"""
>>> Output:
dtype = <class 'numpy.int64'>
tensor([   77309411330, 31048318582927, 55344948613018,  ...,
           30064791424,    30064786861,  5420248739272])
<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>...

dtype = <class 'numpy.int32'>
tensor([    2,    18,   143,  ...,     7, 11720,  1262], dtype=torch.int32)
</s>The present invention relates to novel compounds of formula (I) that are capable...
"""

damin604 avatar Jul 11 '23 17:07 damin604

Are you possible on a 32bit system? This is what I get on my machine

In [11]: import numpy as np

In [12]: token_ids = [1,10,100]

In [13]: np_arr = np.asarray(token_ids)

In [14]: np_bytes = np_arr.tobytes()

In [15]: np_arr.dtype
Out[15]: dtype('int64')

In [16]: read_in = np.frombuffer(np_bytes, dtype=np.int64)

In [17]: read_in
Out[17]: array([  1,  10, 100])

In [18]: read_in.dtype
Out[18]: dtype('int64')

We should probably just make the dtype explicit though to avoid the possibility of your issue, but what system are you running into this problem on?

dakinggg avatar Jul 23 '23 23:07 dakinggg