trax
trax copied to clipboard
Custom logging of metrics during training loop?
How can I get access to current metrics in a training loop? I want to do custom logging while the model runs - specifically I want to log to Weights and Biases. Strangely I can't find a nice way to access things like the current loss, so what I assumed was an easy plan - just use a callback - doesn't seem to be working out. I've had a little look through the source of Loop
and there doesn't seem to be a way to do this. This is sufficiently surprising that I'm fairly convinced that I'm simply missing something obvious.
I'm about to try patching Loop._log_training_progress
to get access to the info, but surely this cannot be the expected way!
I'm not quite sure what this issue is - if there is a way to do this then it's a question; if there isn't then it's a feature request for a new callback!
P.S. This should probably be its own whole issue, but having looked through the Loop
source it looks like there isn't a way to do early stopping either, which is going to be the thing that annoys me in about 15 minutes. I have a horrible feeling that I'm going to have to raise an exception in a callback for that, at which point I'll probably give up on using Trax's training and write my own loop from scratch.
@SamPruden can I ask how you went with this? I'm about to look into a workaround for this myself.
@SamPruden can I ask how you went with this? I'm about to look into a workaround for this myself.
Bad news I'm afraid. When I submitted this I was experimenting with the various options in the Jax ecosystem, and little things like this pushed me away from Trax. I just moved over to Haiku and Optax. Trax's API and the small amount of code required to do complex things was appealing, but a little too restrictive for my level of experimentation.