trax
trax copied to clipboard
nan CrossEntropyLossWithLogSoftmax while training an NMT Reformer
Description
I'm training a Reformer-based NMT model, the code is pretty much identical to https://github.com/google/trax/blob/283cbda9cb87f4a25a952d4c302aedfe54a65850/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb with a custom dataset. The model itself looks like this:
model = trax.models.Reformer(
input_vocab_size=39901,
d_model=512, d_ff=2048, dropout=0.1,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=512, mode='train')
after training for 10000 epochs, the Loss coverages to nan:
Step 11000: Ran 1000 train steps in 1182.59 secs Step 11000: train CrossEntropyLossWithLogSoftmax | nan Step 11000: eval CrossEntropyLossWithLogSoftmax | nan Step 11000: eval WeightedCategoryAccuracy | 0.00000000
Any idea what would the reason be?
...
Environment information
AWS g3.8xlarge / 2 Tesla M60 GPUs running ubuntu 18
OS: Ubuntu 18.04
$ pip freeze | grep trax
trax==1.4.1
$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensor2tensor==1.15.7
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.4
tensorflow-addons==0.14.0
tensorflow-datasets==4.4.0
tensorflow-estimator==2.4.0
tensorflow-gan==2.1.0
tensorflow-gpu==2.4.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.4.0
tensorflow-probability==0.7.0
tensorflow-text==2.4.1
$ pip freeze | grep jax
jax==0.2.24
jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.70+cuda110-cp37-none-manylinux2010_x86_64.whl
$ python -V
Python 3.7.12
For bugs: reproduction and error logs
# Steps to reproduce:
...
# Error logs:
...