torchvision
torchvision copied to clipboard
adding `.getbatch()` method to `mnist_dataset` dataset generator improves performance markedly
The standard dataset_generator for MNIST dataset does not include a .getbatch()
method and, as a result, getting a batch is quite slow, at least on CPU.
# dataset root directory
dir <- "./dataset"
# download dataset
train_ds <- mnist_dataset(
dir,
download = TRUE,
transform = transform_to_tensor
)
# dataloader
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader iterator
train_iter <- train_dl$.iter()
microbenchmark( b <- $.next())
The timings are:
Unit: milliseconds
expr min lq mean median uq max neval
b <- train_iter$.next() 45.5263 47.78455 52.74117 49.19825 52.95675 87.2219 100
As explained in the vignette, the dataloader uses the .getitem()
method iteratively to return a batch in absence of a .getmatch()
method.
Interestingly, it seems that the .getitem()
method mitgh be used as .getbatch()
method without any change:
# mnist_dataset .getitem() method
> train_ds$.getitem
function (index)
{
img <- self$data[index, , ]
target <- self$targets[index]
if (!is.null(self$transform))
img <- self$transform(img)
if (!is.null(self$target_transform))
target <- self$target_transform(target)
list(x = img, y = target)
}
<environment: 0x000001527445be68>
It is easy to add a .getbatch()
to the exsiting mnist_dataset
dataset generator:
# create a new dataset generator that extends mnist_dataset
mnist_dataset2 <- dataset(
inherit = mnist_dataset,
.getbatch = function(index) {
self$.getitem(index)
}
)
Let's measure the performance with this new dataset generator:
# create a dataset with the new dataset generator
train_ds2 <- mnist_dataset2(
dir,
download = TRUE,
transform = transform_to_tensor
)
# create a dataloder with the new dataset
train_dl2 <- dataloader(train_ds2, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader
train_iter2 <- train_dl2$.iter()
microbenchmark::microbenchmark(train_iter2$.next())
Unit: milliseconds
expr min lq mean median uq max neval
train_iter2$.next() 3.995601 4.328151 5.430246 4.601451 4.965501 11.7692 100
The new dataloader is almost 10 times faster!
That saids, it seems that the newdata loader cannot be used in place of train_dl
in this example which uses luz
to train the network:
fitted <- mnist_module %>%
setup(
loss = nn_cross_entropy_loss(),
optimizer = optim_adam,
metrics = list(
luz_metric_accuracy()
)
) %>%
fit(train_dl, epochs = 1, valid_data = test_dl)
It yields an error message
expected input[1, 28, 128, 28] to have 1 channels, but got 28 channels instead
I don't have a PC with GPU to test whether there is a similar improvement when the data are loaded on the GPU. I also wonder why the .getbatch()
function is not always implemented since it seems an easy way to improve performance. Though I did not investigate the origin the error, the luz::fit
method should be able to accept data_loader with a .getbatch
method.