Rotations.jl
Rotations.jl copied to clipboard
Support automatic differentiation with Zygote
Supporting reverse-mode autodiff with Zygote requires two things:
- custom pullbacks (reverse-mode differentiation rules, AKA adjoints) for functions that internally mutate arrays (Zygote does not support array mutation)
- custom pullbacks for constructors
I think the latter is the main thing missing for Zygote support. Supporting it would require adding custom rrules for constructors using ChainRulesCore, which has essentially no dependencies.
Here's an example that fails:
julia> using Rotations, Zygote
julia> foo(ω, v) = (RotationVec(ω...) * v)[1];
julia> ω, v = randn(3), randn(3)
([-0.5124874613220701, 0.27274002772526423, 1.0593705312514463], [-0.2329525748114141, -1.1183670007323072, -0.4878065893106537])
julia> foo(ω, v)
0.8904209135865043
julia> Zygote.gradient(foo, ω, v) # expected from finite differences: ([-0.18142454901446126, -0.16476417425397563, 0.8182962760715483], [0.47098619925782836, -0.8816660765248943, -0.028929735807022527])
ERROR: Need an adjoint for constructor RotationVec{Float64}. Gradient is of type Array{Float64,2}
In this case mathematically the missing pullback is for the exponential map exp: so(3) → SO(3).
OK - I'm not at all familiar with Zygote. Would supporting it mean (a) adding Zygote as a dependency and (b) adding some new methods to some Zygote functions (like rrules)?
OK - I'm not at all familiar with Zygote. Would supporting it mean (a) adding Zygote as a dependency and (b) adding some new methods to some Zygote functions (like
rrules)?
Nope! It would only require adding ChainRulesCore as a dependency, which has a single dependency that in turn has no dependencies, so it's extremely light. The rrule function is in ChainRulesCore, so it would only involve overriding rrule for the constructors.
@oxinabox gave a nice talk on ChainRules at JuliaCon this year: https://www.youtube.com/watch?v=B4NfkkkJ7rs
Yeah OK, seems quite reasonable to me. (Pull requests are always welcome :wink:)