mlr3temporal icon indicating copy to clipboard operation
mlr3temporal copied to clipboard

Proposal to add "forecast.persistence" learner for seasonal persistence prediction

Open MarkusLeyser opened this issue 6 months ago • 3 comments

Description: I am interested in contributing to the mlr3temporal package and would like to propose the addition of a simple "forecast.persistence" learner. (This would be my first contribution to the mlr3 project. 🙂)

Motivation and Idea: I am using mlr3 for a benchmarking study on the prediction of photovoltaic power output. In the context of photovoltaics, seasonal persistence prediction is a common (naive) baseline for evaluating forecasting models. This approach uses observations from the latest seasonal cycle for prediction. In contrast, regular persistence prediction would forecast the last available value. For example, in my case, a seasonal persistence prediction would correspond to the photovoltaic power output from 24 hours ago. While this approach looks simple, it can be surprisingly hard to beat.

Both regular and seasonal persistence predictions are valuable baselines for many time series modeling applications. Therefore, I see a benefit in providing a convenient and simple "forecast.persistence" learner in mlr3temporal.

Implementation: There are two viable options for implementing this simple learner:

  1. Using the data.table package:

    • Pros: No additional dependencies; core functionality can be realized using shift().
    • Cons: Limited functionality.
  2. Using the forecast package:

    • Pros: The (seasonal) persistence prediction can be represented as a suitably specified ARIMA model (seasonal differencing of order 1 with the appropriate period). It also allows for further functionality expansion, in particular with forecast::findfrequency(x), which determines the dominant frequency of a time series.
    • Cons: Requires a package that is not distributed with R itself. Considering that an ARIMA learner (forecast.arima) is already implemented in mlr3temporal, I personally believe that option 2 is still viable.

In both cases, it is necessary to define how the learner behaves when a complete period is not present in the training data, i.e., when the lag order is greater than the length of the data. Hence, the existing forecast.arima learner is not sufficient.

Request: I have attached a few lines of concept code for reference (modified versions of forecast.arima and forecast.average, respectively). I would appreciate it if you could provide feedback on whether such a contribution would be welcomed by the maintainers.

Thank you for your time.

LearnerRegrForecastPersistenceShift:

#' @title Lag-based Persistence Forecast Learner
#'
#' @name mlr_learners_regr.persistence_shift
#'
#' @description
#' Persistence model using data.table::shift()
#' Calls [data.table::shift] from package \CRANpkg{data.table} to return a (seasonal) persistence forecast.
#'
#' @templateVar id forecast.persistence_shift
#' @template learner
#'
#' @template seealso_learner
#' @export
#' @template example
LearnerRegrForecastPersistenceShift = R6::R6Class("LearnerRegrForecastPersistenceShift",
                                                  inherit = LearnerForecast,
                                                  
                                                  public = list(
                                                    
                                                    #' @description
                                                    #' Creates a new instance of this [R6][R6::R6Class] class.
                                                    initialize = function() {
                                                      ps = ps(
                                                        period = p_int(default = 1, lower = 1, tags = "train")
                                                      )
                                                      
                                                      super$initialize(
                                                        id = "forecast.persistence_shift",
                                                        feature_types = "numeric",
                                                        predict_types = c("response"),
                                                        packages = "forecast",
                                                        param_set = ps,
                                                        properties = c("univariate"),
                                                        man = "mlr3temporal::mlr_learners_regr.persistence_shift"
                                                      )
                                                    },
                                                    
                                                    #' @description
                                                    #' Returns forecasts after the last training instance.
                                                    #'
                                                    #' @param h (`numeric(1)`)\cr
                                                    #'   Number of steps ahead to forecast. Default is 10.
                                                    #'
                                                    #' @param task ([Task]).
                                                    #'
                                                    #' @param newdata ([data.frame()])\cr
                                                    #'   Ignored
                                                    #'
                                                    #' @return [Prediction].
                                                    forecast = function(h = 10, task, newdata = NULL) {
                                                      h = assert_int(h, lower = 1, coerce = TRUE)
                                                      indices = 1:h%%length(self$model$last_values)
                                                      indices[indices == 0] = length(self$model$last_values)
                                                      forecast = self$model$last_values[indices]
                                                      response = as.data.table(forecast)
                                                      colnames(response) = task$target_names
                                                      
                                                      truth = copy(response)
                                                      truth[, colnames(truth) := 0]
                                                      p = PredictionForecast$new(task,
                                                                                 response = response, truth = truth,
                                                                                 row_ids = (self$date_span$end$row_id + 1):(self$date_span$end$row_id + h)
                                                      )
                                                    }
                                                  ),
                                                  
                                                  private = list(
                                                    .train = function(task) {
                                                      span = range(task$date()[[task$date_col]])
                                                      self$date_span = list(
                                                        begin = list(time = span[1], row_id = task$row_ids[1]),
                                                        end = list(time = span[2], row_id = task$row_ids[task$nrow])
                                                      )
                                                      pv = self$param_set$get_values(tags = "train")
                                                      period = ifelse(is.null(pv$period), self$param_set$default$period, pv$period)
                                                      x = task$data(cols = task$target_names)[[1L]]
                                                      mean_x = mean(x)
                                                      last_values = c(
                                                        x[max((task$nrow-period+1),1):task$nrow], #already present values
                                                        rep(mean_x, times = max(0, period-task$nrow))#fill up with mean of the data, if period is longer than training data
                                                      )
                                                      list(
                                                        "fill_value" = mean_x,
                                                        "period" = period,
                                                        "fitted" = data.table::shift(x, n = period, type = "lag", fill = mean_x),
                                                        "last_values" = last_values
                                                      )
                                                    },
                                                    
                                                    .predict = function(task) {
                                                      all_values = c(self$model$fitted, self$model$last_values)
                                                      indices = task$row_ids%%length(all_values)
                                                      indices[indices == 0] = length(all_values)
                                                      response = all_values[indices]
                                                      list("response" = response)
                                                    }
                                                  )
)

