mlr3torch icon indicating copy to clipboard operation
mlr3torch copied to clipboard

Lazy Tensor Extension Ideas

Open sebffischer opened this issue 2 years ago • 1 comments

The lazy_tensor datatype is currently restrictive in various ways:

  1. Currently there can only be a single preprocessing graph for the whole tensor-column.
  2. Currently it is assumed that there is only a single dataset per lazy tensor column. It is for example not possible to add two lazy tensor columns and merge their preprocessing graphs using an addition operator. While this might not be that common / useful, we might want to add it later.

Some thoughts About 2. In order to achieve this, we need to merge the graphs and the datasets. Merging the graphs is already implemented and works analogously to merging ModelDescriptors. Merging the datasets is not implemented, but is also not that difficult.

To enable this, we would need a dataset_collection, which is essentially something like the class below, only that we would should treat the special case where all datasets have the .getbatch() method and implement the .getbatch() method accordingly. It is important to take the (input) caching into account, because there could be multiple dataset collections pointing to the same dataset.

Another challenge is that we only want to merge those graphs where it is really necessary to merge them (i.e. they share computation). All inputs of a graph need to be of the same shape (i.e. either a single tensor or the whole batch). If we would merge graphs to aggressively, tensors might get processed row-wise even though they could be processed batchwise, thereby slowing everything down. Currently we sidestep this problem and merge only those graphs that have the same dataset. Because no merge operations are currently supported, this is not really restrictive.

make_dataset_collection = function(lts) {
  dataset("dataset_collection",
    initialize = function(lts) {
      datasets = map(lts, function(lt) dd(lt)$dataset)
      duplicated = duplicated(map(lts, function(lt) dd(lt)$dataset_hash))
      self$datasets = datasets[!duplicated]
      self$ids = map(lts[!duplicated], function(lt) map(lt, 1L))
      self$names = Reduce(c, map(lts[!duplicated], function(lt) {
        paste0(substring(lt$dataset_hash, 1, 5), ".", names(dd(lt)$dataset_shapes))
      }))
    },
    .getitem = function(i) {
      batch = list()
      for (j in seq_along(self$datasets)) {
        idx = self$ids[[j]][i]
        if (is.null(self$dataset[[j]]$.getbatch)) {
          new = self$datasets[[j]]$.getitem(idx)
        } else {
          new = map(self$datasets[[j]]$.getbatch(idx), function(x) x$squeeze(1L))
        }
        batch = append(batch, new)
      }
      set_names(batch, self$names)
    }
  )(lts)
}

sebffischer avatar Oct 26 '23 07:10 sebffischer

Some WIP is here: https://github.com/mlr-org/mlr3torch/tree/feat/lazy_tensor_extension

sebffischer avatar Jan 23 '24 08:01 sebffischer

only do this if we really need this, closing for now

sebffischer avatar Jun 13 '24 15:06 sebffischer