transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Flax models should allow `inputs_embeds`

Open mattf1n opened this issue 3 years ago • 4 comments

Feature request

Currently, non-Flax models allow inputs_embeds instead of input_ids (e.g., GPT2)

def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        ...
        inputs_embeds: Optional[torch.FloatTensor] = None,
        ...
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
    ...

However, Flax models have no such option (input_ids only).

It would be great if Flax models also had this option so that,

Optionally, instead of passing input_ids you can choose to directly pass an embedded representation.

Motivation

This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix. (from the docs)

Additionally, this can be useful for things like tuning "soft-prompts" (e.g., https://aclanthology.org/2021.emnlp-main.243/)

Your contribution

I'm will try to implement this myself, but I haven't yet found a solution.

mattf1n avatar Jul 05 '22 20:07 mattf1n

WDYT @patil-suraj ?

LysandreJik avatar Jul 11 '22 10:07 LysandreJik

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Aug 05 '22 15:08 github-actions[bot]

WDYT, @sanchit-gandhi? :)

LysandreJik avatar Aug 09 '22 07:08 LysandreJik

Hey @mattf1n!

It'll be quite straightforward to modify the __call__ methods in the FlaxGPT2For... classes to handle input_embeds. Here, we can borrow the logic used in the PyTorch counterparts, for example: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_gpt2.py#L783 and https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_gpt2.py#L848 We'll then have to update the init_weights method to reflect these changes: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_flax_gpt2.py#L404 and also the __call__ method for the FlaxGPTPreTrainedModel: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/gpt2/modeling_flax_gpt2.py#L459

The first two changes are pretty straightforward! It's the latter two that I envision being more involved and introducing a little bit more extra code. Do you want to have a go at adding this in a PR @mattf1n? Happy to work with you on adding this feature!

sanchit-gandhi avatar Aug 09 '22 14:08 sanchit-gandhi

I would like to second @mattf1n 's feature request. This would be super useful for vision-language modeling where we often want to feed a concatenation of image and text features into a sequence-to-sequence model. This approach has become quite popular recently. See for example - VL-T5, GPV-1, GPV-2, UnifiedIO. And given that non-Flax models already support this, would be great to have this implemented for Flax models as well for consistency!

BigRedT avatar Aug 10 '22 17:08 BigRedT

Thanks for weighing in @BigRedT! Cool to see so much interest for this feature! Would you or @mattf1n be interested in opening a PR to implement this for a language model (GPT2)? As mentioned, the first two changes should be pretty trivial. More than happy to help with hacking around the init_weights and __call__ method to get the last two bits working!

sanchit-gandhi avatar Aug 12 '22 17:08 sanchit-gandhi

@sanchit-gandhi I am traveling currently but gave it a quick try yesterday. Specifically, I was trying to update FlaxT5ForConditionalGeneration to accept input_embed. It looks like FlaxGenerationMixin in generation_flax_utils.py would also need to be updated as it assumes input_ids are always provided (not Optional). This would be needed to use the generate() method to generate free form text via beam search, greedy decoding etc.

I can share my partial attempt next week when I am back at my desk but overall this seems a bit involved. Would be great if someone like yourself who is more familiar with the huggingface code base takes a stab at it!

P.S - Apologies for any typos; I am writing this on my phone.

BigRedT avatar Aug 12 '22 22:08 BigRedT

Hey @BigRedT! Thanks for jumping on this so quickly. That's a good catch - we will indeed need to handle the case where input_embeds are used for generation. Here, we'll need to do three things:

  1. Allows both input_ids and input_embeds to be passed to the generate() method https://github.com/huggingface/transformers/blob/1ccd2515ed6d7da4ec46fe94aedbd8a86a2cde8e/src/transformers/generation_flax_utils.py#L163
  2. Pass both input_ids and input_embeds to the prepare_inputs_for_generation() method: https://github.com/huggingface/transformers/blob/1ccd2515ed6d7da4ec46fe94aedbd8a86a2cde8e/src/transformers/generation_flax_utils.py#L483
  3. Modify the prepare_inputs_for_generation method to handle both input_ids and input_embeds: https://github.com/huggingface/transformers/blob/1ccd2515ed6d7da4ec46fe94aedbd8a86a2cde8e/src/transformers/models/t5/modeling_flax_t5.py#L1683

There'll then be some cleaning up to make sure batch-size and sequence length terms are set correctly (from either the input_ids or input_embeds accordingly). I've given examples for greedy search and GPT2, but the same logic holds for beam search or sampling and other Causal LMs.

As you say @BigRedT, this is already getting quite involved, both in-terms of the amount of code involved and complexity of the problem. It's going to be a challenging task, but feel free to open a PR with what you've currently got, happy to help with the integration and guide you through! I'm out-of-office next week, but can take a look when I'm back :)

sanchit-gandhi avatar Aug 13 '22 10:08 sanchit-gandhi

@sanchit-gandhi here's the PR you requested. I was actually able to get it to work with minimal modifications to generation_flax_utils.py.

@mattf1n a similar solution might work for GPT-2 as well?

BigRedT avatar Aug 13 '22 13:08 BigRedT

Awesome! Replied on the PR 🙂

sanchit-gandhi avatar Aug 13 '22 14:08 sanchit-gandhi

Hi! I was away for a bit. I'm happy to see so much activity this past week or so! I have implemented a working version for GPT-Neo. I'll try making a pull request soon

mattf1n avatar Aug 17 '22 17:08 mattf1n

Here is the PR ^

mattf1n avatar Aug 17 '22 20:08 mattf1n

This issue is still open with a WIP PR at https://github.com/huggingface/transformers/pull/18613

The PR is near completion - if anyone wants to work with @BigRedT to finish this one off feel free to have a go on the PR! More than happy to answer any questions and provide a review 🤗

sanchit-gandhi avatar Oct 10 '22 13:10 sanchit-gandhi

The PR is still open if anyone would like to see it to completion! Happy to lend a hand with questions / queries!

sanchit-gandhi avatar Nov 04 '22 13:11 sanchit-gandhi

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Nov 28 '22 15:11 github-actions[bot]

Going to leave this one closed for now since interest seems to have dwindled. If you're interested in picking this back up, feel free to reopen the issue and tag me 🤗 We can go from there!

sanchit-gandhi avatar Dec 07 '22 18:12 sanchit-gandhi