KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

Known AD Failures

Open theogf opened this issue 4 years ago • 27 comments

Here is a list of the failures in the tests made in #114 I observed with the different ADs : ForwardDiff.jl, Zygote.jl and ReverseDiff.jl :

  • [ ] FBMKernel : ForwardDiff/Zygote
  • [ ] GaborKernel : Zygote (ForwardDiff+ReverseDiff randomly)
  • [ ] MaternKernel : Zygote
  • [ ] SpectralMixture : All
  • [ ] PeriodicKernel : Zygote (need to define an adjoint) #531, ForwardDiff and ReverseDiff fail randomly (edge cases might need to be looked at) #389
  • [ ] Wiener : All
  • [x] GammaExponentialKernel : Zygote
  • [x] NeuralNetworkKernel : Zygote
  • [x] PiecewisePolynomialKernel : All
  • [x] PolynomialKernel : All
  • [x] GammaRationalQuadratic : All
  • [x] KernelProduct : Zygote
  • [x] KernelSum : Zygote
  • [x] FunctionTransform : Zygote
  • [ ] ChainTransform : Zygote #415

This is a good starting point to try to find solutions

theogf avatar May 16 '20 17:05 theogf

The Zygote problems with MaternKernel are caused by the fact that the partial derivative of besselk with respect to the first argument is defined as NaN in https://github.com/JuliaDiff/ChainRules.jl/blob/98c54587257b86cce6eb45f7870a75f897058d21/src/rulesets/packages/SpecialFunctions.jl#L46-L47 (and I assume the same problem exists for the other AD backends, since I get NaN for all of them when I try to run the commented out AD tests). I guess one would have to implement https://dlmf.nist.gov/10.38 to fix it.

devmotion avatar Jun 16 '20 08:06 devmotion

Haha writing these derivatives sounds like one should write a whole package about bessel functions

theogf avatar Jun 16 '20 08:06 theogf

@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.

yebai avatar Jun 24 '20 12:06 yebai

BTW I found some publication from 2016 with closed-form expressions of the derivatives of the Bessel functions with respect to the order. I opened an issue at https://github.com/JuliaDiff/ChainRules.jl/issues/208 to discuss how one would deal with the additional dependencies needed for their implementations (they contain hypergeometric functions).

devmotion avatar Jun 24 '20 12:06 devmotion

We might want to refactor KernelSum and KernelProd (making them concretely typed and allowing both tuples and vectors of kernels similar to TensorProduct, and probably removing the weights in KernelSum) before fixing any AD issues there.

devmotion avatar Jun 24 '20 12:06 devmotion

Agreed! There is also a general AD issue when using Transform where the pullback on ColVecs and RowVecs return a vector of vectors, this would tick off a good portions of the issues.

theogf avatar Jun 24 '20 12:06 theogf

@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.

Sorry for the late reply. I somehow didn't get a notification this comment. Randomly found this while browsing the issues. I am looking into it.

sharanry avatar Aug 02 '20 17:08 sharanry

The probable reason Zygote fails for FunctionTransform is the usage of Base.mapslices in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/0df9e8352d7d9f034f75de4d3639c0bb3b96c714/src/transform/functiontransform.jl#L19-L20.

Base.mapslices is mutating the array. Not sure why. https://github.com/JuliaLang/julia/pull/17266 should have fixed this.

julia> Zygote._pullback(x-> mapslices(x->sin.(x), x, dims=1), rand(3,3))[2](ones(3,3))
ERROR: Mutating arrays is not supported

sharanry avatar Aug 08 '20 10:08 sharanry

mapslices still mutates a temporary array. The linked PR just ensures that

mapslices never modifies the input array. It allocates temporary storage and copies each slice into it before calling the user-function.

devmotion avatar Aug 08 '20 10:08 devmotion

mapslices still mutates a temporary array. The linked PR just ensures that

Oh makes sense. Do you see any other efficient way to apply a function transform for a matrix/ColVecs/RowVecs?

