mlr3torch icon indicating copy to clipboard operation
mlr3torch copied to clipboard

implement plotter for history state in mlr3viz

Open sebffischer opened this issue 1 year ago • 2 comments

sebffischer avatar Jun 13 '24 07:06 sebffischer

old code:

   #' @description Plots the history.
    #' @param measures (`character()`)\cr
    #'   Which measures to plot. No default.
    #' @param set (`character(1)`)\cr
    #'   Which set to plot. Either `"train"` or `"valid"`. Default is `"valid"`.
    #' @param epochs (`integer()`)\cr
    #'   An integer vector restricting which epochs to plot. Default is `NULL`, which plots all epochs.
    #' @param theme ([ggplot2::theme()])\cr
    #'   The theme, [ggplot2::theme_minimal()] is the default.
    #' @param ... (any)\cr
    #'   Currently unused.
    plot = function(measures, set = "valid", epochs = NULL, theme = ggplot2::theme_minimal(), ...) {
      assert_choice(set, c("valid", "train"))
      data = self[[set]]
      assert_subset(measures, colnames(data))

      if (is.null(epochs)) {
        data = data[, c("epoch", measures), with = FALSE]
      } else {
        assert_integerish(epochs, unique = TRUE)
        data = data[get("epoch") %in% epochs, c("epoch", measures), with = FALSE]
      }

      if ((!nrow(data)) || (ncol(data) < 2)) {
        stopf("No eligible measures to plot for set '%s'.", set)
      }

      epoch = score = measure = .data = NULL
      if (ncol(data) == 2L) {
        ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = .data[[measures]])) +
          ggplot2::geom_line() +
          ggplot2::geom_point() +
          ggplot2::labs(
            x = "Epoch",
            y = measures,
            title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
          ) +
          theme
      } else {
        data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score")
        ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = score, color = measure)) +
          viridis::scale_color_viridis(discrete = TRUE) +
          ggplot2::geom_line() +
          ggplot2::geom_point() +
          ggplot2::labs(
            x = "Epoch",
            y = "Score",
            title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
          ) +
          theme
      }

sebffischer avatar Jun 13 '24 07:06 sebffischer

this should dispatch on LearnerTorch

sebffischer avatar Jun 14 '24 13:06 sebffischer

It's simple enough to do this oneself

sebffischer avatar Feb 06 '25 09:02 sebffischer