trax
trax copied to clipboard
BERT model fails on its initialization
Description
PretrainedBERT model fails on its initialization.
Environment information
OS: macOS Big Sur 11.4
$ pip freeze | grep trax
trax==1.3.9
$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.3.0
tensorflow-estimator==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.0.0
tensorflow-text==2.5.0
$ pip freeze | grep jax
jax==0.2.13
jaxlib==0.1.67
$ python -V
Python 3.8.10
For bugs: reproduction and error logs
# Steps to reproduce:
import trax
trax.models.bert.BERT(init_checkpoint="bert-base-uncased")
# Error logs:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/manifest/essential/utils/python/env/lib/python3.8/site-packages/trax/models/research/bert.py", line 160, in BERT
bert = PretrainedBERT(
File "/Users/manifest/essential/utils/python/env/lib/python3.8/site-packages/trax/models/research/bert.py", line 178, in __init__
self.init_checkpoint = None
File "/Users/manifest/essential/utils/python/env/lib/python3.8/site-packages/trax/layers/base.py", line 703, in __setattr__
raise ValueError(
ValueError: Trax layers only allow to set ('weights', 'state', 'rng') as public attribues, not init_checkpoint.
In the PR above, I've overridden the _settable_attrs function of the PretrainedBERT to allow setting init_checkpoint attribute required for loading the model from its checkpoints.