transformers
transformers copied to clipboard
Flax models should allow `inputs_embeds`
Feature request
Currently, non-Flax models allow inputs_embeds instead of input_ids (e.g., GPT2)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
...
inputs_embeds: Optional[torch.FloatTensor] = None,
...
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
...
However, Flax models have no such option (input_ids only).
It would be great if Flax models also had this option so that,
Optionally, instead of passing
input_idsyou can choose to directly pass an embedded representation.
Motivation
This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix. (from the docs)
Additionally, this can be useful for things like tuning "soft-prompts" (e.g., https://aclanthology.org/2021.emnlp-main.243/)
Your contribution
I'm will try to implement this myself, but I haven't yet found a solution.
WDYT @patil-suraj ?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
WDYT, @sanchit-gandhi? :)
Hey @mattf1n!
It'll be quite straightforward to modify the __call__ methods in the FlaxGPT2For... classes to handle input_embeds. Here, we can borrow the logic used in the PyTorch counterparts, for example:
https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_gpt2.py#L783
and
https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_gpt2.py#L848
We'll then have to update the init_weights method to reflect these changes:
https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_flax_gpt2.py#L404
and also the __call__ method for the FlaxGPTPreTrainedModel:
https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_flax_gpt2.py#L459
The first two changes are pretty straightforward! It's the latter two that I envision being more involved and introducing a little bit more extra code. Do you want to have a go at adding this in a PR @mattf1n? Happy to work with you on adding this feature!
I would like to second @mattf1n 's feature request. This would be super useful for vision-language modeling where we often want to feed a concatenation of image and text features into a sequence-to-sequence model. This approach has become quite popular recently. See for example - VL-T5, GPV-1, GPV-2, UnifiedIO. And given that non-Flax models already support this, would be great to have this implemented for Flax models as well for consistency!
Thanks for weighing in @BigRedT! Cool to see so much interest for this feature! Would you or @mattf1n be interested in opening a PR to implement this for a language model (GPT2)? As mentioned, the first two changes should be pretty trivial. More than happy to help with hacking around the init_weights and __call__ method to get the last two bits working!
@sanchit-gandhi I am traveling currently but gave it a quick try yesterday. Specifically, I was trying to update FlaxT5ForConditionalGeneration to accept input_embed. It looks like FlaxGenerationMixin in generation_flax_utils.py would also need to be updated as it assumes input_ids are always provided (not Optional). This would be needed to use the generate() method to generate free form text via beam search, greedy decoding etc.
I can share my partial attempt next week when I am back at my desk but overall this seems a bit involved. Would be great if someone like yourself who is more familiar with the huggingface code base takes a stab at it!
P.S - Apologies for any typos; I am writing this on my phone.
Hey @BigRedT! Thanks for jumping on this so quickly. That's a good catch - we will indeed need to handle the case where input_embeds are used for generation. Here, we'll need to do three things:
- Allows both
input_idsandinput_embedsto be passed to thegenerate()method https://github.com/huggingface/transformers/blob/1ccd2515ed6d7da4ec46fe94aedbd8a86a2cde8e/src/transformers/generation_flax_utils.py#L163 - Pass both
input_idsandinput_embedsto theprepare_inputs_for_generation()method: https://github.com/huggingface/transformers/blob/1ccd2515ed6d7da4ec46fe94aedbd8a86a2cde8e/src/transformers/generation_flax_utils.py#L483 - Modify the
prepare_inputs_for_generationmethod to handle bothinput_idsandinput_embeds: https://github.com/huggingface/transformers/blob/1ccd2515ed6d7da4ec46fe94aedbd8a86a2cde8e/src/transformers/models/t5/modeling_flax_t5.py#L1683
There'll then be some cleaning up to make sure batch-size and sequence length terms are set correctly (from either the input_ids or input_embeds accordingly). I've given examples for greedy search and GPT2, but the same logic holds for beam search or sampling and other Causal LMs.
As you say @BigRedT, this is already getting quite involved, both in-terms of the amount of code involved and complexity of the problem. It's going to be a challenging task, but feel free to open a PR with what you've currently got, happy to help with the integration and guide you through! I'm out-of-office next week, but can take a look when I'm back :)
@sanchit-gandhi here's the PR you requested. I was actually able to get it to work with minimal modifications to generation_flax_utils.py.
@mattf1n a similar solution might work for GPT-2 as well?
Awesome! Replied on the PR 🙂
Hi! I was away for a bit. I'm happy to see so much activity this past week or so! I have implemented a working version for GPT-Neo. I'll try making a pull request soon
Here is the PR ^
This issue is still open with a WIP PR at https://github.com/huggingface/transformers/pull/18613
The PR is near completion - if anyone wants to work with @BigRedT to finish this one off feel free to have a go on the PR! More than happy to answer any questions and provide a review 🤗
The PR is still open if anyone would like to see it to completion! Happy to lend a hand with questions / queries!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Going to leave this one closed for now since interest seems to have dwindled. If you're interested in picking this back up, feel free to reopen the issue and tag me 🤗 We can go from there!