transformers
transformers copied to clipboard
add flax whisper implementation
Adds Flax whisper implementations, and adjusts flax generation utils to support it.
@ydshieh @ArthurZucker
See discussion in #19512
@andyehrenberg
Thank you for the PR. However, a pull request should focus on a single objective/goal, rather than changing multiple things at the same time which are not absolutely coupled.
Please
- follow the pytorch implementation regarding the
past_key_values
- revert the changes on the flax generation utils (You may want to have a backup branch to save these changes for future pull requests.)
The goal of this PR is to add Flax implementation of Whisper. For other changes, it's better to open issue tickets, and if we all agree with the proposals, a PR could proceed :-)
Thank you!
I see a few other instances in this repo where the pytorch implementation computes past_key_values_length
while the flax implementation uses position_ids
(BART, OPT, etc) - to me, keeping consistency among the APIs of the flax models is something we should strive for. What do you think @ydshieh @patrickvonplaten ?
Happy to remove the changes to the generation stuff and open a separate PR for that - will definitely do this to make flax Whisper generation work!
I wasn't aware of that inconsistency, thank you for pointing out. This is a good question! But I don't think that's a very serious problem so far - the most important thing is the different frameworks produce the same outputs when feeding the same (supported) inputs + the API on the top model levels being consistent.
(The internal computation could be somehow different - if there is good reason)
In any case, this could be discussed in an issue and we can proceed with a PR once decided :-)
BTW, there is some issue for triggering CircleCI. The message is
Could not find a usable config.yml, you may have revoked the CircleCI OAuth app.
Please sign out of CircleCI and log back in with your VCS before triggering a new pipeline.
Do you use some IDE to push the commits? Could you try to push the commit with a commandline tool or some git GUI tools instead?
The documentation is not available anymore as the PR was closed or merged.
Also cc @sanchit-gandhi
Hey! Thanks for opening the follow PR π€
I don't think I agree with @ydshieh here, adding the flax_generation_utils
along with whisper totally makes sense as it was done for pytorch
and tf
, and is required to add the generation
tests which are currently missing!
Regarding the past_key_values
, we don't really strive to match transformers
with other APIs, rather I think we prefer consistency within our own library, and code clarity.
However you can still open an issue and we can discuss whether we should refactor the design of past_key_values
for our flax
model!
Will have a look at the PR π
You are right! I am not aware of those generation features are introduced when you added Whisper @ArthurZucker . Sorry about that, @andyehrenberg !
Super excited by this PR! π Feel free to tag me with questions / review requests as well @andyehrenberg π€
Hey @andyehrenberg! Looks like you found my old PR for implementing scan with Flax nn.Modules
and copied the logic across https://github.com/huggingface/transformers/pull/18341
I'm happy to answer @ArthurZucker's questions regarding scan here. In the end, we decided not to pursue with adding scan in Transformers - this is why you haven't seen the PR merged or scan in any of our Flax models.
The reason for this is that scan adds a lot of complexity to the modelling code. Whilst it does give faster compile times for training, it is actually slower for inference. On balance, it's not worth the myriad of extra code for a small speed-up to compile time for training. We prefer readability and ease of understanding over highly optimised code in Transformers. Because of this, unfortunately scan is not a good fit.
Note: since Whisper pads/truncates the audio inputs to 30s, the inputs to Whisper are always of fixed dimension. This means that you only ever need 1 compile step! So the compilation time is entirely amortised by the subsequent compiled times during training/inference. For this reason, I advise that you stick to the regular way of implementing unrolled Flax nn.Modules
for Whisper.
Happy to answer any questions regarding scan and why we don't include it in our modelling code!
The optimum library might be a better place for highly optimised Flax code: https://github.com/huggingface/optimum
Hey @ydshieh! Is there a way of enabling the Flax CI in this PR? Before merging it'd be awesome to verify that the Flax CI is β
cc @sanchit-gandhi @sgugger for a final review here maybe :-)
@andyehrenberg thanks for the changes in the last commit <3
Green light for this PR on my end [generate]
Mmmm, before merging this PR, there is something wrong going on with the tests: only one of the tests job is actually run (no tests_flax/tests_tf etc...)
Will investigate later today unless someone beats me to it.
It looks like running under the wrong CircleCI project (on the PR author one, not on huggingface/transformers
), and it got
Resource class docker for xlarge is not available for your project, or is not a valid resource class. This message will often appear if the pricing plan for this project does not support docker use.
See https://app.circleci.com/pipelines/github/andyehrenberg/transformers?branch=flax_whisper
@andyehrenberg
Could you follow the instruction mentioned here, and see if it fixes the CI issue?
If you're following the fork instead of the upstream repo A user who submits a pull request to your repository from a fork, but no pipeline is triggered with the pull request. This can happen when the user is following the project fork on their personal account rather than the project itself on CircleCI.
This will cause the jobs to trigger under the user's personal account. If the user is following a fork of the repository on CircleCI, we will only build on that fork and not the parent, so the parentβs PR will not get status updates.
In these cases, the user unfollows their fork of the project on CircleCI. This will trigger their jobs to run under the organization when they submit pull requests. Those users can optionally follow the source project if they wish to see the pipelines.
Mmmm, before merging this PR, there is something wrong going on with the tests: only one of the tests job is actually run (no tests_flax/tests_tf etc...)
Will investigate later today unless someone beats me to it.
@sgugger Fixed, and all tests are passing now (had to override some tests due to input_features
being different from its usual shape in the tests)
Thanks @andyehrenberg !
@sanchit-gandhi Can you have one final look?
@sanchit-gandhi - How can I rerun the checks without further commits? The error looks like an account limit overshoot and doesn't seem to do with the two newer commits.
@andyehrenberg We can re-run the failed tests on the job run page
But I think only HF members can do that - I will launch it.
@sanchit-gandhi I think it's ready for another look by you! The torch tests it's failing current seem unrelated to the PR, so rerunning CI may give all passes
Also sorry! We just modified Whisper quit a bit π
Also sorry! We just modified Whisper quit a bit π
@ArthurZucker - Doesn't actually look too bad to catch up with those changes! Can do that soon-ish. I already have a jax timestamp processor that's compilable.
Oh no - sorry you have to iterate again here @andyehrenberg! Feel free to ping me with any questions / discussions - more than happy to help with the final sprint of the integration! Otherwise super excited to review a final time before merge! π
@sanchit-gandhi - I think this is ready for another look - the recent commits (I think) get us to feature parity with the torch version.
@sanchit-gandhi Bump
@sanchit-gandhi @ArthurZucker - Addressed Arthur's comments and cleaned up the timestamp logits processor a bit. Hopefully we're close to getting this merged!
Very nice @andyehrenberg! Thanks for iterating here - reviewed the new changes and the PR is looking super clean. Last request from me is if we can avoid defining the
if_true()
functions if possible and just add the code explicitly! Good for merge otherwise :)
For sure, made those changes :)
Is there any instructions to open the google cloud TPU port, admin?