knitr
knitr copied to clipboard
Problem caching instances of torch modules and datasets
Caching chunks that create an instance of torch module or of a torch dataset yields an external pointer is not valid error when the instance is used in another chunk.
Example with torch module:
```{r, cache=TRUE}
lin <- nn_linear(2, 3)
# torch_save(lin, "lin.pt")
```
```{r}
# lin <- torch_load("lin.pt")
x <- torch_randn(2)
lin$forward(x)
```
Example with torch dataset:
```{r, cache=TRUE}
ds_gen <- dataset(
initialize = function() {
self$x <- torch_tensor(1:10, dtype=torch_long())
},
.getitem = function(index) {
self$x[index]
},
.length = function() {
length(self$x)
}
)
ds <- ds_gen()
```
```{r}
ds[1:3]
```
If there is no cache, the chunks are executed without problems. However, when a cache exists, an error is created when trying to access the cached instance of the module or of the dataset:
Error in cpp_tensor_dim(x$ptr) : external pointer is not valid
This might be due to the fact that R torch package relies on reference classes (R6 and/or R7) and could be related to issue #2176. In any case, caching would be useful to cache trained instance of a module or instances of datasets which involve a lot processing during initialization.
At the moment, the only alternative is to save the torch model in the cached chunk with torch_save and load it in the uncached chunk with torch_load (see comments in the chunk above). However, afaik, there is no method to save and load torch datasets.
@atusy do you think https://github.com/yihui/knitr/pull/2340 could also solves this ?
It seems to me broader, and probably more related to reticulate caching mechanism.
@gavril0 did you open an issue about this in reticulate ? The knitr's python engine lives there.