Zygote.jl
Zygote.jl copied to clipboard
`cispi` has poor performance
Hi!
cispi seems to have type infereability issues:
f1(x) = sum(abs2, cispi.(x))
f2(x) = sum(abs2, cis.(x))
@time Zygote.gradient(f1, ones((1024,1024)))
# 1.121823 seconds (18.87 M allocations: 648.002 MiB, 31.55% gc time)
@time Zygote.gradient(f2, ones((1024,1024)))
# 0.066483 seconds (40 allocations: 168.001 MiB, 30.13% gc time)
Zygote uses ChainRules, doesn't it?
There is this rule for sincospi implemented which should be called by cispi.
Best,
Felix
It should use the rule, but I suspect something else is causing broadcasting to take a slow path here (you can guess because that path requires a couple of allocations per element to store pullbacks, and x1 million is not far off the reported amount)
Ah, this is the same issue as https://github.com/FluxML/Zygote.jl/issues/961#issuecomment-871084267 because cis and cispi return complex numbers.
But cishas good performance?
And the linked issue was only for CUDA
I would not say it has good performance. Even though the memory allocations are better, they and overall runtime are strictly worse than a function with a more "expensive" forward pass.
manyexp(x) = exp(x) * exp(x) + exp(x)
x = ones(1024, 1024)
julia> @btime cis.($x);
10.829 ms (2 allocations: 16.00 MiB)
julia> @btime manyexp.($x);
17.025 ms (2 allocations: 8.00 MiB)
julia> @btime gradient(x -> sum(abs2, cis.(x)), $x);
78.728 ms (38 allocations: 160.00 MiB)
julia> @btime gradient(x -> sum(abs2, manyexp.(x)), $x);
22.156 ms (31 allocations: 40.00 MiB)
It just so happens that the pullback for cis is simple enough to not capture any values and thus be inlined even on the slow path. cispi's pullback is not so fortunate (I think because of the reverse and splat?) and so has to be materialized for every array element.
And the linked issue was only for CUDA
The linked comment discusses a general-purpose solution that works for the CPU path as well. I'm not sure what level of effort would be required to implement it.
I recently discovered https://github.com/JuliaDiff/ForwardDiff.jl/pull/583, which enables real -> complex function differentiation in ForwardDiff. @mcabbott do you think that could be adapted to what Zygote does for broadcast?
Closed in #1324
julia> f1(x) = sum(abs2, cispi.(x))
julia> f2(x) = sum(abs2, cis.(x))
julia> @btime Zygote.gradient(f1, $(ones((1024,1024))));
29.061 ms (34 allocations: 72.00 MiB)
julia> @btime Zygote.gradient(f2, $(ones((1024,1024))));
28.697 ms (34 allocations: 72.00 MiB)