transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add TensorFlow Wav2Vec2 for sequence classification

Open nandwalritik opened this issue 2 years ago • 12 comments

What does this PR do?

Fixes # (issue)

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. #21778
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

@sanchit-gandhi

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

nandwalritik avatar Mar 10 '23 06:03 nandwalritik

The documentation is not available anymore as the PR was closed or merged.

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 09 '23 15:04 github-actions[bot]

Kindly ping @sanchit-gandhi and adding @Rocketknight1 for the TensorFlow side.

sgugger avatar Apr 10 '23 13:04 sgugger

Hi @nandwalritik, and sorry for the extremely long delay in catching this! Ordinarily one of the TF maintainers reviews TF pull requests, but this one slipped through the cracks somehow. If you want to file TF PRs in future, you can directly ping me or @gante to make sure that we don't miss it.

This PR actually looks almost perfect, but there are a couple of TF-specific details that are causing some tests to fail. I'll mark them in a code review in just a sec, but they shouldn't take too long to fix. Thanks again for submitting this!

Rocketknight1 avatar Apr 11 '23 14:04 Rocketknight1

for serving and serving_output methods I added changes, but now sure they are correct or not.

nandwalritik avatar Apr 13 '23 06:04 nandwalritik

Hi @nandwalritik, I'm seeing the issue when you move it to build() - the problem is the weight name, as it usually is in our TensorFlow ports! TF isn't very consistent about the name scope used for weights, and it can differ depending on when the weight is created in the init, the build or lazily in the call(), which makes it tricky because we use the names to match weights between PT and TF models.

I'll see if I can push a solution to your repo, hang on.

Rocketknight1 avatar Apr 14 '23 16:04 Rocketknight1

Ok

nandwalritik avatar Apr 14 '23 16:04 nandwalritik

Try:

with tf.name_scope(self._name_scope()):
    self.layer_weights = self.add_weight(
        shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
    )

in the __init__, not the build(). I know that contradicts what I said earlier, but it turns out to be a bit different for a base model class than a sublayer.

I also see a couple of other errors - you can see them by clicking the Details beside tests_tf in the checklist at the bottom of this PR. If you can't figure out what's causing them, ping me over the weekend or on Monday and I'll try to debug them!

Rocketknight1 avatar Apr 14 '23 18:04 Rocketknight1

Try:

with tf.name_scope(self._name_scope()):
    self.layer_weights = self.add_weight(
        shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
    )

in the __init__, not the build(). I know that contradicts what I said earlier, but it turns out to be a bit different for a base model class than a sublayer.

I also see a couple of other errors - you can see them by clicking the Details beside tests_tf in the checklist at the bottom of this PR. If you can't figure out what's causing them, ping me over the weekend or on Monday and I'll try to debug them!

Ok, so after adding this change, the weights are getting loaded without any warning or error, but the output of pytorch and tensorflow model doesn't have rtol of 1e-5. Although I checked shape and absolute sum of tensors of both the models they are almost equal

PT model 
1,292,768 -> 29877.8750


1,292,256 -> 29711.7109

pooled_output
1,256 -> 38.7491



TF model

hidden_state
1,292,768 -> 29877.879

1,292,256 -> 29711.715

pooled_output
1,256 -> 38.811996

What should i try next to satisfy rtol criteria.

nandwalritik avatar Apr 17 '23 04:04 nandwalritik

Hm, those are some fairly large discrepancies! The debugging process we recommend when something like that happens is:

  • Make a test environment and load the PT and TF models with the same weights
  • Try to isolate the earliest point where the model outputs diverge. You can use options like output_hidden_states to get the model to return all hidden states, not just the final ones.
  • Once you find the first point of divergence, try to see if you can dig into the layer where the divergence happened. You can place breakpoints, or extract sublayers and try passing test inputs into them.
  • Eventually you will find the single specific place where the divergence creeps in - now you can check what the cause is. Make sure the weights for that operation really do match between the two frameworks, and make sure both frameworks are doing the same thing at that point.

As always, if you can't figure it out, let me know! This kind of work can be quite gruelling, but we really appreciate the work you're doing on the model port.

Rocketknight1 avatar Apr 17 '23 16:04 Rocketknight1

Hi @Rocketknight1 I added test cases and fixed the feed forward part, but the CI is failing due to flax, I think this might not be related to my changes. Please review the PR and let me know if any more changes are required.

nandwalritik avatar Apr 21 '23 06:04 nandwalritik

Yep, those flax issues are unrelated, just ignore them. I'll review everything today, but the CI looks good!

Rocketknight1 avatar Apr 21 '23 13:04 Rocketknight1

@sanchit-gandhi @Rocketknight1 let me know if any more changes are required or else can you guys get this pr merged.

nandwalritik avatar Apr 26 '23 04:04 nandwalritik

Just looked over the last few changes - I'm happy to merge it at this point. Thanks again for putting in the work on this!

Rocketknight1 avatar Apr 26 '23 12:04 Rocketknight1