qadence
qadence copied to clipboard
[Refactoring] Alternative ways of defining tensorboard metrics
At the moment, train
accepts a function write_tensorboard(writer, loss, metrics, iteration)
, where metrics is returned from a custom loss function.
However, we might want to log non-training related metrics, e.g. logging a plot, or logging a mean squared error to a known benchmark. These sorts of metrics are a) potentially expensive to calculate at every training step and b) conceptually independent to the training loop.
It would be cool if there was an alternative way of logging custom metrics etc. to tensorboard without having to do so via the loss function.
I think the most general way of implementing this would be to redefine our train_with_grad
function and do:
# outer epoch loop
for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
try:
# in case there is not data needed by the model
# this is the case, for example, of quantum models
# which do not have classical input data (e.g. chemistry)
if dataloader is None:
loss, metrics = optimize_step(model, optimizer, loss_fn, None)
loss = loss.item()
elif isinstance(dataloader, (DictDataLoader, DataLoader)):
data = data_to_device(next(dl_iter), device) # type: ignore[arg-type]
loss, metrics = optimize_step(model, optimizer, loss_fn, data)
else:
raise NotImplementedError(
f"Unsupported dataloader type: {type(dataloader)}. "
"You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader."
)
iteration_callback()
except KeyboardInterrupt:
print("Terminating training gracefully after the current iteration.")
break
# Final writing and checkpointing
final_callback()
return model, optimizer
instead of whats currently being done with "hardcoded" functions. The hardcoded functions should be called in the default iteration_callback
/final_callback
and we need a wait to nicely construct one callback from a list of callbacks.
@DanieleCucurachi FYI!
Closes with #533