mesh-transformer-jax
mesh-transformer-jax copied to clipboard
Is "to_hf_weights.py" specific to "6B_roto_256.json" only?
Is "to_hf_weights.py" specific to "6B_roto_256.json" only? I was trying to make this codebase work for smaller models (e.g., "layers": 12, "d_model": 768, "n_heads": 16). However, the HF model produced by "to_hf_weights.py" generates very strange results on GPU, while "device_sample.py" works fine on TPU VM.
After several hours with different combinations (e.g., fp16, bf16, fp32) of "to_hf_weights.py" with/without "slim_model.py", it seems impossible to build the HF model for the following sample code to produce plausible results:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16) https://huggingface.co/docs/transformers/model_doc/gptj
No idea what I might have done wrong. I wonder: is it possible that "to_hf_weights.py" is not compatible with smaller models? Or, is it possible to know more details in the following?
Running with HuggingFace
To use the model in HuggingFace's transformer library using pytorch, you'll need to transfer the weights into a format that it recognizes. This can be done using to_hf_weights.py. It's recommended that you use slim_model.py before attempting to move the weights to a pytorch/transformer format. Use python to_hf_weights.py --help to see usage details.
https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md
For example, what were the arguments of the command line(s) to produce the "EleutherAI/gpt-j-6B" model hosted on HF? I'd like to follow the same steps of preparing the "EleutherAI/gpt-j-6B" model on HF.
Thank you for any advice.