mlr3torch icon indicating copy to clipboard operation
mlr3torch copied to clipboard

Add tensorboard callback

Open sebffischer opened this issue 1 year ago • 0 comments

The goal here is to add a callback that logs the training and the validation error (if the latter exists) and is then displayed in the browser via tensorboard. The R package for tensorboard can be found here: https://github.com/mlverse/tfevents.

For the first implementation, we can use the torch_callback helper function as defined here: https://mlr3torch.mlr-org.com/articles/callbacks.html

Once everything is working as expected, we can move away from this "syntactic sugar" and implemented it directly as an R6 Class has e.g. here: https://github.com/mlr-org/mlr3torch/blob/main/R/CallbackSetProgress.R. This is necessary to generate the proper documentation for the class.

The training and validation loss can be accessed via those two fields from the torch context: https://github.com/mlr-org/mlr3torch/blob/fef4cdb4dafa8f725ecf83acac11ef10b3d4d6ae/R/ContextTorch.R#L59-L60

The validation loss is only present when a validation task is set, so we need to handle both cases. Another open question is which configuration options we want for the callback. We can also look a bit how it is implemented in keras: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TensorBoard although we don't need to get the whole feature set, at least not in the first iteration.

sebffischer avatar Jun 14 '24 13:06 sebffischer