mlr3torch icon indicating copy to clipboard operation
mlr3torch copied to clipboard

Fix/nn transformer block

Open cxzhang4 opened this issue 8 months ago • 1 comments

A sketch of the FT-Transformer graph.

cxzhang4 avatar Mar 21 '25 09:03 cxzhang4

#' @title Custom Function
#' @inherit torch::nnf_linear description
#' @section nn_module:
#' Calls [`torch::nn_linear()`] when trained where the parameter `in_features` is inferred as the second
#' to last dimension of the input tensor.
#' @section Parameters:
#' * `out_features` :: `integer(1)`\cr
#'   The output features of the linear layer.
#' * `bias` :: `logical(1)`\cr
#'   Whether to use a bias.
#'   Default is `TRUE`.
#'
#' @templateVar id nn_linear
#' @template pipeop_torch_channels_default
#' @templateVar param_vals out_features = 10
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#'
#' @export
PipeOpTorchFn = R6Class("PipeOpTorchFn",
  inherit = PipeOpTorch,
  public = list(
    #' @description Creates a new instance of this [R6][R6::R6Class] class.
    #' @template params_pipelines
    initialize = function(id = "nn_fn", param_vals = list()) {
      param_set = ps(fn = p_uty(...))
      super$initialize(
        id = id,
        param_set = param_set,
        param_vals = param_vals,
        module_generator = nn_linear
      )
    }
  ),
  private = list(
    .shapes_out = function(shapes_in, param_vals, task) {
      # Implement this.
      # 1. Generate a tensor of shape shapes_in (fill NA with something)
      # 2. Apply function private$.f
      # 3. Meausre shapes and fill dimensions with NA again

      # Should also be possible to implement shapes_out properly

      # Also take inspiration from pipeop_preproc_torch
    },
    .make_module = function(shapes_in, param_vals, task) {
      self$param_set$values$fn
    },
    .fn = NULL
  )
)

#' @include aaa.R
register_po("nn_fn", PipeOpTorchFn)

sebffischer avatar Mar 21 '25 10:03 sebffischer