mlr3torch icon indicating copy to clipboard operation
mlr3torch copied to clipboard

Freezing weights

Open sebffischer opened this issue 3 years ago • 3 comments

  • the problem here is that the trained weights do not live in the PipeOpTorch objects
  • PipeOpTorch should probably keep a link to the nn_module's parameters (torch_tensors have reference semantics, in some way). (Anticipating problems with cloning here...). They should have a hyperparameter 'fix_weights' or something like that, in which case the new nn_module they create has these weights (and they are fixed).
  • maybe also have a 'relative learning rate' hyperparameter

sebffischer avatar Sep 20 '22 08:09 sebffischer

Maybe relevant: It is possible, to set group-specific parameters for the optimizer:

library(torch)

n = nn_linear(10, 1)
y = torch_randn(16, 1)
x = torch_randn(16, 10)

params = list(
  list(params = n$parameters$weight, lr = 0),
  list(params = n$parameters$bias, lr = 0.1)
)

weight = torch_clone(n$parameters$weight)
bias = torch_clone(n$parameters$bias)

o = optim_sgd(params, lr = 0.1)

o$zero_grad()
y_hat = n(x)
loss = nnf_mse_loss(y_hat, y)
loss$backward()

o$step()
#> NULL

weight - n$parameters$weight
#> torch_tensor
#>  0  0  0  0  0  0  0  0  0  0
#> [ CPUFloatType{1,10} ][ grad_fn = <SubBackward0> ]
bias - n$parameters$bias
#> torch_tensor
#> 0.01 *
#>  4.3876
#> [ CPUFloatType{1} ][ grad_fn = <SubBackward0> ]

Created on 2022-09-20 by the reprex package (v2.0.1)

sebffischer avatar Sep 20 '22 08:09 sebffischer

Regarding the pre-initialized weights: We could consider saving the PipeOpModule that gets generated by a PipeOpTorch in the $state of that PipeOpTorch, or alternatively as a hyperparameter. The PipeOpModule contains the torch nn_module, so has the weights, and they can be initialized. Currently PipeOpTorchModel resets the weights of its inputs, that would need to be avoided in this case.

Difference between hyperparameter and $state would be that the latter would (need to?) make use of the hotstart API, not sure how stable that is, and how flexible it is, in particular when some of the modules are pre-trained and others aren't (so only some of the $states are present). Therefore using a hyperparameter is probably best here.

We'd want to have a neat API for transferring pre-trained model-weights to parts of graphs, similar to what we want for https://github.com/mlr-org/mlr3pipelines/issues/538

mb706 avatar Sep 26 '22 10:09 mb706

idea:

  • $state saves the PipeOpModule by reference
  • setting module-hyperparameter fixes it; then $.train() only checks whether the PipeOpModule is compatible with input
  • hyperparameters for freezing weights / setting learning rate could be hyperparameter of PipeOpTorch or PipeOpModule. maybe it should be of PipeOpModule, then they can be modified during training, e.g. by a callback
  • maybe give PipeOpModule or PipeOpTorch an option to define tags. then the callback can address all weights that have a certain tag with a convenience-functiopn

mb706 avatar Sep 26 '22 11:09 mb706