transformers icon indicating copy to clipboard operation
transformers copied to clipboard

added GPTNeoForTokenClassification

Open peter-sk opened this issue 2 years ago • 10 comments

What does this PR do?

It adds the class GPTNeoForTokenClassification, which allows using GPT Neo models for token classification tasks. The implementation follows the one for other models (such as GPT2) closely and simply adds a linear layer after the hidden states.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [X] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@ArthurZucker @younesbelkada

peter-sk avatar Apr 21 '23 06:04 peter-sk

The documentation is not available anymore as the PR was closed or merged.

Hey! Could you make sure the CI tests are green? Can review then!

ArthurZucker avatar Apr 21 '23 14:04 ArthurZucker

@ArthurZucker Sure. I'm getting the hang of it. Now, the only failing tests are connected to flax and seem unrelated to this pull request.

peter-sk avatar Apr 22 '23 02:04 peter-sk

If the flax errors are not due to the PR, this is ready to be reviewed, @ArthurZucker and @younesbelkada :-)

peter-sk avatar Apr 22 '23 22:04 peter-sk

I just checked the logs for the remaining errors one more time. The errors are related to the import of the optax library, where jax.Array is used in a type. Apparently there is no name "Array" in the top-level namespace of the jax module.

I cannot see how this could be related to my PR.

peter-sk avatar Apr 23 '23 06:04 peter-sk

The jax version used in the examples_flax test is 0.3.6: Collecting jax!=0.3.2,<=0.3.6,>=0.2.8 (from transformers==4.28.0.dev0) Using cached jax-0.3.6-py3-none-any.whl This version clearly has no Array class. I am unsure why such an old version should be used?

peter-sk avatar Apr 23 '23 19:04 peter-sk

Figured out that optax <= 0.1.4 is needed. And found out that upstream/main has that change already 👍 Now everything should be cleared for review.

peter-sk avatar Apr 23 '23 20:04 peter-sk

Definitely ready for review, @ArthurZucker and @younesbelkada :-)

peter-sk avatar Apr 24 '23 21:04 peter-sk

Cool! Reviewing now

ArthurZucker avatar Apr 25 '23 07:04 ArthurZucker

All done and ready to be merged, @ArthurZucker and @younesbelkada 👍

peter-sk avatar Apr 26 '23 00:04 peter-sk

I implemented the same change as for GPTNeoXForTokenClassification, i.e., I removed the hasattr etc. and just use config.classifier_dropout directly.

peter-sk avatar Apr 27 '23 14:04 peter-sk

@sgugger Ready to merge when the checks complete. Thanks for the fast action 👍

... and more to come in the next weeks!

peter-sk avatar Apr 27 '23 15:04 peter-sk