added GPTNeoForTokenClassification
What does this PR do?
It adds the class GPTNeoForTokenClassification, which allows using GPT Neo models for token classification tasks. The implementation follows the one for other models (such as GPT2) closely and simply adds a linear layer after the hidden states.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [X] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada
The documentation is not available anymore as the PR was closed or merged.
Hey! Could you make sure the CI tests are green? Can review then!
@ArthurZucker Sure. I'm getting the hang of it. Now, the only failing tests are connected to flax and seem unrelated to this pull request.
If the flax errors are not due to the PR, this is ready to be reviewed, @ArthurZucker and @younesbelkada :-)
I just checked the logs for the remaining errors one more time. The errors are related to the import of the optax library, where jax.Array is used in a type. Apparently there is no name "Array" in the top-level namespace of the jax module.
I cannot see how this could be related to my PR.
The jax version used in the examples_flax test is 0.3.6: Collecting jax!=0.3.2,<=0.3.6,>=0.2.8 (from transformers==4.28.0.dev0) Using cached jax-0.3.6-py3-none-any.whl This version clearly has no Array class. I am unsure why such an old version should be used?
Figured out that optax <= 0.1.4 is needed. And found out that upstream/main has that change already 👍 Now everything should be cleared for review.
Definitely ready for review, @ArthurZucker and @younesbelkada :-)
Cool! Reviewing now
All done and ready to be merged, @ArthurZucker and @younesbelkada 👍
I implemented the same change as for GPTNeoXForTokenClassification, i.e., I removed the hasattr etc. and just use config.classifier_dropout directly.
@sgugger Ready to merge when the checks complete. Thanks for the fast action 👍
... and more to come in the next weeks!