Optimisers.jl icon indicating copy to clipboard operation
Optimisers.jl copied to clipboard

Investigate using a different AD for tests

Open ToucheSir opened this issue 2 years ago • 5 comments

Zygote compile times make what ordinarily should be a pretty fast test suite rather sluggish. Perhaps we could borrow some functionality from CRTU: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/v1.8.1/src/testers.jl#L224.

ToucheSir avatar Jul 03 '22 18:07 ToucheSir

It would be nice to be quicker. But the tests aren't just about mathematical correctness, some are also checking that missing branches with nothing are handled correctly, etc. If https://github.com/JuliaDiff/Diffractor.jl/issues/66 ever happens, these should ideally be supplemented by checks that ZeroTangent works equally well.

mcabbott avatar Jul 03 '22 21:07 mcabbott

Would FD.jl not generate NoTangent for those same branches? I've not tried it with nested structs.

ToucheSir avatar Jul 03 '22 22:07 ToucheSir

It seems to make dense zeros:

julia> using FiniteDifferences

julia> grad(central_fdm(5, 1), x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Float32[0.9999863, 1.0000029], b = (Float32[0.0, 0.0], 0.0)),)

I wondered if Tracker might work well, but it doesn't seem to like NamedTuples.

For comparison:

julia> Zygote.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
((a = Fill(1.0f0, 2), b = nothing),)

julia> Diffractor.gradient(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = InplaceableThunk(ChainRules.var"#1547#1550"{Float32, Colon}(1.0f0, Colon()), Thunk(ChainRules.var"#1548#1551"{Float32, Colon, Vector{Float32}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}(1.0f0, Colon(), Float32[1.0, 2.0], ProjectTo{AbstractArray}(element = ProjectTo{Float32}(), axes = (Base.OneTo(2),))))),),)

julia> ans[1].b
ZeroTangent()

mcabbott avatar Jul 04 '22 00:07 mcabbott

@mcabbott Since you linked the issue in Yota, I guess you want to test this example too. So let me save you a few minutes:

julia> grad(x -> sum(x.a), (a=[1,2f0], b=([3,4f0], 5.0)))[2][2:end]
(Tangent{NamedTuple{(:a, :b), Tuple{Vector{Float32}, Tuple{Vector{Float32}, Float64}}}}(a = Float32[1.0, 1.0],),)

julia> ans[1].b
ZeroTangent()

dfdx avatar Jul 09 '22 22:07 dfdx

@mcabbott Funny that #105 explicitly says it won't close this. Maybe it was closed automatically because Github saw the "close #105" string.

Just checking if the there was a real intention to close this issue?

cossio avatar Dec 11 '22 22:12 cossio