ml-ane-transformers icon indicating copy to clipboard operation
ml-ane-transformers copied to clipboard

Upgrade torch and correct dim mismatch

Open KE7 opened this issue 2 years ago • 2 comments
trafficstars

KE7 avatar Apr 16 '23 01:04 KE7

+1 on this, max versions of dependencies shouldn't be specified in library code (unless library is known to be incompatible with a specific version) ...it just makes life awkward for consumers of the library

anentropic avatar Apr 18 '23 11:04 anentropic

The shape mismatch errors should be fixed in the models and not in the tests. The _register_load_state_dict_pre_hook() logic is not correct. Currently it is only applied in the base DistilBert model but that skips the pre_classifier and classifier weights added by other models. Until the model __init__s are fixed it is correct to leave the test failing.

test_distilbert.py will pass by adding a pre_hook in DistilBertForSequenceClassification's __init__ :

self._register_load_state_dict_pre_hook(linear_to_conv2d_map)

I think the same needs to be done for all the models that create pre_classifier and classifier weights. At least that appears to be the intention of linear_to_conv2d_map which expects to handle layers with classifier.weights in the name.

mikowals avatar Apr 23 '23 16:04 mikowals