sharanry avatar Aug 08 '20 10:08 sharanry

I see the following possibilities here:

  • define a custom adjoint for _map(::FunctionTransform, ...)
  • define a custom adjoint for mapslices in ChainRules (see https://github.com/FluxML/Zygote.jl/issues/92)
  • use SliceMap (see https://github.com/FluxML/Zygote.jl/issues/92)
  • rewrite the lines as
function _map(t::FunctionTransform, x::ColVecs)
    vals = map(axes(x.X, 2)) do i
        t.f(view(x.X, :, i))
    end
    return ColVecs(vals)
end

(Zygote should support this automatically)

devmotion avatar Aug 08 '20 10:08 devmotion

I just ran a quick benchmark for one other possibility which would require us to define adjoint for a generator. The methods you mentioned are probably better.

julia> @btime hcat(map(x->sin.(x), (eachslice(rand(1000,1000); dims=1)))...)
  16.586 ms (2015 allocations: 23.09 MiB)
julia> @btime mapslices(x->sin.(x), rand(1000,1000); dims=1)
  12.189 ms (7505 allocations: 23.18 MiB)

sharanry avatar Aug 08 '20 11:08 sharanry

A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use mapreduce(x -> sin.(x), hcat, ...). For benchmarks you also want to use $(rand(1000, 1000)) (in that way the timings are unaffected by the calls of rand).

devmotion avatar Aug 08 '20 11:08 devmotion

A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use mapreduce(x -> sin.(x), hcat, ...). For benchmarks you also want to use $(rand(1000, 1000)) (in that way the timings are unaffected by the calls of rand).

Thanks! Wasn't aware of this. This however gave unexpected results for mapreduce.

julia> @btime mapslices(x->sin.(x), $(rand(1000,1000)); dims=1);
  10.581 ms (7503 allocations: 15.55 MiB)

julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=1));
  914.970 ms (5002 allocations: 3.74 GiB)

sharanry avatar Aug 08 '20 11:08 sharanry

Shouldn't you use eachslice(...; dims=2) or eachcol?

devmotion avatar Aug 08 '20 11:08 devmotion

I don't think it is making much difference performance wise at least.

julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=2));
  952.564 ms (5002 allocations: 3.74 GiB)

sharanry avatar Aug 08 '20 12:08 sharanry

Can you check if the function is typestable? I suspect it might not, which would explain the number of allocations. The problem might be that it returns a different type if eachslice(...) is empty. Specifying an init kwarg might be helpful.

devmotion avatar Aug 08 '20 12:08 devmotion

Just for the record -- we should be using ChainRulesCore to define pullbacks for Zygote, and ChainRulesTestUtils to test those implementations -- see e.g. here for example usage.

