Add optimisers
This PR seeks to add optimisers functionality to the code.
- [x] FTorch
- [x] Expose optimisers in FTorch
- [x] Write unit tests to cover these
- [x] Update
pages/online.md
- [ ] Exercise:
- [ ] README - WIP
- [x] requirements
- [x] Python version of the exercise
- [x] Fortran version of the exercise
- [x] Make use of
torch_tensor_mean- see #240
- [x] Make use of
- [x] Plotting
- [ ] Give it a number rather than $n$
Notes whilst in progress:
- it is handy to implicitly initialise a gradient of ones in
.backward()by contracting a (loss) tensor into a scalar. Should this be added to the autograd example? ftorch.F90is going to grow with this. Perhaps now is the time to break apart into sub-modules?- I'm thinking of having
ftorch_optimas a module to hold these.
- I'm thinking of having
Starting with trying to simply bring across SGD.
There is a torch::optim::Optimizer class, but it is never directly used, instead being subclassed by torch::optim::SGD etc. For now I will just try and bring over torch::optim::SGD, and we can see if it is possible to generalise in future.
There is an associated torch::optim::SGDOptions class (subclassed from torch::optim::OptimizerOptions (or maybe torch::optim::OptimizerCloneableOptions??)) that is passed into the SGD class to provide options.
The other required input is a list of tensors: std::vector<torch::Tensor> which I will use in the early examples.
More generally we can use torch::nn::Module::parameters() to provide this list, so we probably want to expose that method for models somewhere down the line.
We will need to expose the .zero_grad() method and the .step() method which are common for all optimizers.
Still not clear what unit tests should be. My plan is to write some C++ code, then move to write using ctorch code, then to FTorch code.
The other required input is a list of tensors:
std::vector<torch::Tensor>which I will use in the early examples. More generally we can usetorch::nn::Module::parameters()to provide this list, so we probably want to expose that method for models somewhere down the line.
- On reflection I think we want to make the default input a
vector, and perhaps handle enforcing this on the FTorch side if possible? - For model forward we require a user to supply tensors as an array of tensors for which each
IValuein C++ is pushed on to a vector. This can also be implemented.- The action of casting an array of tensors to a vector probably needs to become its own function as it seems to be being reused.
- Note it turns out that
IValuesare a specific TorchScript thing, not Torch. Hopefully this does not foreshadow future difficulties.
- For model parameters we could either pass around a pointer to a vector of tensors as arises from the C++
.parametersand check the type, or manually extract into an array of tensors in Fortran. The former might be 'cleaner', but the latter would be useful should we ever eant to probe things on ther Fortran side...
Other thoughts:
- There are various other optional arguments available to SGD beyond learning rate: https://pytorch.org/docs/stable/generated/torch.optim.SGD.html How many of these do we want to wrap. Realistically only a certain number are commonly used and users could always extend.
- We used to have models and tensors returned by a call to a function, and the same could be done for optimizers. This had to be changed for tensors to allow for autograd, and I believe we decided to make models consistent. Therefore I have chosen the same for optimizers: declare optimizer object, pass to call to set up optimizer, use optimizer. Function might be cleaner however: optimizer = call to init optimizer, use optimizer.
Checkpointing here:
- Functions for:
- Creating SGD optimizer
- stepping an optimizer
- zeroing gradients associated with an optimizer
- Python and fortran examples to 'train' a single tensor to scale an input vector
Further steps:
- [ ] Implement more optimizers beyond SGD
- [ ] Tests for optimizer functionalities
- [ ] Better generation of input
parameters- [ ] Including getting outputs of a model, though this may be another PR
- [ ] Docs, always docs
A useful thing to note is that if we want to set values of a tensor (e.g. loss, or a gradient) it must NOT have requires_grad=.true. or it will produce a segfault.
This has caught me out a couple of times, so probably need some clear docs as to where users should and should not be enabling gradient calculations.
A useful thing to note is that if we want to set values of a tensor (e.g. loss, or a gradient) it must NOT have
requires_grad=.true.or it will produce a segfault.This has caught me out a couple of times, so probably need some clear docs as to where users should and should not be enabling gradient calculations.
Good to know. It seems you need to use requires_grad=.true. only for the arrays you want to differentiate with respect to.
FYI I rebased this branch on top of main on branch joe/optim in case it's useful. Really cool functionality!
Rebased on main
Note: Found during this that one cannot use tensor_a = tensor_b(requires_grad=.true.) when tensor_a is only defined as a torch_tensor. This is because the required_grad property gets copied across, and then we get errors resulting from trying to modify the values of a tensor with requires_grad=.true.. The solution is to declare tensor_a as empty.
See pFUnit tests for optimizers, specifically the zero_grad test.
cc @jwallwork23
@jwallwork23 Take a look at https://github.com/Cambridge-ICCS/FTorch/pull/320/commits/e59fa97ad8436b952af7814ccdd7782f45a6affb for updated unit tests for zero_grad.
Note that in libtorch zero_grad does not set gradients to 0.0 but rather makes them 'undefined' ready for the next backward step. This meant that I check the zero_grad operation implicitly (I added a comment to try and explain this).
Secondly, I found that if you ever want to call backward more than once on the same tensor you need to use retain_graph=.true., regardless of whether you zero gradients between operations or not - hence it being used for every call in the unit test. I confirmed this by removing it from the backward step in the integration test loop which causes failure.
So perhaps we should consider true as a default, or at least make it clear in the docs that this is required for multiple use.
I haven't yet delved deep enough to understands what happens to the gradient after a backward call when the graph is not retained.
The error is:
[ERROR]: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward(
) or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
@jwallwork23 Take a look at e59fa97 for updated unit tests for zero_grad.
Note that in libtorch zero_grad does not set gradients to
0.0but rather makes them 'undefined' ready for the next backward step. This meant that I check the zero_grad operation implicitly (I added a comment to try and explain this).Secondly, I found that if you ever want to call backward more than once on the same tensor you need to use
retain_graph=.true., regardless of whether you zero gradients between operations or not - hence it being used for every call in the unit test. I confirmed this by removing it from the backward step in the integration test loop which causes failure. So perhaps we should consider true as a default, or at least make it clear in the docs that this is required for multiple use. I haven't yet delved deep enough to understands what happens to the gradient after a backward call when the graph is not retained.
Thanks for digging into this and for leaving helpful comments. It'd probably be worth putting similar comments in the source code for zero_grad, to make sure we don't get confused by this in the future?
Rebased on main to resolve conflicts.