#' @include aaa.R
learners[["forecast.persistence_shift"]] = LearnerRegrForecastPersistenceShift

LearnerRegrForecastPersistenceArima:

#' @title Arima-based Persistence Forecast Learner
#'
#' @name mlr_learners_regr.persistence_arima
#'
#' @description
#' Persistence model as an ARIMA model
#' Calls [forecast::Arima] from package \CRANpkg{forecast} with suitable parameter values to return a (seasonal) persistence forecast.
#'
#' @templateVar id forecast.persistence_arima
#' @template learner
#'
#' @template seealso_learner
#' @export
#' @template example
LearnerRegrForecastPersistenceArima = R6::R6Class("LearnerRegrForecastPersistenceArima",
                                       inherit = LearnerForecast,
                                       
                                       public = list(
                                         
                                         #' @description
                                         #' Creates a new instance of this [R6][R6::R6Class] class.
                                         initialize = function() {
                                           ps = ps(
                                             period = p_int(default = 1, lower = 1, tags = "train")
                                           )
                                           
                                           super$initialize(
                                             id = "forecast.persistence_arima",
                                             feature_types = "numeric",
                                             predict_types = c("response", "se"),
                                             packages = "forecast",
                                             param_set = ps,
                                             properties = c("univariate"),
                                             man = "mlr3temporal::mlr_learners_regr.persistence_arima"
                                           )
                                         },
                                         
                                         #' @description
                                         #' Returns forecasts after the last training instance.
                                         #'
                                         #' @param h (`numeric(1)`)\cr
                                         #'   Number of steps ahead to forecast. Default is 10.
                                         #'
                                         #' @param task ([Task]).
                                         #'
                                         #' @param newdata ([data.frame()])\cr
                                         #'   New data to predict on.
                                         #'
                                         #' @return [Prediction].
                                         forecast = function(h = 10, task, newdata = NULL) {
                                           h = assert_int(h, lower = 1, coerce = TRUE)
                                           if (length(task$feature_names) > 0) {
                                             newdata = as.matrix(newdata)
                                             forecast = invoke(forecast::forecast, self$model, xreg = newdata)
                                           } else {
                                             forecast = invoke(forecast::forecast, self$model, h = h)
                                           }
                                           response = as.data.table(as.numeric(forecast$mean))
                                           colnames(response) = task$target_names
                                           
                                           se = as.data.table(as.numeric(
                                             ci_to_se(width = forecast$upper[, 1] - forecast$lower[, 1], level = forecast$level[1])
                                           ))
                                           colnames(se) = task$target_names
                                           
                                           truth = copy(response)
                                           truth[, colnames(truth) := 0]
                                           p = PredictionForecast$new(task,
                                                                      response = response, se = se, truth = truth,
                                                                      row_ids = (self$date_span$end$row_id + 1):(self$date_span$end$row_id + h)
                                           )
                                         }
                                       ),
                                       
                                       private = list(
                                         .train = function(task) {
                                           span = range(task$date()[[task$date_col]])
                                           self$date_span = list(
                                             begin = list(time = span[1], row_id = task$row_ids[1]),
                                             end = list(time = span[2], row_id = task$row_ids[task$nrow])
                                           )
                                           pv = self$param_set$get_values(tags = "train")
                                           seasonal = list(order = c(0L, 1L, 0L), period = pv$period)
                                           invoke(forecast::Arima, 
                                                  y = task$data(
                                                    rows = task$row_ids,
                                                    cols = task$target_names
                                                  ), 
                                                  seasonal = seasonal,
                                                  include.mean = FALSE)
                                         },
                                         
                                         .predict = function(task) {
                                           se = NULL
                                           fitted_ids = task$row_ids[task$row_ids <= self$date_span$end$row_id]
                                           predict_ids = setdiff(task$row_ids, fitted_ids)
                                           
                                           if (length(predict_ids) > 0) {
                                             if (length(task$feature_names) > 0) {
                                               newdata = as.matrix(task$data(cols = task$feature_names, rows = predict_ids))
                                               response_predict = invoke(forecast::forecast, self$model, xreg = newdata)
                                             } else {
                                               response_predict = invoke(forecast::forecast, self$model, h = length(predict_ids))
                                             }
                                             
                                             predict_mean = as.data.table(as.numeric(response_predict$mean))
                                             colnames(predict_mean) = task$target_names
                                             fitted.mean = self$fitted_values(fitted_ids)
                                             colnames(fitted.mean) = task$target_names
                                             response = rbind(fitted.mean, predict_mean)
                                             if (self$predict_type == "se") {
                                               predict_se = as.data.table(as.numeric(
                                                 ci_to_se(width = response_predict$upper[, 1] - response_predict$lower[, 1],
                                                          level = response_predict$level[1])
                                               ))
                                               colnames(predict_se) = task$target_names
                                               fitted_se = as.data.table(
                                                 sapply(task$target_names, function(x) rep(sqrt(self$model$sigma2), length(fitted_ids)), simplify = FALSE)
                                               )
                                               se = rbind(fitted_se, predict_se)
                                             }
                                           } else {
                                             response = self$fitted_values(fitted_ids)
                                             if (self$predict_type == "se") {
                                               se = as.data.table(
                                                 sapply(task$target_names, function(x) rep(sqrt(self$model$sigma2), length(fitted_ids)), simplify = FALSE)
                                               )
                                             }
                                           }
                                           
                                           list(response = response, se = se)
                                         }
                                       )
)

