torch icon indicating copy to clipboard operation
torch copied to clipboard

`dataset`'s `[` implementation has inconsistent output shapes when `$.getitem()` is implemented

Open sebffischer opened this issue 8 months ago • 4 comments

In the example below, for the dataset with the $.getitem() implementation, the [ method returns an element without batch dimension for an index of length 1, and otherwise includes the batch dimension. I think it would be better to have this consistent and always return the batch dimension.

library(torch)

ds_batch = dataset("batch",
  initialize = function() {
    self$x = torch_randn(100, 10)
  },
  .getbatch = function(i) {
    self$x[i,.., drop = FALSE]
  },
  .length = function() nrow(self$x)
)()

print(ds_batch[1L]$shape)
#> [1]  1 10
print(ds_batch[1:2]$shape)
#> [1]  2 10

ds_item = dataset("batch",
  initialize = function() {
    self$x = torch_randn(100, 10)
  },
  .getitem = function(i) {
    self$x[i]
  },
  .length = function() nrow(self$x)
)()

print(ds_item[1L]$shape)
#> [1] 10
print(ds_item[1:2]$shape)
#> [1]  2 10

Created on 2025-04-17 with reprex v2.1.1

sebffischer avatar Apr 17 '25 07:04 sebffischer

Ok, I realized that this is because [.dataset just calls into $.getitem() with whatever indices are provided.

I am not sure what the correct behavior here is, but I think the current implementation is somewhat inconsistent.

One suggestion would be to make [.dataset err when there is more than one index provided (for datsets that implement only $.getitem(). We can't just cat along the first dimension because the returned tensors might have varying shapes.

Also for consistency I think that [.dataset should include the batch dimension when called with a single index on a dataset that implements $.getitem.

sebffischer avatar Apr 17 '25 10:04 sebffischer

I agree with your second suggestion [ should ibnclude the batch dimension when called with a single index on a dataset that only implements .getitem(). We could implement [[ to extract a single element by index, with .getitem().

dfalbel avatar Apr 17 '25 11:04 dfalbel

But the question is still whether ds[1:2] should throw an error if ds implements $.getitem(). The different tensors might have varying shapes, so it's not always possible to torch_cat() them.

sebffischer avatar Apr 17 '25 11:04 sebffischer

Yes, maybe a simpler solution is to error if it only implements .getitem() but then, I don't think we should include the batch dimension in this case. Maybe just allow [[ if .getitem is implemented. And make [ for .getbatch.

dfalbel avatar Apr 17 '25 12:04 dfalbel