allennlp icon indicating copy to clipboard operation
allennlp copied to clipboard

Don't cache reinit_modules

Open JohnGiorgi opened this issue 2 years ago • 4 comments

Fixes https://github.com/allenai/allennlp/pull/5505#issuecomment-1007540627

Changes proposed in this pull request:

  • Don't cache transformers when reinit_modules is provided.
  • Removes reinit_modules from the transformer spec
  • Always load a new model when reinit_modules is not None

Before submitting

  • [X] I've read and followed all steps in the Making a pull request section of the CONTRIBUTING docs.
  • [X] I've updated or added any relevant docstrings following the syntax described in the Writing docstrings section of the CONTRIBUTING docs.
  • [ ] If this PR fixes a bug, I've added a test that will fail without my fix.
  • [ ] If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.

After submitting

  • [ ] All GitHub Actions jobs for my pull request have passed.
  • [ ] codecov/patch reports high test coverage (at least 90%). You can find this under the "Actions" tab of the pull request once the other checks have finished.

JohnGiorgi avatar Jan 16 '22 20:01 JohnGiorgi

I think one of @dirkgr's points was that we don't want to add a reinitialized transformer to the cache. So maybe only run _model_cache[spec] = transformer when reinit_modules is None.

Oh right, good catch. I have fixed this.

JohnGiorgi avatar Jan 18 '22 19:01 JohnGiorgi

The way reinit should work is that if reinit is specified, it loads the model first using the normal cached_transformers code path, with make_copy = True, and then re-inits the layers it needs. If you already have the weights loaded (which happens a lot in AllenNLP), then it'll be much faster the second time.

I would almost say that we should have an entirely separate function that reinits some layers from a given transformer model. It doesn't have to be part of cached_transformers at all. But I don't know how you're using this downstream, so maybe that's not practical.

dirkgr avatar Jan 18 '22 19:01 dirkgr

@dirkgr I had originally added this functionality to PretrainedTransformerEmbedder, but then I figured it made sense to move it to cached_transformers, so any part of the library that loads a transformer could reinit some layers/modules. Is there somewhere else in the library it could live that makes more sense? If it is its own function, how would a user access it, e.g. in a training config.

JohnGiorgi avatar Jan 20 '22 22:01 JohnGiorgi

I'd say for ease of use, just have it in the cached_transformers module as well. While strictly speaking it doesn't have to be used with cached transformers, I think in practice that's what's going to happen.

dirkgr avatar Feb 24 '22 21:02 dirkgr