mlr3temporal
mlr3temporal copied to clipboard
Proposal to add "forecast.persistence" learner for seasonal persistence prediction
Description:
I am interested in contributing to the mlr3temporal
package and would like to propose the addition of a simple "forecast.persistence" learner. (This would be my first contribution to the mlr3 project. 🙂)
Motivation and Idea: I am using mlr3 for a benchmarking study on the prediction of photovoltaic power output. In the context of photovoltaics, seasonal persistence prediction is a common (naive) baseline for evaluating forecasting models. This approach uses observations from the latest seasonal cycle for prediction. In contrast, regular persistence prediction would forecast the last available value. For example, in my case, a seasonal persistence prediction would correspond to the photovoltaic power output from 24 hours ago. While this approach looks simple, it can be surprisingly hard to beat.
Both regular and seasonal persistence predictions are valuable baselines for many time series modeling applications. Therefore, I see a benefit in providing a convenient and simple "forecast.persistence" learner in mlr3temporal
.
Implementation: There are two viable options for implementing this simple learner:
-
Using the
data.table
package:-
Pros: No additional dependencies; core functionality can be realized using
shift()
. - Cons: Limited functionality.
-
Pros: No additional dependencies; core functionality can be realized using
-
Using the
forecast
package:-
Pros: The (seasonal) persistence prediction can be represented as a suitably specified ARIMA model (seasonal differencing of order 1 with the appropriate period). It also allows for further functionality expansion, in particular with
forecast::findfrequency(x)
, which determines the dominant frequency of a time series. -
Cons: Requires a package that is not distributed with R itself. Considering that an ARIMA learner (
forecast.arima
) is already implemented inmlr3temporal
, I personally believe that option 2 is still viable.
-
Pros: The (seasonal) persistence prediction can be represented as a suitably specified ARIMA model (seasonal differencing of order 1 with the appropriate period). It also allows for further functionality expansion, in particular with
In both cases, it is necessary to define how the learner behaves when a complete period is not present in the training data, i.e., when the lag order is greater than the length of the data. Hence, the existing forecast.arima learner is not sufficient.
Request:
I have attached a few lines of concept code for reference (modified versions of forecast.arima
and forecast.average, respectively
). I would appreciate it if you could provide feedback on whether such a contribution would be welcomed by the maintainers.
Thank you for your time.
LearnerRegrForecastPersistenceShift:
#' @title Lag-based Persistence Forecast Learner
#'
#' @name mlr_learners_regr.persistence_shift
#'
#' @description
#' Persistence model using data.table::shift()
#' Calls [data.table::shift] from package \CRANpkg{data.table} to return a (seasonal) persistence forecast.
#'
#' @templateVar id forecast.persistence_shift
#' @template learner
#'
#' @template seealso_learner
#' @export
#' @template example
LearnerRegrForecastPersistenceShift = R6::R6Class("LearnerRegrForecastPersistenceShift",
inherit = LearnerForecast,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
period = p_int(default = 1, lower = 1, tags = "train")
)
super$initialize(
id = "forecast.persistence_shift",
feature_types = "numeric",
predict_types = c("response"),
packages = "forecast",
param_set = ps,
properties = c("univariate"),
man = "mlr3temporal::mlr_learners_regr.persistence_shift"
)
},
#' @description
#' Returns forecasts after the last training instance.
#'
#' @param h (`numeric(1)`)\cr
#' Number of steps ahead to forecast. Default is 10.
#'
#' @param task ([Task]).
#'
#' @param newdata ([data.frame()])\cr
#' Ignored
#'
#' @return [Prediction].
forecast = function(h = 10, task, newdata = NULL) {
h = assert_int(h, lower = 1, coerce = TRUE)
indices = 1:h%%length(self$model$last_values)
indices[indices == 0] = length(self$model$last_values)
forecast = self$model$last_values[indices]
response = as.data.table(forecast)
colnames(response) = task$target_names
truth = copy(response)
truth[, colnames(truth) := 0]
p = PredictionForecast$new(task,
response = response, truth = truth,
row_ids = (self$date_span$end$row_id + 1):(self$date_span$end$row_id + h)
)
}
),
private = list(
.train = function(task) {
span = range(task$date()[[task$date_col]])
self$date_span = list(
begin = list(time = span[1], row_id = task$row_ids[1]),
end = list(time = span[2], row_id = task$row_ids[task$nrow])
)
pv = self$param_set$get_values(tags = "train")
period = ifelse(is.null(pv$period), self$param_set$default$period, pv$period)
x = task$data(cols = task$target_names)[[1L]]
mean_x = mean(x)
last_values = c(
x[max((task$nrow-period+1),1):task$nrow], #already present values
rep(mean_x, times = max(0, period-task$nrow))#fill up with mean of the data, if period is longer than training data
)
list(
"fill_value" = mean_x,
"period" = period,
"fitted" = data.table::shift(x, n = period, type = "lag", fill = mean_x),
"last_values" = last_values
)
},
.predict = function(task) {
all_values = c(self$model$fitted, self$model$last_values)
indices = task$row_ids%%length(all_values)
indices[indices == 0] = length(all_values)
response = all_values[indices]
list("response" = response)
}
)
)
#' @include aaa.R
learners[["forecast.persistence_shift"]] = LearnerRegrForecastPersistenceShift
LearnerRegrForecastPersistenceArima:
#' @title Arima-based Persistence Forecast Learner
#'
#' @name mlr_learners_regr.persistence_arima
#'
#' @description
#' Persistence model as an ARIMA model
#' Calls [forecast::Arima] from package \CRANpkg{forecast} with suitable parameter values to return a (seasonal) persistence forecast.
#'
#' @templateVar id forecast.persistence_arima
#' @template learner
#'
#' @template seealso_learner
#' @export
#' @template example
LearnerRegrForecastPersistenceArima = R6::R6Class("LearnerRegrForecastPersistenceArima",
inherit = LearnerForecast,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
period = p_int(default = 1, lower = 1, tags = "train")
)
super$initialize(
id = "forecast.persistence_arima",
feature_types = "numeric",
predict_types = c("response", "se"),
packages = "forecast",
param_set = ps,
properties = c("univariate"),
man = "mlr3temporal::mlr_learners_regr.persistence_arima"
)
},
#' @description
#' Returns forecasts after the last training instance.
#'
#' @param h (`numeric(1)`)\cr
#' Number of steps ahead to forecast. Default is 10.
#'
#' @param task ([Task]).
#'
#' @param newdata ([data.frame()])\cr
#' New data to predict on.
#'
#' @return [Prediction].
forecast = function(h = 10, task, newdata = NULL) {
h = assert_int(h, lower = 1, coerce = TRUE)
if (length(task$feature_names) > 0) {
newdata = as.matrix(newdata)
forecast = invoke(forecast::forecast, self$model, xreg = newdata)
} else {
forecast = invoke(forecast::forecast, self$model, h = h)
}
response = as.data.table(as.numeric(forecast$mean))
colnames(response) = task$target_names
se = as.data.table(as.numeric(
ci_to_se(width = forecast$upper[, 1] - forecast$lower[, 1], level = forecast$level[1])
))
colnames(se) = task$target_names
truth = copy(response)
truth[, colnames(truth) := 0]
p = PredictionForecast$new(task,
response = response, se = se, truth = truth,
row_ids = (self$date_span$end$row_id + 1):(self$date_span$end$row_id + h)
)
}
),
private = list(
.train = function(task) {
span = range(task$date()[[task$date_col]])
self$date_span = list(
begin = list(time = span[1], row_id = task$row_ids[1]),
end = list(time = span[2], row_id = task$row_ids[task$nrow])
)
pv = self$param_set$get_values(tags = "train")
seasonal = list(order = c(0L, 1L, 0L), period = pv$period)
invoke(forecast::Arima,
y = task$data(
rows = task$row_ids,
cols = task$target_names
),
seasonal = seasonal,
include.mean = FALSE)
},
.predict = function(task) {
se = NULL
fitted_ids = task$row_ids[task$row_ids <= self$date_span$end$row_id]
predict_ids = setdiff(task$row_ids, fitted_ids)
if (length(predict_ids) > 0) {
if (length(task$feature_names) > 0) {
newdata = as.matrix(task$data(cols = task$feature_names, rows = predict_ids))
response_predict = invoke(forecast::forecast, self$model, xreg = newdata)
} else {
response_predict = invoke(forecast::forecast, self$model, h = length(predict_ids))
}
predict_mean = as.data.table(as.numeric(response_predict$mean))
colnames(predict_mean) = task$target_names
fitted.mean = self$fitted_values(fitted_ids)
colnames(fitted.mean) = task$target_names
response = rbind(fitted.mean, predict_mean)
if (self$predict_type == "se") {
predict_se = as.data.table(as.numeric(
ci_to_se(width = response_predict$upper[, 1] - response_predict$lower[, 1],
level = response_predict$level[1])
))
colnames(predict_se) = task$target_names
fitted_se = as.data.table(
sapply(task$target_names, function(x) rep(sqrt(self$model$sigma2), length(fitted_ids)), simplify = FALSE)
)
se = rbind(fitted_se, predict_se)
}
} else {
response = self$fitted_values(fitted_ids)
if (self$predict_type == "se") {
se = as.data.table(
sapply(task$target_names, function(x) rep(sqrt(self$model$sigma2), length(fitted_ids)), simplify = FALSE)
)
}
}
list(response = response, se = se)
}
)
)
#' @include aaa.R
learners[["forecast.persistence_arima"]] = LearnerRegrForecastPersistenceArima
Example using the included airpassengers task:
library(mlr3verse)
library(mlr3temporal)
library(R6)
library(mlr3misc)
library(checkmate)
library(data.table)
source("aaa.R")
source("helper.R")
source("LearnerRegrForecastPersistenceArima.R")
source("LearnerRegrForecastPersistenceShift.R")
source("zzz.R")
task <- tsk("airpassengers")
l_arima <- LearnerRegrForecastPersistenceArima$new()
l_arima$param_set$set_values(period = 10)
l_arima$train(task, row_ids = 1:134)
l_arima$predict(task, row_ids = 135:144)
(forecast_arima <- l_arima$forecast(h = 40, task = task))
l_shift <- LearnerRegrForecastPersistenceLag$new()
l_shift$param_set$set_values(period = 10)
l_shift$train(task, row_ids = 1:134)
l_shift$predict(task, row_ids = 135:144)
(forecast_shift <- l_shift$forecast(h = 40, task = task))
library(ggplot2)
ggplot() +
geom_line(aes(x = task$row_ids, y = task$data()$target)) +
geom_line(data = as.data.table(forecast_arima)[,c(1,3)], aes(x = row_ids, y = target), linetype = "dashed") +
labs(x = "row_ids", y = "target", caption = "dashed: prediction")