Chain rules for FFT plans via AdjointPlans
An rfft can be written as PF where F is the n x n Fourier transform and P is a projection operator that removes the redundant information due to conjuagate symmetry. Because of P, the adjoint of real FFTs (real inverse FFTs) require a special scaling before (after) applying the backwards transformation. As discussed in https://github.com/JuliaMath/AbstractFFTs.jl/issues/63 this motivates supporting the Base.adjoint operation for plans to simplify the writing of backward rules for AD.
The following functions must be implemented by backends in order for output_size(p::Plan) and AdjointPlan to work:
projection_style(p::Plan)which can either be:none,:real, or:real_inv.irfft_dim(p::Plan), only for those plans with:real_invprojection style, which gives the original length of the halved dimension.
Using the adjoint plan, we can simplify the writing of backwards rules. I test the adjoint plans both directly and indirectly through tests of the rrule's.
NB: The interface has changed since the initial PR message. See the updated implementation docs in the PR for accurate info.
Codecov Report
Patch coverage: 100.00% and project coverage change: +4.13 :tada:
Comparison is base (
b5109aa) 87.32% compared to head (e137ae3) 91.45%.
:exclamation: Current head e137ae3 differs from pull request most recent head e601347. Consider uploading reports for the commit e601347 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## master #67 +/- ##
==========================================
+ Coverage 87.32% 91.45% +4.13%
==========================================
Files 3 3
Lines 213 281 +68
==========================================
+ Hits 186 257 +71
+ Misses 27 24 -3
| Impacted Files | Coverage Δ | |
|---|---|---|
| ext/AbstractFFTsChainRulesCoreExt.jl | 100.00% <100.00%> (ø) |
|
| src/definitions.jl | 83.33% <100.00%> (+9.29%) |
:arrow_up: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
This should be ready for another review (with https://github.com/JuliaMath/AbstractFFTs.jl/pull/69 as a dependency)
I've got test_frule and test_rrule to work for ScaledPlan's too (and fixed some bugs with the rule in the process), with some minimal type piracy (see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256). So things should be well-tested now.
I guess this is still waiting for #78? I think it would be good to also make ChainRulesCore a weak dependency on new Julia versions (see, e.g., how it is done in SpecialFunctions).
@gaurav-arya you might want to rebase after https://github.com/JuliaMath/AbstractFFTs.jl/commit/3a3f0e4ebc3a9ab81bfb655003759ffa9c6bd1bb.
Question: what would the projection style be for real-to-real DCTs/ DSTs in FFTW.jl?
I updated the PR (and fixed a few issues with types and functions that were not available in the extension).
Since in-place plans are not supported (or at least not tested?) currently, maybe a final thing to add would be to check in the ChainRules definitions that y = P * x is not aliased with x, and throw a descriptive error otherwise.
I don't have the time right now to revisit this PR, but if it looks good and would be helpful, please feel free to fix anything remaining and merge. Thanks!
@sethaxen since you were interested in differentiation rules as well (#103), I guess it might be valuable if you would have a look at the PR before I go ahead and merge it?
Sure @devmotion, will take a quick look tonight.
Now that adjoints are defined at the *(::Plan, ::AbstractArray) level, can the rules for fft, rfft etc be removed?
Now that adjoints are defined at the *(::Plan, ::AbstractArray) level, can the rules for fft, rfft etc be removed?
I'd argue no, they should be kept. Zygote also defined rules for fft etc. and plan_fft, but the IMO the main argument is that fft etc. are used for one-shot computations of FFTs whereas plan_fft etc. are intended for repeated calculations - so downstream packages might want to implement fft etc. in an optimized way knowing that there's no amortization, users might want to rather call fft etc. if they only apply it once, and its the only way for the rules to distinguish between both use cases.