transformers icon indicating copy to clipboard operation
transformers copied to clipboard

add flax whisper implementation

Open andyehrenberg opened this issue 2 years ago β€’ 24 comments

Adds Flax whisper implementations, and adjusts flax generation utils to support it.

@ydshieh @ArthurZucker

See discussion in #19512

andyehrenberg avatar Nov 28 '22 22:11 andyehrenberg

@andyehrenberg

Thank you for the PR. However, a pull request should focus on a single objective/goal, rather than changing multiple things at the same time which are not absolutely coupled.

Please

  • follow the pytorch implementation regarding the past_key_values
  • revert the changes on the flax generation utils (You may want to have a backup branch to save these changes for future pull requests.)

The goal of this PR is to add Flax implementation of Whisper. For other changes, it's better to open issue tickets, and if we all agree with the proposals, a PR could proceed :-)

Thank you!

ydshieh avatar Nov 29 '22 09:11 ydshieh

I see a few other instances in this repo where the pytorch implementation computes past_key_values_length while the flax implementation uses position_ids (BART, OPT, etc) - to me, keeping consistency among the APIs of the flax models is something we should strive for. What do you think @ydshieh @patrickvonplaten ?

Happy to remove the changes to the generation stuff and open a separate PR for that - will definitely do this to make flax Whisper generation work!

andyehrenberg avatar Nov 29 '22 16:11 andyehrenberg

I wasn't aware of that inconsistency, thank you for pointing out. This is a good question! But I don't think that's a very serious problem so far - the most important thing is the different frameworks produce the same outputs when feeding the same (supported) inputs + the API on the top model levels being consistent.

(The internal computation could be somehow different - if there is good reason)

In any case, this could be discussed in an issue and we can proceed with a PR once decided :-)

ydshieh avatar Nov 29 '22 17:11 ydshieh

BTW, there is some issue for triggering CircleCI. The message is

Could not find a usable config.yml, you may have revoked the CircleCI OAuth app.
Please sign out of CircleCI and log back in with your VCS before triggering a new pipeline.

Do you use some IDE to push the commits? Could you try to push the commit with a commandline tool or some git GUI tools instead?

ydshieh avatar Nov 29 '22 17:11 ydshieh

The documentation is not available anymore as the PR was closed or merged.

Also cc @sanchit-gandhi

patrickvonplaten avatar Nov 30 '22 11:11 patrickvonplaten

Hey! Thanks for opening the follow PR πŸ€—

I don't think I agree with @ydshieh here, adding the flax_generation_utils along with whisper totally makes sense as it was done for pytorch and tf, and is required to add the generation tests which are currently missing! Regarding the past_key_values, we don't really strive to match transformers with other APIs, rather I think we prefer consistency within our own library, and code clarity. However you can still open an issue and we can discuss whether we should refactor the design of past_key_values for our flax model!

Will have a look at the PR πŸ˜‰

ArthurZucker avatar Dec 02 '22 09:12 ArthurZucker

You are right! I am not aware of those generation features are introduced when you added Whisper @ArthurZucker . Sorry about that, @andyehrenberg !

ydshieh avatar Dec 02 '22 11:12 ydshieh

Super excited by this PR! πŸš€ Feel free to tag me with questions / review requests as well @andyehrenberg πŸ€—

sanchit-gandhi avatar Dec 02 '22 16:12 sanchit-gandhi

Hey @andyehrenberg! Looks like you found my old PR for implementing scan with Flax nn.Modules and copied the logic across https://github.com/huggingface/transformers/pull/18341

I'm happy to answer @ArthurZucker's questions regarding scan here. In the end, we decided not to pursue with adding scan in Transformers - this is why you haven't seen the PR merged or scan in any of our Flax models.

The reason for this is that scan adds a lot of complexity to the modelling code. Whilst it does give faster compile times for training, it is actually slower for inference. On balance, it's not worth the myriad of extra code for a small speed-up to compile time for training. We prefer readability and ease of understanding over highly optimised code in Transformers. Because of this, unfortunately scan is not a good fit.

Note: since Whisper pads/truncates the audio inputs to 30s, the inputs to Whisper are always of fixed dimension. This means that you only ever need 1 compile step! So the compilation time is entirely amortised by the subsequent compiled times during training/inference. For this reason, I advise that you stick to the regular way of implementing unrolled Flax nn.Modules for Whisper.

Happy to answer any questions regarding scan and why we don't include it in our modelling code!

The optimum library might be a better place for highly optimised Flax code: https://github.com/huggingface/optimum

sanchit-gandhi avatar Dec 02 '22 16:12 sanchit-gandhi

Hey @ydshieh! Is there a way of enabling the Flax CI in this PR? Before merging it'd be awesome to verify that the Flax CI is βœ…

