llama2.c icon indicating copy to clipboard operation
llama2.c copied to clipboard

Incorrect parameter counts for 15M, 42M, 110M models?

Open xenova opened this issue 2 years ago • 3 comments

While converting these models to ONNX to be used in transformers.js (e.g., 15M), I was encountering a problem where the ONNX version would be slightly larger than the pytorch version. After a few hours of debugging the conversion script, I decided to just check the parameter counts of the pytorch models.

!wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin
data = torch.load("stories42M.pt")
sum([v.numel() for k, v in data["model"].items()])

outputs 58073600, which is significantly larger than expected (42M).

The same thing happens for the other versions:

  • 15M → 24407712
  • 42M → 58073600
  • 110M → 134105856

However, the actual models on the HF hub suggest that the advertised names are indeed correct (e.g., 15M * 32 bit = 15M * 4 bytes = 60MB): image

If anyone knows what's going on, I'd greatly appreciate some explanation - thanks!

xenova avatar Sep 02 '23 22:09 xenova

Those numbers are larger than expected by exactly vocab_size * dim (32000*288 for the 15M model) which is the size of the input and output layers. If you count the parameters of a loaded model, so not the state dict, the expected numbers are returned. This is due to weight tying between the large input embeddings and the output layer as seen here

https://github.com/karpathy/llama2.c/blob/master/model.py#L224

janimo avatar Sep 03 '23 16:09 janimo

Good observation @janimo, thanks! So, that's something I was partially aware of - and in fact we've just made some updates to Optimum (the library used to convert to ONNX) - but the logs indicate that it wasn't able to find tied weights (which does seem to be a bug on our end, since the model does use tied weights).

cc @fxmarty

xenova avatar Sep 03 '23 18:09 xenova

See this commit (https://github.com/karpathy/llama2.c/pull/395/commits/fc11cc387b47efd98ca4ac0956f715d2e5451c41) or in line L224 in model.py to see where weights are tied, or more discussion in this issue: https://github.com/karpathy/llama2.c/issues/321#issuecomment-1722272404

nickypro avatar Sep 16 '23 17:09 nickypro