#' @include aaa.R
learners[["forecast.persistence_arima"]] = LearnerRegrForecastPersistenceArima

Example using the included airpassengers task:

library(mlr3verse)
library(mlr3temporal)
library(R6)
library(mlr3misc)
library(checkmate)
library(data.table)

source("aaa.R")
source("helper.R")
source("LearnerRegrForecastPersistenceArima.R")
source("LearnerRegrForecastPersistenceShift.R")
source("zzz.R")


task <- tsk("airpassengers")

l_arima <- LearnerRegrForecastPersistenceArima$new()
l_arima$param_set$set_values(period = 10)
l_arima$train(task, row_ids = 1:134)
l_arima$predict(task, row_ids = 135:144)
(forecast_arima <- l_arima$forecast(h = 40, task = task))

l_shift <- LearnerRegrForecastPersistenceLag$new()
l_shift$param_set$set_values(period = 10)
l_shift$train(task, row_ids = 1:134)
l_shift$predict(task, row_ids = 135:144)
(forecast_shift <- l_shift$forecast(h = 40, task = task))

library(ggplot2)
ggplot() +
  geom_line(aes(x = task$row_ids, y = task$data()$target)) +
  geom_line(data = as.data.table(forecast_arima)[,c(1,3)], aes(x = row_ids, y = target), linetype = "dashed") +
  labs(x = "row_ids", y = "target", caption = "dashed: prediction")

MarkusLeyser avatar Aug 05 '24 11:08 MarkusLeyser