Optimisers.jl
Optimisers.jl copied to clipboard
Optimisers.jl defines many standard optimisers and utilities for learning loops.
This was discovered in https://github.com/SciML/NeuralPDE.jl/issues/533 as an issue that only showed itself as an incorrect gradient: the primal passes of what was being trained was in Float64, the reverse passes...
Through a typo, I found that Base exports [`iswritable`](https://docs.julialang.org/en/v1/base/io-network/#Base.iswritable) (one char difference). Were we to outsource this, possible candidates include [`ChainRulesCore.is_inplaceable_destination`](https://juliadiff.org/ChainRulesCore.jl/stable/api.html#ChainRulesCore.is_inplaceable_destination) (used by `add!!`) and [`ArrayInterfaceCore.ismutable`](https://juliaarrays.github.io/ArrayInterface.jl/dev/api/#ArrayInterfaceCore.ismutable).
It seems like Optimisers only works with vectors. ``` struct WithFloat val :: Float64 end @functor WithFloat ``` The call to `Optimisers.trainable(WithFloat(4))` returns `(val = 4.3,)`, but `destructure(WithFloat(4))` produces an...
We should give more prominence in the docs to the usage of `Flux.@functor` and `Optimisers.trainable` to define trainable parameters of custom types.
In optimisers like AdamW, it is often the case that the learning rate and the weight decay are tweaked, but the momentum decay values are not (see [PyTorch](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html), for example,...
```julia using Flux using Functors using Optimisers struct Custom abc::Tuple end Functors.@functor Custom (abc,) function (f::Custom)(x) x .* f.abc[1] .+ f.abc[2] end function Custom(;dim::Int) abc = (randn(Float32, dim), randn(Float32, dim))...
This is simpler than the version in https://github.com/FluxML/Flux.jl/pull/969, as it has no special handling for momentum, or not yet. It's unusual in that I think it needs to be written...
Flux has model structs, and Zygote would return NamedTuple gradients for them. With FluxML/Functors.jl#1 we add the ability to handle gradients in Functors.jl - in other words do a "zipped"...