ml-ane-transformers
ml-ane-transformers copied to clipboard
Upgrade torch and correct dim mismatch
+1 on this, max versions of dependencies shouldn't be specified in library code (unless library is known to be incompatible with a specific version) ...it just makes life awkward for consumers of the library
The shape mismatch errors should be fixed in the models and not in the tests. The _register_load_state_dict_pre_hook() logic is not correct. Currently it is only applied in the base DistilBert model but that skips the pre_classifier and classifier weights added by other models. Until the model __init__s are fixed it is correct to leave the test failing.
test_distilbert.py will pass by adding a pre_hook in DistilBertForSequenceClassification's __init__ :
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
I think the same needs to be done for all the models that create pre_classifier and classifier weights. At least that appears to be the intention of linear_to_conv2d_map which expects to handle layers with classifier.weights in the name.