`dataset`'s `[` implementation has inconsistent output shapes when `$.getitem()` is implemented
In the example below, for the dataset with the $.getitem() implementation, the [ method returns an element without batch dimension for an index of length 1, and otherwise includes the batch dimension. I think it would be better to have this consistent and always return the batch dimension.
library(torch)
ds_batch = dataset("batch",
initialize = function() {
self$x = torch_randn(100, 10)
},
.getbatch = function(i) {
self$x[i,.., drop = FALSE]
},
.length = function() nrow(self$x)
)()
print(ds_batch[1L]$shape)
#> [1] 1 10
print(ds_batch[1:2]$shape)
#> [1] 2 10
ds_item = dataset("batch",
initialize = function() {
self$x = torch_randn(100, 10)
},
.getitem = function(i) {
self$x[i]
},
.length = function() nrow(self$x)
)()
print(ds_item[1L]$shape)
#> [1] 10
print(ds_item[1:2]$shape)
#> [1] 2 10
Created on 2025-04-17 with reprex v2.1.1
Ok, I realized that this is because [.dataset just calls into $.getitem() with whatever indices are provided.
I am not sure what the correct behavior here is, but I think the current implementation is somewhat inconsistent.
One suggestion would be to make [.dataset err when there is more than one index provided (for datsets that implement only $.getitem(). We can't just cat along the first dimension because the returned tensors might have varying shapes.
Also for consistency I think that [.dataset should include the batch dimension when called with a single index on a dataset that implements $.getitem.
I agree with your second suggestion [ should ibnclude the batch dimension when called with a single index on a dataset that only implements .getitem(). We could implement [[ to extract a single element by index, with .getitem().
But the question is still whether ds[1:2] should throw an error if ds implements $.getitem(). The different tensors might have varying shapes, so it's not always possible to torch_cat() them.
Yes, maybe a simpler solution is to error if it only implements .getitem() but then, I don't think we should include the batch dimension in this case. Maybe just allow [[ if .getitem is implemented. And make [ for .getbatch.