transformers
transformers copied to clipboard
Update the original mapping in _LazyConfigMapping to fix AutoTokenizer registration
What does this PR do?
Currently when we want to register a new config+tokenizer+model, per the instructions, it seems we should do the following:
from transformers import AutoConfig, AutoModel
AutoConfig.register("new-model", NewModelConfig)
AutoTokenizer.register(NewModelConfig, TokenizerSlow, TokenizerFast)
AutoModel.register(NewModelConfig, NewModel)
AutoTokenizer.from_pretrained("xxx") # <--- error `Unrecognized configuration class <xxx> to build an AutoTokenizer.`
However, there is one potential bug in the current AutoTokenizer registration code:
- In https://github.com/huggingface/transformers/blob/280db2e39c1e586389df4e46f2b895fc092911bb/src/transformers/models/auto/tokenization_auto.py#L605,
AutoTokenizerwillconfig_class_to_model_typeto determine whether the corresponding config is registered in the input config. - The
config_class_to_model_typefunction checks theCONFIG_MAPPING_NAMESto find the newly register config class. https://github.com/huggingface/transformers/blob/280db2e39c1e586389df4e46f2b895fc092911bb/src/transformers/models/auto/configuration_auto.py#L438 - However, according to https://github.com/huggingface/transformers/blob/280db2e39c1e586389df4e46f2b895fc092911bb/src/transformers/models/auto/configuration_auto.py#L781 , after registering a config, the
CONFIG_MAPPINGonly updates the_extra_contentbut not the original mapping orCONFIG_MAPPING_NAMESin this case https://github.com/huggingface/transformers/blob/280db2e39c1e586389df4e46f2b895fc092911bb/src/transformers/models/auto/configuration_auto.py#L492 . That is to say, theconfig_class_to_model_typecannot find the newly registered config in this case, and will throw an errorUnrecognized configuration class <xxx> to build an AutoTokenizer.
A temporary local hot fix can be:
from transformers import AutoConfig, AutoModel
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
AutoConfig.register("new-model", NewModelConfig)
CONFIG_MAPPING_NAMES["new-model"] = NewModelConfig.__name__
AutoTokenizer.register(NewModelConfig, TokenizerSlow, TokenizerFast)
AutoModel.register(NewModelConfig, NewModel)
But thought it would be better to fix it upstream.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
@n1t0, @LysandreJik
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
cc @sgugger
Do you have a full example of the error you are reporting I could run? I am unable to reproduce it. Something like the test of this feature we could investigate more.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.