transformers
transformers copied to clipboard
Add support of output_scores to flax models
What does this PR do?
Flax models does not support output_scores
when generate() method is called, despite the PyTorch models that fully supports this feature.
It is tried to follow naming and format of these parameters as same as PyTorch model codes (utils.py)
Before submitting
- [x] This PR adds support of
output_scores
to flax models. - [x] Flax Whisper model handles
output_scores
andnum_beams
parameters to consider during generate().
Who can review?
@sanchit-gandhi
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
@sanchit-gandhi @ArthurZucker Could you please review this PR?
I get a new error on CI for test codes of all models:
AttributeError: module 'jax' has no attribute 'Array'
However there was not such error two weeks ago (while my commit all tests passed).
Is there any updates on Jax
that is not compatible with?
Any ideas @sanchit-gandhi ?
Update: The above mentioned error has been raised from optax
source codes:
examples/flax/test_flax_examples.py:41: in <module>
import run_clm_flax
examples/flax/language-modeling/run_clm_flax.py:40: in <module>
import optax
../.pyenv/versions/3.8.12/lib/python3.8/site-packages/optax/__init__.py:18: in <module>
from optax._src.alias import adabelief
../.pyenv/versions/3.8.12/lib/python3.8/site-packages/optax/_src/alias.py:23: in <module>
from optax._src import clipping
../.pyenv/versions/3.8.12/lib/python3.8/site-packages/optax/_src/clipping.py:130: in <module>
) -> Tuple[List[chex.Array], jax.Array]:
E AttributeError: module 'jax' has no attribute 'Array'
Does it have relevance to this recent merge: https://github.com/huggingface/transformers/pull/22895 ?
I get a new error on CI for test codes of all models:
AttributeError: module 'jax' has no attribute 'Array'
However there was not such error two weeks ago (while my commit all tests passed). Is there any updates onJax
that is not compatible with? Any ideas @sanchit-gandhi ?Update: The above mentioned error has been raised from
optax
source codes:examples/flax/test_flax_examples.py:41: in <module> import run_clm_flax examples/flax/language-modeling/run_clm_flax.py:40: in <module> import optax ../.pyenv/versions/3.8.12/lib/python3.8/site-packages/optax/__init__.py:18: in <module> from optax._src.alias import adabelief ../.pyenv/versions/3.8.12/lib/python3.8/site-packages/optax/_src/alias.py:23: in <module> from optax._src import clipping ../.pyenv/versions/3.8.12/lib/python3.8/site-packages/optax/_src/clipping.py:130: in <module> ) -> Tuple[List[chex.Array], jax.Array]: E AttributeError: module 'jax' has no attribute 'Array'
Does it have relevance to this recent merge: #22895 ?
I've found that this issue was related to the optax version (which installed the 0.1.5). In the updated version of transformer repo, the version to be installed is forced to be 0.1.4
Good catch regarding the jax.Array
issue! I need to un-pin JAX on Transformers since new Optax / Chex versions are running ahead https://github.com/huggingface/transformers/issues/19842 Will do this tomorrow 🤗
Thanks for the review @gante 🙌 See https://github.com/huggingface/transformers/pull/22700#discussion_r1182794360 for the next steps @hannan72 🚀
Also see related: #22700
This might get merged before this PR, in which case we can rebase to get the beam score fixes from main! Your changes will still be valuable for greedy search @hannan72 🤗
Hey @hannan72! This PR is looking in good shape - would you like to get it over the line with the last bits of integration?
Hey @hannan72! This PR is looking in good shape - would you like to get it over the line with the last bits of integration?
Sorry for late response. I was busy on a product release. Yes I really want to make it final and put it in the next release of transformers. What is remaining? Please clarify the remaining steps to finalize the PR and close this issue.
Awesome! It's more or less as you left it - the major "TODO" is getting the correct vocab size in the first forward pass (see https://github.com/huggingface/transformers/pull/22700#discussion_r1178185868)
Awesome! It's more or less as you left it - the major "TODO" is getting the correct vocab size in the first forward pass (see #22700 (comment))
I had made a try on it and posted the result:
I tried to do this. But there was an error stopped me working on it.
I get the vocab_size
from logits shape in the first step as follows:
next_tokens_scores = logits_processor(state.sequences, logits, state.cur_len)
next_token = jnp.argmax(next_tokens_scores, axis=-1)
scores = state.scores
if output_scores and state.scores is None:
vocab_size = next_tokens_scores.shape[-1]
scores = jnp.ones((batch_size, max_length, vocab_size)) * np.array(-1.0e7)
tokens_scores = scores.at[:, state.cur_len, :].set(next_tokens_scores) if output_scores else None
But in the line: https://github.com/huggingface/transformers/blob/312b104ff65514736c0475814fec19e47425b0b5/src/transformers/generation/flax_utils.py#L641 it checks that tensor shapes between runs should be exactly same, which causes the following error :
Exception has occurred: TypeError (note: full exception trace is shown but execution is paused at: _run_module_as_main) body_fun output and input must have same type structure, got PyTreeDef(CustomNode(GreedyState[()], [*, *, *, *, *, {'decoder_attention_mask': *, 'decoder_position_ids': *, 'encoder_attention_mask': None, 'encoder_outputs': CustomNode(FlaxBaseModelOutput[()], [*, None, None]),...
So it seems the second suggestion is not going to work here. Because in Jax, every tensor shape should be pre-defined before deployment while we get the vocab_size
during the deployment.
The idea here would be to run the first pass outside of the lax while loop (which we already do), then get the logits shape, then run the loop with the correct vocab size. Picking up on L730: https://github.com/huggingface/transformers/blob/9ade58f0555430cec851e307c83c3a56c4a77d0b/src/transformers/generation/flax_utils.py#L730
This would look something like:
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
if input_ids.shape[1] > 1:
state = sample_search_body_fn(state)
# now get the vocab size
vocab_size = state.logits.shape[-1]
# do the other stuff that we need to do to init the state scores
# ...
# now run the main body
if not trace:
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
else:
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
The idea here would be to run the first pass outside of the lax while loop (which we already do), then get the logits shape, then run the loop with the correct vocab size. Picking up on L730:
https://github.com/huggingface/transformers/blob/9ade58f0555430cec851e307c83c3a56c4a77d0b/src/transformers/generation/flax_utils.py#L730
This would look something like:
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU if input_ids.shape[1] > 1: state = sample_search_body_fn(state) # now get the vocab size vocab_size = state.logits.shape[-1] # do the other stuff that we need to do to init the state scores # ... # now run the main body if not trace: state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) else: state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
I implemented your suggestion by applying following changes in greedy_search_body_fn
and get the vocab_size
from the first run as follows:
def greedy_search_body_fn(state):
"""state update fn."""
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
logits = model_outputs.logits[:, -1]
# apply min_length, ...
next_tokens_scores = logits_processor(state.sequences, logits, state.cur_len)
next_token = jnp.argmax(next_tokens_scores, axis=-1)
if output_scores:
if state.scores is not None:
tokens_scores = state.scores.at[:, state.cur_len, :].set(next_tokens_scores)
else:
scores = jnp.ones((batch_size, max_length, next_tokens_scores.shape[-1])) * np.array(-1.0e7)
tokens_scores = scores.at[:, state.cur_len, :].set(next_tokens_scores)
else:
tokens_scores = None
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return GreedyState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
scores=tokens_scores,
running_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
# Besides, when output_scores is true, to return scores vocab_size of the model is got from first run.
if input_ids.shape[1] > 1 or output_scores:
state = greedy_search_body_fn(state)
@sanchit-gandhi I think this PR is ready to merge. All tests are passed. Could you please review it again and merge it?
@sanchit-gandhi I have checked out to my latest commit (b82ef360c5d819efc10298344d7d2fb4c33e1c47) and run a test as follows:
- Model_name: whisper_medium with flax inference
- GPU: A100-40GB
- Input audio: 5seconds
- transformers: git+https://github.com/huggingface/transformers.git@b82ef360c5d819efc10298344d7d2fb4c33e1c47
- Pytorch: 2.0.0
- jax: [cuda12_local] 0.4.11
A. Normal Inference (while output_scores=False
):
The model has been deployed for 5 sequence runs. Inference time is ~0.2 seconds:
model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, from_pt=True)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])
runtime=[]
for i in range(5):
start_time = time.time()
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, max_length=128, language='<|de|>', task ="transcribe")
runtime.append(time.time() - start_time)
print("Inference time: ", runtime)
print("output scores: ", scores)
result: Inference time: [57.01693844795227, 0.22632288932800293, 0.1981194019317627, 0.19892430305480957, 0.19736719131469727] output scores: None
B. Inference with confidence scores (while output_scores=True
):
The model has been deployed for 5 sequence runs. Inference time is also ~0.2 seconds:
model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, from_pt=True)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task", "output_hidden_states", "output_scores", "return_dict_in_generate"])
runtime=[]
for i in range(5):
start_time = time.time()
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, max_length=128, language='<|de|>', task ="transcribe",
output_scores=True, output_hidden_states=True, return_dict_in_generate=True)
runtime.append(time.time() - start_time)
print("Inference time: ", runtime)
print("output scores: ", scores)
result: Inference time: [82.8741066455841, 0.20504498481750488, 0.19746017456054688, 0.1972200870513916, 0.1973130702972412] output scores: [[[-10000000. -10000000. -10000000. ... -10000000. -10000000. -10000000.] [ -inf -inf -inf ... -inf -inf -inf] [ -inf -inf -inf ... -inf -inf -inf] ... [-10000000. -10000000. -10000000. ... -10000000. -10000000. -10000000.] [-10000000. -10000000. -10000000. ... -10000000. -10000000. -10000000.] [-10000000. -10000000. -10000000. ... -10000000. -10000000. -10000000.]]]
It should be also noted that the result of model inference are exactly the same. The only change is that the first run takes more when output_score=True
But next inferences are the approximately the same value.
Could you please review and merge this PR?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hello @sanchit-gandhi Sorry for late response I added test codes for output_scores feature of Flax that you asked me. But the PR is closed automatically by github-actions. My latest commit that included some tests is: https://github.com/hannan72/transformers/commit/0416becf86c65a3f32e72314715b79f5f84f52ce
Could you please re-open the PR to run new tests?
Re-opened and running the tests! Thanks @hannan72! Let me know when this is ready for a re-review
Thank you @sanchit-gandhi for re-opening this PR. I've add some tests based on what you asked me for. Could you please review my latest commit (https://github.com/huggingface/transformers/pull/22700/commits/0416becf86c65a3f32e72314715b79f5f84f52ce)?
@sanchit-gandhi Have you reviewed my added test codes?
Hey @hannan72 - yes I did! Please see the comment I left a couple of weeks ago: https://github.com/huggingface/transformers/pull/22700#discussion_r1316165469
Let me know if you need any help here @hannan72! More than happy to assist with the integration and think you're pretty close to finishing!
Let me know if you need any help here @hannan72! More than happy to assist with the integration and think you're pretty close to finishing!
@sanchit-gandhi many thanks for your reviews of the PR of @hannan72! We would need your help in finalizing the PR. As you mentioned in your comment that the HF team already tested that the ids do not change, I think it would be much easier if you extend the existing test case to show that also the scores are correct.
Quote: @sanchit-gandhi "At the moment, we've tested that output_scores=True doesn't change the ids, but not that the scores are correct"
We would appreciate your help a lot here to get this PR over the finish line since you know your test code much better than we do and for you it would be much faster. Do you think you can ask someone of your team to help to get the PR finalized and merged?
Cool to see that you're interested in this PR @teddius! I sadly won't have the bandwidth to work on this PR directly, but am more than happy to continue with PR reviews and answering any questions/queries. If @hannan72 is able to, it seems fitting that he gets the opportunity to finish the PR that he started! Otherwise, we can open this one up to the community and see if anyone is able to help here.
@sanchit-gandhi understand, many thanks for your fast reply. @hannan72 will be busy on other tasks so please feel free to open up the task for the community, so we can get some help in finishing the last test case. Many thanks for your support along the way!
Cool, sounds good @teddius! See https://github.com/huggingface/transformers/issues/22612#issuecomment-1753324050 for the community contribution request.