AbstractFFTs.jl icon indicating copy to clipboard operation
AbstractFFTs.jl copied to clipboard

Chain rules for FFT plans via AdjointPlans

Open gaurav-arya opened this issue 3 years ago • 3 comments

An rfft can be written as PF where F is the n x n Fourier transform and P is a projection operator that removes the redundant information due to conjuagate symmetry. Because of P, the adjoint of real FFTs (real inverse FFTs) require a special scaling before (after) applying the backwards transformation. As discussed in https://github.com/JuliaMath/AbstractFFTs.jl/issues/63 this motivates supporting the Base.adjoint operation for plans to simplify the writing of backward rules for AD.

The following functions must be implemented by backends in order for output_size(p::Plan) and AdjointPlan to work:

  • projection_style(p::Plan) which can either be :none, :real, or :real_inv.
  • irfft_dim(p::Plan), only for those plans with :real_inv projection style, which gives the original length of the halved dimension.

Using the adjoint plan, we can simplify the writing of backwards rules. I test the adjoint plans both directly and indirectly through tests of the rrule's.

NB: The interface has changed since the initial PR message. See the updated implementation docs in the PR for accurate info.

gaurav-arya avatar Jun 06 '22 22:06 gaurav-arya

Codecov Report

Patch coverage: 100.00% and project coverage change: +4.13 :tada:

Comparison is base (b5109aa) 87.32% compared to head (e137ae3) 91.45%.

:exclamation: Current head e137ae3 differs from pull request most recent head e601347. Consider uploading reports for the commit e601347 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #67      +/-   ##
==========================================
+ Coverage   87.32%   91.45%   +4.13%     
==========================================
  Files           3        3              
  Lines         213      281      +68     
==========================================
+ Hits          186      257      +71     
+ Misses         27       24       -3     
Impacted Files Coverage Δ
ext/AbstractFFTsChainRulesCoreExt.jl 100.00% <100.00%> (ø)
src/definitions.jl 83.33% <100.00%> (+9.29%) :arrow_up:

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov[bot] avatar Jun 06 '22 22:06 codecov[bot]

This should be ready for another review (with https://github.com/JuliaMath/AbstractFFTs.jl/pull/69 as a dependency)

gaurav-arya avatar Jul 01 '22 05:07 gaurav-arya

I've got test_frule and test_rrule to work for ScaledPlan's too (and fixed some bugs with the rule in the process), with some minimal type piracy (see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256). So things should be well-tested now.

gaurav-arya avatar Aug 06 '22 22:08 gaurav-arya

I guess this is still waiting for #78? I think it would be good to also make ChainRulesCore a weak dependency on new Julia versions (see, e.g., how it is done in SpecialFunctions).

devmotion avatar Feb 25 '23 15:02 devmotion

@gaurav-arya you might want to rebase after https://github.com/JuliaMath/AbstractFFTs.jl/commit/3a3f0e4ebc3a9ab81bfb655003759ffa9c6bd1bb.

vpuri3 avatar Jun 26 '23 20:06 vpuri3

Question: what would the projection style be for real-to-real DCTs/ DSTs in FFTW.jl?

vpuri3 avatar Jun 26 '23 20:06 vpuri3

I updated the PR (and fixed a few issues with types and functions that were not available in the extension).

Since in-place plans are not supported (or at least not tested?) currently, maybe a final thing to add would be to check in the ChainRules definitions that y = P * x is not aliased with x, and throw a descriptive error otherwise.

devmotion avatar Jun 30 '23 12:06 devmotion

I don't have the time right now to revisit this PR, but if it looks good and would be helpful, please feel free to fix anything remaining and merge. Thanks!

gaurav-arya avatar Jul 01 '23 13:07 gaurav-arya

@sethaxen since you were interested in differentiation rules as well (#103), I guess it might be valuable if you would have a look at the PR before I go ahead and merge it?

devmotion avatar Jul 04 '23 14:07 devmotion

Sure @devmotion, will take a quick look tonight.

sethaxen avatar Jul 04 '23 14:07 sethaxen

Now that adjoints are defined at the *(::Plan, ::AbstractArray) level, can the rules for fft, rfft etc be removed?

vpuri3 avatar Jul 04 '23 22:07 vpuri3

Now that adjoints are defined at the *(::Plan, ::AbstractArray) level, can the rules for fft, rfft etc be removed?

I'd argue no, they should be kept. Zygote also defined rules for fft etc. and plan_fft, but the IMO the main argument is that fft etc. are used for one-shot computations of FFTs whereas plan_fft etc. are intended for repeated calculations - so downstream packages might want to implement fft etc. in an optimized way knowing that there's no amortization, users might want to rather call fft etc. if they only apply it once, and its the only way for the rules to distinguish between both use cases.

devmotion avatar Jul 04 '23 23:07 devmotion