mlr3torch
mlr3torch copied to clipboard
implement plotter for history state in mlr3viz
old code:
#' @description Plots the history.
#' @param measures (`character()`)\cr
#' Which measures to plot. No default.
#' @param set (`character(1)`)\cr
#' Which set to plot. Either `"train"` or `"valid"`. Default is `"valid"`.
#' @param epochs (`integer()`)\cr
#' An integer vector restricting which epochs to plot. Default is `NULL`, which plots all epochs.
#' @param theme ([ggplot2::theme()])\cr
#' The theme, [ggplot2::theme_minimal()] is the default.
#' @param ... (any)\cr
#' Currently unused.
plot = function(measures, set = "valid", epochs = NULL, theme = ggplot2::theme_minimal(), ...) {
assert_choice(set, c("valid", "train"))
data = self[[set]]
assert_subset(measures, colnames(data))
if (is.null(epochs)) {
data = data[, c("epoch", measures), with = FALSE]
} else {
assert_integerish(epochs, unique = TRUE)
data = data[get("epoch") %in% epochs, c("epoch", measures), with = FALSE]
}
if ((!nrow(data)) || (ncol(data) < 2)) {
stopf("No eligible measures to plot for set '%s'.", set)
}
epoch = score = measure = .data = NULL
if (ncol(data) == 2L) {
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = .data[[measures]])) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::labs(
x = "Epoch",
y = measures,
title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
) +
theme
} else {
data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score")
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = score, color = measure)) +
viridis::scale_color_viridis(discrete = TRUE) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::labs(
x = "Epoch",
y = "Score",
title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
) +
theme
}
this should dispatch on LearnerTorch
It's simple enough to do this oneself