nanoGPT
nanoGPT copied to clipboard
Why using np.int64 instead of int32 in train.py?
Here the training data is converted to int64:
https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/train.py#L121
I read on the web, even for GPT3:
https://enjoymachinelearning.com/blog/the-gpt-3-vocabulary-size/#:~:text=After%20crunching%20the%20numbers%2C%20we,languages%20that%20use%20many%20words.
""" After crunching the numbers, we found that GPT-3’s vocabulary size is 14,735,746 words. """
Which can be hold by int32, so I am wondering why not use np.int32?
Will this work? or even improve training speed?
Or I am missing something?
Thanks for your comments!
It is because the embedding layer "wte = nn.Embedding(config.vocab_size, config.n_embd)" requires int64 tensor input.
And why does that require 64 tensor input? Why can it not be implemented differently?
I am entirely sure, but it has to do something with preventing overflow in the gradient calculations during backpropagation. You can check out this forum for more information: https://discuss.pytorch.org/t/why-does-nn-embedding-layers-expect-longtensor-type-input-tensors/21376 .
Hmm, I understand the precision issues when using float vs double, but I don't really understand the precision issues when using in64 vs int32 here. So basically the initial x
fed into the model is just an index, and this index will always be between 0 and length of vocabulary (~50k here). I don't really understand how this index can overflow, it's interesting to try to figure out how that would happen. But anyway, in the same link that you have posted, it is mentioned that int32 support is added for the embedding layer, and I think it'd be very helpful for this repo, as it deals with large datasets.
Though I don't really know how this change would interact with loading OpenAI weights, but I'm sure it's eventually fixable if someone spends some time and tests things.
Just a general thought... Take instead of 32bit vs 64bit question... Let's say you use 3 bits... That's a max signed decimal value of 3. If you multiply 2 of them you get 9 which will overflow unless you use twice the capacity 6 bits instead of 3... Do you see what I'm trying to describe here? So using int64 less chance of overflow causing data "corruption"...
I don't understand why any arithmetic operation would be done on (or using) the input indices when taking gradients. My intuition is that these indices are only used to determine which rows of the embedding matrix will be updated with the gradients. I'd be happy to know how the gradient update involves arithmetic operations on or using the indices, if it does.