Freezing weights
- the problem here is that the trained weights do not live in the PipeOpTorch objects
- PipeOpTorch should probably keep a link to the nn_module's parameters (torch_tensors have reference semantics, in some way). (Anticipating problems with cloning here...). They should have a hyperparameter 'fix_weights' or something like that, in which case the new nn_module they create has these weights (and they are fixed).
- maybe also have a 'relative learning rate' hyperparameter
Maybe relevant: It is possible, to set group-specific parameters for the optimizer:
library(torch)
n = nn_linear(10, 1)
y = torch_randn(16, 1)
x = torch_randn(16, 10)
params = list(
list(params = n$parameters$weight, lr = 0),
list(params = n$parameters$bias, lr = 0.1)
)
weight = torch_clone(n$parameters$weight)
bias = torch_clone(n$parameters$bias)
o = optim_sgd(params, lr = 0.1)
o$zero_grad()
y_hat = n(x)
loss = nnf_mse_loss(y_hat, y)
loss$backward()
o$step()
#> NULL
weight - n$parameters$weight
#> torch_tensor
#> 0 0 0 0 0 0 0 0 0 0
#> [ CPUFloatType{1,10} ][ grad_fn = <SubBackward0> ]
bias - n$parameters$bias
#> torch_tensor
#> 0.01 *
#> 4.3876
#> [ CPUFloatType{1} ][ grad_fn = <SubBackward0> ]
Created on 2022-09-20 by the reprex package (v2.0.1)
Regarding the pre-initialized weights: We could consider saving the PipeOpModule that gets generated by a PipeOpTorch in the $state of that PipeOpTorch, or alternatively as a hyperparameter. The PipeOpModule contains the torch nn_module, so has the weights, and they can be initialized. Currently PipeOpTorchModel resets the weights of its inputs, that would need to be avoided in this case.
Difference between hyperparameter and $state would be that the latter would (need to?) make use of the hotstart API, not sure how stable that is, and how flexible it is, in particular when some of the modules are pre-trained and others aren't (so only some of the $states are present). Therefore using a hyperparameter is probably best here.
We'd want to have a neat API for transferring pre-trained model-weights to parts of graphs, similar to what we want for https://github.com/mlr-org/mlr3pipelines/issues/538
idea:
$statesaves thePipeOpModuleby reference- setting module-hyperparameter fixes it; then
$.train()only checks whether thePipeOpModuleis compatible with input - hyperparameters for freezing weights / setting learning rate could be hyperparameter of
PipeOpTorchorPipeOpModule. maybe it should be ofPipeOpModule, then they can be modified during training, e.g. by a callback - maybe give
PipeOpModuleorPipeOpTorchan option to define tags. then the callback can address all weights that have a certain tag with a convenience-functiopn