transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Open AI GPT Model Implementation in Flax

Open mayankagarwals opened this issue 2 years ago • 10 comments

Model description

https://huggingface.co/openai-gpt today supports tf and pytorch but not flax. I'd like to implement the support to enhance the current gpt offering by hugging face

Open source status

  • [X] The model implementation is available
  • [X] The model weights are available

Provide useful links for the implementation

Given that the model is already implemented in other two frameworks, I'll try to infer the model from there. Please feel free to provide additional resources that can help me wrap this up better and faster

mayankagarwals avatar Apr 07 '23 05:04 mayankagarwals

@sanchit-gandhi

mayankagarwals avatar Apr 07 '23 05:04 mayankagarwals

@sanchit-gandhi @sgugger Are there any reservations around this? I have gone through GPT architecture and flax code of GPT2. I'm fairly certain this is implementable for exhaustiveness. OpenAI GPT model still sees almost a million downloads a month

Please let me know. Would like to start with a draft PR than just rushing in

mayankagarwals avatar Apr 12 '23 19:04 mayankagarwals

Hey @mayankagarwals! Super sorry for not getting back to you earlier here. Let me give you my two cents: the OpenAI GPT model is definitely still super popular amongst PyTorch users (as you say, ~1 mil downloads per month). What we tend to see with Flax users though is a preference for newer, larger models (e.g. OPT, Flan-T5). This is primarily because of how easy it is to run super large models in JAX with data and model parallelism. So whilst I think this PR would be cool for completeness, I think porting a newer, more flashy model might get the JAX/Flax community more excited! How does this sound?

sanchit-gandhi avatar Apr 21 '23 14:04 sanchit-gandhi

No worries :) @sanchit-gandhi Yes, I had not gone ahead because of the same skepticism. Would you mind pointing me to what in your opinion might be a model worth digging into and think will benefit hugging face and the community? I have a good hold on text generation architecture so something aligned there would be better!

mayankagarwals avatar Apr 21 '23 17:04 mayankagarwals

LLaMA could be cool! What I would suggest doing is starting from the Flax GPT-Neo model (since this is the Flax model most similar to LLaMa) and then adding the new bits in

sanchit-gandhi avatar May 05 '23 09:05 sanchit-gandhi

@sanchit-gandhi I was also thinking of adding a Flax version of LLama (and also GPT-NeoX, maybe others) as some Flax practice. I couldn't find a guide on adding a new framework to an existing model, and I asked on the discord without much avail (but was directed to this issue).

I'm familiar with the architectures having already ported them to other frameworks where I work.

If you could point me in the right direction, I would be happy to port this for you! I wasn't sure if it is as simple as adding a new modeling_flax_* file or if there are more parts / some best practices to be aware of.

Thanks 🤗

vvvm23 avatar Jun 06 '23 12:06 vvvm23

Hey @vvvm23! In this case, since we already have the PT model, the best thing to do would be to add a new modelling file for flax (modeling_flax_llama.py) which is initially copied from the Flax GPT Neo modelling code. You can then start making changes to the Flax code to adapt it to LLama. The reason that we copy from Flax GPT Neo is that it contains optimised code for the attention layer which we should try and re-use for Flax LLama.

You'll then need to make sure that the weight names match and that you have equivalence between PyTorch LLama and Flax LLama. To do this, I would recommend creating a 'dummy' version of the PyTorch LLama model:

from transformers import LlamaConfig, LlamaForCausalLM

config = LlamaConfig(hidden_size=16, intermediate_size=24, max_position_embeddings=128, num_attention_heads=2, num_hidden_layers=2)

model = LlamaForCausalLM(config)
model.save_pretrained("./path/to/save")

And then for your test script, load this same model in PyTorch, then Flax (pass from_pt=True in the from_pretrained call), and verify with random inputs that you get the same logits out when you do a forward pass (example here https://github.com/huggingface/transformers/issues/15476#issue-1121800731)

You can then focus on the tests and converting the actual model weights as required. Feel free to open a PR and tag me - more than happy to help with the integration here!

sanchit-gandhi avatar Jun 06 '23 17:06 sanchit-gandhi

Thanks @sanchit-gandhi that was very comprehensive! I'll let you know how I get on. :hugs:

vvvm23 avatar Jun 06 '23 18:06 vvvm23

Got a bit caught up with real life stuff, but I will be working on this more intensively from Monday, aiming to finish something by end of week.

vvvm23 avatar Jun 16 '23 09:06 vvvm23

@sanchit-gandhi I made a draft PR of my current progress, see #24587. Sorry, I haven't made the full model, been very busy 😓

vvvm23 avatar Jun 30 '23 07:06 vvvm23