sanchit-gandhi avatar Dec 22 '22 11:12 sanchit-gandhi

cc @sanchit-gandhi @sgugger for a final review here maybe :-)

patrickvonplaten avatar Dec 30 '22 14:12 patrickvonplaten

@andyehrenberg thanks for the changes in the last commit <3

Green light for this PR on my end [generate]

gante avatar Jan 02 '23 17:01 gante

Mmmm, before merging this PR, there is something wrong going on with the tests: only one of the tests job is actually run (no tests_flax/tests_tf etc...)

Will investigate later today unless someone beats me to it.

sgugger avatar Jan 03 '23 11:01 sgugger

It looks like running under the wrong CircleCI project (on the PR author one, not on huggingface/transformers), and it got

Resource class docker for xlarge is not available for your project, or is not a valid resource class. This message will often appear if the pricing plan for this project does not support docker use.

See https://app.circleci.com/pipelines/github/andyehrenberg/transformers?branch=flax_whisper

ydshieh avatar Jan 03 '23 12:01 ydshieh

@andyehrenberg

Could you follow the instruction mentioned here, and see if it fixes the CI issue?

If you're following the fork instead of the upstream repo A user who submits a pull request to your repository from a fork, but no pipeline is triggered with the pull request. This can happen when the user is following the project fork on their personal account rather than the project itself on CircleCI.

This will cause the jobs to trigger under the user's personal account. If the user is following a fork of the repository on CircleCI, we will only build on that fork and not the parent, so the parent’s PR will not get status updates.

In these cases, the user unfollows their fork of the project on CircleCI. This will trigger their jobs to run under the organization when they submit pull requests. Those users can optionally follow the source project if they wish to see the pipelines.

ydshieh avatar Jan 03 '23 12:01 ydshieh

Mmmm, before merging this PR, there is something wrong going on with the tests: only one of the tests job is actually run (no tests_flax/tests_tf etc...)

Will investigate later today unless someone beats me to it.

@sgugger Fixed, and all tests are passing now (had to override some tests due to input_features being different from its usual shape in the tests)

andyehrenberg avatar Jan 04 '23 15:01 andyehrenberg

Thanks @andyehrenberg !

@sanchit-gandhi Can you have one final look?

sgugger avatar Jan 04 '23 18:01 sgugger

@sanchit-gandhi - How can I rerun the checks without further commits? The error looks like an account limit overshoot and doesn't seem to do with the two newer commits.

andyehrenberg avatar Jan 16 '23 18:01 andyehrenberg

@andyehrenberg We can re-run the failed tests on the job run page Screenshot 2023-01-16 202103

But I think only HF members can do that - I will launch it.

ydshieh avatar Jan 16 '23 19:01 ydshieh

@sanchit-gandhi I think it's ready for another look by you! The torch tests it's failing current seem unrelated to the PR, so rerunning CI may give all passes

andyehrenberg avatar Jan 21 '23 19:01 andyehrenberg

Also sorry! We just modified Whisper quit a bit πŸ˜…

ArthurZucker avatar Jan 26 '23 14:01 ArthurZucker

Also sorry! We just modified Whisper quit a bit πŸ˜…

@ArthurZucker - Doesn't actually look too bad to catch up with those changes! Can do that soon-ish. I already have a jax timestamp processor that's compilable.

andyehrenberg avatar Jan 26 '23 14:01 andyehrenberg

Oh no - sorry you have to iterate again here @andyehrenberg! Feel free to ping me with any questions / discussions - more than happy to help with the final sprint of the integration! Otherwise super excited to review a final time before merge! πŸš€

sanchit-gandhi avatar Jan 27 '23 15:01 sanchit-gandhi

@sanchit-gandhi - I think this is ready for another look - the recent commits (I think) get us to feature parity with the torch version.

andyehrenberg avatar Feb 04 '23 15:02 andyehrenberg

@sanchit-gandhi Bump

andyehrenberg avatar Feb 13 '23 15:02 andyehrenberg

@sanchit-gandhi @ArthurZucker - Addressed Arthur's comments and cleaned up the timestamp logits processor a bit. Hopefully we're close to getting this merged!

andyehrenberg avatar Feb 15 '23 21:02 andyehrenberg

Very nice @andyehrenberg! Thanks for iterating here - reviewed the new changes and the PR is looking super clean. Last request from me is if we can avoid defining the if_true() functions if possible and just add the code explicitly! Good for merge otherwise :)

For sure, made those changes :)

andyehrenberg avatar Feb 17 '23 22:02 andyehrenberg

Is there any instructions to open the google cloud TPU port, admin?

Tungbillee avatar Feb 06 '24 10:02 Tungbillee