llama
llama copied to clipboard
Documentation about model stiching
I seem to not find any good documentation of the complete model architecture. Specifically I'm looking into how the tensor weights are stitched together between files. As all .pth
files are needed for inference, i assume they are stitched together before execution.
I see that all tensors are present in all files (eg. tok_embeddings.weight
is present in all consolidated.XX.pth
), so that must mean they need to be put together some way.
In a python implementation for example, is the correct solution just to:
dict1 = torch.load(`consolidated.00.pth`, map_location='cpu')
dict2 = torch.load(`consolidated.01.pth`, map_location='cpu')
weights1 = dict1['tok_embeddings.weight']
weights2 = dict2['tok_embeddings.weight']
# concatenate the tensor weights along dimension 0
weights = torch.cat([weights1, weights2], dim=0)
This feels extremely inefficient, one may use some smarter form of loading parts of the dataset. But alas, am I on the right track?
Any explainations or references are very welcome. Thank you in advance.
Maybe the load
function here will be useful?