Optimisers.jl
Optimisers.jl copied to clipboard
Investigate using a different AD for tests
Zygote compile times make what ordinarily should be a pretty fast test suite rather sluggish. Perhaps we could borrow some functionality from CRTU: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/v1.8.1/src/testers.jl#L224.
It would be nice to be quicker. But the tests aren't just about mathematical correctness, some are also checking that missing branches with nothing
are handled correctly, etc. If https://github.com/JuliaDiff/Diffractor.jl/issues/66 ever happens, these should ideally be supplemented by checks that ZeroTangent
works equally well.
Would FD.jl not generate NoTangent
for those same branches? I've not tried it with nested structs.
It seems to make dense zeros:
julia> using FiniteDifferences
julia> grad(central_fdm(5, 1), x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Float32[0.9999863, 1.0000029], b = (Float32[0.0, 0.0], 0.0)),)
I wondered if Tracker might work well, but it doesn't seem to like NamedTuples.
For comparison:
julia> Zygote.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Fill(1.0f0, 2), b = nothing),)
julia> Diffractor.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = InplaceableThunk(ChainRules.var"#1547#1550"{Float32, Colon}(1.0f0, Colon()), Thunk(ChainRules.var"#1548#1551"{Float32, Colon, Vector{Float32}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}(1.0f0, Colon(), Float32[1.0, 2.0], ProjectTo{AbstractArray}(element = ProjectTo{Float32}(), axes = (Base.OneTo(2),))))),),)
julia> ans[1].b
ZeroTangent()
@mcabbott Since you linked the issue in Yota, I guess you want to test this example too. So let me save you a few minutes:
julia> grad(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))[2][2:end]
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = Float32[1.0, 1.0],),)
julia> ans[1].b
ZeroTangent()
@mcabbott Funny that #105 explicitly says it won't close this. Maybe it was closed automatically because Github saw the "close #105" string.
Just checking if the there was a real intention to close this issue?