Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

`cispi` has poor performance

Open roflmaostc opened this issue 3 years ago • 5 comments

Hi!

cispi seems to have type infereability issues:

f1(x) = sum(abs2, cispi.(x))
f2(x) = sum(abs2, cis.(x))

@time Zygote.gradient(f1, ones((1024,1024)))
# 1.121823 seconds (18.87 M allocations: 648.002 MiB, 31.55% gc time)

@time Zygote.gradient(f2, ones((1024,1024)))
#   0.066483 seconds (40 allocations: 168.001 MiB, 30.13% gc time)

Zygote uses ChainRules, doesn't it? There is this rule for sincospi implemented which should be called by cispi.

Best,

Felix

roflmaostc avatar Aug 01 '22 16:08 roflmaostc

It should use the rule, but I suspect something else is causing broadcasting to take a slow path here (you can guess because that path requires a couple of allocations per element to store pullbacks, and x1 million is not far off the reported amount)

ToucheSir avatar Aug 01 '22 17:08 ToucheSir

Ah, this is the same issue as https://github.com/FluxML/Zygote.jl/issues/961#issuecomment-871084267 because cis and cispi return complex numbers.

ToucheSir avatar Aug 01 '22 22:08 ToucheSir

But cishas good performance?

And the linked issue was only for CUDA

roflmaostc avatar Aug 02 '22 07:08 roflmaostc

I would not say it has good performance. Even though the memory allocations are better, they and overall runtime are strictly worse than a function with a more "expensive" forward pass.

manyexp(x) = exp(x) * exp(x) + exp(x)
x = ones(1024, 1024)

julia> @btime cis.($x);
  10.829 ms (2 allocations: 16.00 MiB)

julia> @btime manyexp.($x);
  17.025 ms (2 allocations: 8.00 MiB)

julia> @btime gradient(x -> sum(abs2, cis.(x)), $x);
  78.728 ms (38 allocations: 160.00 MiB)

julia> @btime gradient(x -> sum(abs2, manyexp.(x)), $x);
  22.156 ms (31 allocations: 40.00 MiB)

It just so happens that the pullback for cis is simple enough to not capture any values and thus be inlined even on the slow path. cispi's pullback is not so fortunate (I think because of the reverse and splat?) and so has to be materialized for every array element.

And the linked issue was only for CUDA

The linked comment discusses a general-purpose solution that works for the CPU path as well. I'm not sure what level of effort would be required to implement it.

ToucheSir avatar Aug 03 '22 04:08 ToucheSir

I recently discovered https://github.com/JuliaDiff/ForwardDiff.jl/pull/583, which enables real -> complex function differentiation in ForwardDiff. @mcabbott do you think that could be adapted to what Zygote does for broadcast?

ToucheSir avatar Aug 21 '22 01:08 ToucheSir

Closed in #1324

julia> f1(x) = sum(abs2, cispi.(x))

julia> f2(x) = sum(abs2, cis.(x))

julia> @btime Zygote.gradient(f1, $(ones((1024,1024))));
  29.061 ms (34 allocations: 72.00 MiB)

julia> @btime Zygote.gradient(f2, $(ones((1024,1024))));
  28.697 ms (34 allocations: 72.00 MiB)

CarloLucibello avatar Jan 10 '23 17:01 CarloLucibello