Plans are in the works to transfer both Tracker and ReverseDiff over to use ChainRules at some point (we know how we're going to do it, just waiting for code to get written), so this will future-proof AD in the package.

willtebbutt avatar Aug 10 '20 07:08 willtebbutt

I can't figure out how to define mapslices adjoint only in ChainRulesCore. If f is an anonymous function, how can we get its backward (rrule) from ChainRulesCore, while it can be done in Zygote using gradient?

yiyuezhuo avatar Aug 29 '20 20:08 yiyuezhuo

KernelFunctions doesn't use mapslices anymore, so for this projects custom adjoints for mapslices are not required anymore. Nevertheless, an implementation of an adjoint of mapslices in ChainRules requires a solution to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68 AFAICT.

devmotion avatar Aug 29 '20 21:08 devmotion

I see. I checked the source of the last release to backport TransformedKernel to Stheno as Stheno doesn't support KernelFunctions and found those mapslices code. But after thinking how to implement it in ChainRules or Zygote, I just disable a check in Stheno to re-enable gradient of f.(ColVecs(X)) since I feel ChainRules or Zygote will not too much difference.

yiyuezhuo avatar Aug 30 '20 01:08 yiyuezhuo

The latest releases don't use mapslices, it was replaced in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/152. I guess you can use something similar instead of mapslices in Stheno as well.

devmotion avatar Aug 30 '20 07:08 devmotion

Regarding FBMKernel not working with ForwardDiff. It seems to be producing NaN values incorrectly. According to the ForwardDiff documentation, the fix for this is to "enable ForwardDiff's NaN-safe mode by setting the NANSAFE_MODE_ENABLED constant to true in ForwardDiff's source". They are currently not allowing users to enable it dynamically [Issue].

sharanry avatar Sep 10 '20 13:09 sharanry

You can use https://github.com/JuliaDiff/ForwardDiff.jl/pull/451 if you do not want to edit the source code.

devmotion avatar Sep 10 '20 13:09 devmotion

Hi--it just occurs to me to share this here, but I recently finished a project for computing derivatives of besselk with respect to the order parameter precisely for the purpose of fitting Matern covariances (example Matern kernel implementation here). The strategy that worked ended up being a re-implementation of besselk in Julia that admits fast and accurate AD derivatives with ForwardDiff.jl. The re-implemented besselk itself is not quite as accurate as the AMOS one linked in SpecialFunctions, but the derivatives are pretty accurate. Not quite to machine double precision, but reasonably close. And very fast.

I'm not sure how helpful this is because the derivatives are at present pretty ForwardDiff-specific. I would guess that it would be possible to reach compatibility with other AD tools, perhaps at a slight cost of performance by eliminating some special branches in the current implementation, but I honestly don't understand how Zygote works at all so I can't promise it.

Anyways, just writing here in case it is helpful.

cgeoga avatar Feb 08 '22 15:02 cgeoga

I came across https://www.tandfonline.com/doi/pdf/10.1080/10652469.2016.1164156 a while ago, it contains closed-form expressions of the derivatives using e.g. hypergeometric functions. In principle these could be used with other AD backends as well but I don't know if there are any numerical problems, how slow/fast the evaluation with HypergeometricFunctions would be, and if (I assume not since it would introduce a circular dependency) SpecialFunctions would take a dependency on HypergeometricFunctions.

devmotion avatar Feb 08 '22 15:02 devmotion

I also saw that paper and was interested in just using that before undertaking a more from-scratch approach. But there are a few challenges with using the representations in Santander. For one, as you point out, evaluating the generalized hypergeometric functions like 3F4 and 2F3 is a task of comparable difficulty. I love HyperGeometricFunctions.jl, but that's a lot of pressure to put on that package, which at the very least in my experience is very slow when the besselk argument is small (which is unfortunately where the accurate derivatives matter the most). More importantly, though, the representation in Santander is hard in a bunch of edge cases. Like, when $\nu$ is an integer or near-integer, there are several problems, both with cancellations and in trig functions blowing up. If nu = 1 + 1e-8 or something that ostensibly exact equation might give literally zero digits of accuracy. The exact derivatives when nu + 1/2 is a whole integer are particularly gnarly and I've never seen them for any case besides nu=1/2.

Our project was enough of a hassle that we ended up writing a paper about it, and almost all the work was in handling the problems of $\nu$ being exactly or nearly an integer of half-integer. I don't think there's any way around a gnarly branching function to handle the derivatives in those cases. And if you look our timings (table one of the paper), it will probably be hard to come anywhere near those speeds at even comparable accuracy around those edge cases.

I've actually thought about asking the SpecialFunctions package folks if they'd be interested in some of our code being added to that package, but considering that we are a bit cavalier in giving up the last couple digits of accuracy I'm a bit concerned that it's not a great fit.

In any case, just posting here for your consideration. If somebody manages to implement them with exact expressions in a way that is tolerably fast and handles those edge cases, I'll be the first person to celebrate. In the mean time, though, I wouldn't be shocked if zygote compatibility was possible. I just really don't know enough to conjecture about how much of a project it would be.

cgeoga avatar Feb 08 '22 16:02 cgeoga