trax
trax copied to clipboard
[BUG] Incorrect readings of crossentrophy loss in Trax 1.3.7
Description
The crossentrophy loss in Trax 1.3.7 gave strange reading. Is this normal?
The same code was running in Trax 1.3.6 with this result.
The loss layer and evaluation metrics used are both tl.CrossEntropyLoss().
Looks like tl.CrossEntropyLoss() is depreciated: https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.metrics.CrossEntropyLoss However, it looks like tl.WeightedCategoryCrossEntropy() might also have a memory leak (at least on TPU) so hold off on switching to trax 1.3.7 for now?