transformers
transformers copied to clipboard
Get output_hidden_state and output_scores from Flax whisper model
I need whisper's output_scores and output_hidden_states as the result of generate() method. On Pytorch model, I can easily get the output_scores and output_hidden_states by setting these parameters in generate() method as follows:
whisper_output = model.generate(inputs=input_features, max_new_tokens=180, output_scores=True, output_hidden_states=True, return_dict_in_generate=True)
and the resulted whisper_output
returns 'scores' and 'output_hidden_states' as it keys alongside 'sequences'
Now I want to do so for Flax whisper model. but setting these parameters as the static_argnames of model doesn't have effect to get output_scores.
Is there any solution for getting output_scores or logits from Flax whisper model?
cc @sanchit-gandhi
I found that Flax model when set to use beam-search calculates the scores value: https://github.com/huggingface/transformers/blob/12d51db243a00726a548a43cc333390ebae731e3/src/transformers/generation/flax_utils.py#L83-L96
and in the _beam_search method it is calculated and returned: https://github.com/huggingface/transformers/blob/12d51db243a00726a548a43cc333390ebae731e3/src/transformers/generation/flax_utils.py#L998-L1004
but it doesn't return scores when greedy-search is done: https://github.com/huggingface/transformers/blob/12d51db243a00726a548a43cc333390ebae731e3/src/transformers/generation/flax_utils.py#L55-L65
I run the flax whisper model in beam_search model by passing generation_config.num_beams
to a value larger than 1.
It returns scores
at the output but it is totally different from the scores
returned from PyTorch model.
scores in Flax is just a scalar value but scores output of PyTorch model is a List of n (n = number of output tokens) in which each element of list is a torch.tensor(1, size of vocab). In other words the scores of Pytorch return score of each output token with the probability (score) of every vocab token.
So the Flax output scores is something totally different
I found logits of Flax in flax_utils.py as follows: https://github.com/huggingface/transformers/blob/ed67286465c5e9e3d3005de3e21bc3c679d93072/src/transformers/generation/flax_utils.py#L610-L618
Just need to extract this logits out of greed_search function and return it
I've added the support of output_scores
to the flax_utils.py code in the followin fork:
https://github.com/hannan72/transformers/commit/116d8f38722359ca5d2dad918975348359cc2ac1
And also add support of the following parameters to the Flax-Whisper model: https://github.com/hannan72/transformers/commit/accdcb2d66496c5ee8547739bf833c95e189344c
@sanchit-gandhi Could you review changes and do a PR to support scores value for flax model?
I have made a PR about this feature: https://github.com/huggingface/transformers/pull/22700
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
https://github.com/huggingface/transformers/pull/22700 is still open and active 🤗
https://github.com/huggingface/transformers/pull/22700 is still open and active
Hey everyone! @hannan72 has done a great job at working on the PR for this feature. The Flax generation code is more or less complete, but there are a few extra integration tests we want to add to make sure the code gives the expected results: https://github.com/huggingface/transformers/pull/22700#discussion_r1288921417
If anyone would like to finish this PR, contributions are more than welcome! Feel free to have a look through the pull request and familiarise yourself with the generation code changes. The last pending point is the integration test mentioned above, which should be quite straightforward to add by comparing the Flax outputs to the PyTorch ones.
cc @teddius