KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

Qualify pairwise call

Open Crown421 opened this issue 3 years ago • 10 comments

While trying to understand how to work with Coverage I was struck by some lines that stubbornly did not get covered. Initially I thought that was due to missing tests, but even adding those did not change the coverage.

After some effort, I have found that these functions don't get called because the default Distances.pairwise is already enough. I managed to get those lines covered by qualifying the pairwise call in the tests, but it seems they are not actually necessary? Maybe it makes sense to remove those definitions from src altogether?

Crown421 avatar Aug 17 '21 10:08 Crown421

Interesting. I'm intruiged to see precisely which lines have changed coverage in this PR (which I guess we'll see once the Format suggestions check passes?). If there is indeed no need for certain methods, I agree that it would make sense to remove them. With things like ColVecs / RowVecs, we definitely need something though, because Distances.jl doesn't know about their existance. Hard to say without seeing which lines now have coverage though.

willtebbutt avatar Aug 17 '21 10:08 willtebbutt

I have covered a few more lines, and run the formatter, so that should pass now, I hope.

I feel the solution could be to keep the tests in place, remove the qualifiers, and then also remove all non-covered lines. The test will then make sure that the functionality stays.

Crown421 avatar Aug 17 '21 10:08 Crown421

After some digging around in Distances, I discovered this PR https://github.com/JuliaStats/Distances.jl/pull/194 , that I had previously missed.

It resolves the long-standing issue that we had whereby we couldn't add some methods that we needed to Distances.pairwise without committing type-piracy. It resolves this my having added the methods that we needed. I suspect this change means that, at some point in the (hopefully not too distant future) we'll be able to dispense with KernelFunctions.pairwise. I suspect that this is why stuff now works.

What concerns me now is just AD-related stuff. Specifically, ChainRules doesn't have rules for Distances stuff (because ChainRules only contains rules for Base, Core, and the standard libraries), and Distances isn't receptive to accepting ChainRulesCore as a dependency (because it doesn't want to have to maintain AD-related functionality, which is reasonable). There's an on-going discussion about glue packages, that will hopefully get some kind of resolution in the short to medium term (https://github.com/JuliaLang/Pkg.jl/issues/1285), but it's not a thing yet, meaning that we can't fully dispense with KernelFunctions.pairwise at the minute.

This brings me around to what I think should be done:

  1. also add AD tests to the new methods that you've implemented, and
  2. then try to remove things, and see what breaks.

My suspicion is that the majority of the methods that we have will be necessary for AD to continue to work, and we're not currently testing that at all properly. The "correct" way to do that in the modern world is using https://juliadiff.org/ChainRulesTestUtils.jl/dev/#Testing-AD-systems .

This leaves a couple of ways forward:

  1. you extend this PR as per 1 and 2 above, or
  2. we merge this PR as-is, and open an issue about this, and deal with it at a later date.

I'm happy to go with either, it's just a question of what you've got the time / inclination to do.

willtebbutt avatar Aug 17 '21 13:08 willtebbutt

Very interesting. I have seen the issue on conditional dependencies, it does not seem that it will resolve very soon.

In principle I am open to add the proper tests, as a way to learn about AD. I have been using it, but don't really understand it in any way. In practice, I have no idea how difficult it will be. I have not added any methods, so I suppose it would mostly be writing new tests for existing functions (including figuring out which even need it).

Crown421 avatar Aug 17 '21 13:08 Crown421

I have not added any methods, so I suppose it would mostly be writing new tests for existing functions (including figuring out which even need it).

Indeed -- to be honest, it would make sense to test all of the methods (once you've got a function to test one of them, testing the others should be straightforward, at least in principle).

I think the way to go about it is just going to be to use the method I linked above (definitely don't try to roll your own testing code for AD). I suspect you can just use the rule config defined here: https://github.com/FluxML/Zygote.jl/blob/78bb9a3cad52de6e7c9a590d0f8ac4b6014a73f4/src/compiler/chainrules.jl#L4

willtebbutt avatar Aug 17 '21 14:08 willtebbutt

I noticed that some recent changes in Distances make it necessary to analyze more carefully which methods are called. Distances.pairwise is not owned by Distances anymore. Instead it is owned by StatsBase (even though technically it is defined in StatsAPI). The dangerous part is that StatsBase defines a very general fallback definition of pairwise that does not exploit the structure of the metrics and inputs. However, it is difficult to spot potential performance problems since (as intended) no errors or warnings are shown if the potentially slow fallback is hit. It even caused problems with Distances.PreMetric, for some inputs the fallback was used instead of the implementations in Distances (I don't remember all details but I added some fixes to Distances a while ago). So while tests might not fail if specific methods are removed we should carefully evaluate if the removal causes any performance regressions.

devmotion avatar Aug 17 '21 18:08 devmotion

After some exploration, I notice two issues:

  1. Since I am using Julia 1.7, ReverseDiff (and probably also ForwardDiff) segfault whenever I actually get to interesting parts. I understand this has something to do with a new ChainRules version?
  2. I am not sure I know what should be tested. I am trying to write AD tests for pairwise, for different input types (Vector of vectors, ColVecs, RowVecs ), but that seems to cause an issue with undefined zero. Am I actually on the right track?

Crown421 avatar Aug 17 '21 22:08 Crown421

Since I am using Julia 1.7, ReverseDiff (and probably also ForwardDiff) segfault whenever I actually get to interesting parts. I understand this has something to do with a new ChainRules version?

Oh interesting. I have no idea what to do about that. Would suggest using 1.6 if it's not obvious how to fix.

I am not sure I know what should be tested. I am trying to write AD tests for pairwise, for different input types (Vector of vectors, ColVecs, RowVecs ), but that seems to cause an issue with undefined zero. Am I actually on the right track?

Hmmm could you provide a MWE?

willtebbutt avatar Aug 18 '21 17:08 willtebbutt

I think I have the same problem that makes the CI fail all the way through for Julia nightly. I have switched to 1.3 for now.

For development, I have started by looking at the following:

A = [rand(3) for _ in 1:5]
B = [rand(3) for _ in 1:7]
norm(pairwise(SqEuclidean(), A, B))

ReverseDiff.gradient(a->norm(pairwise(SqEuclidean(), a, B)), A)

which yields

ERROR: LoadError: MethodError: no method matching zero(::Type{Array{Float64,1}})

With my limited understanding, I have been able to find this discourse discussion, but before going down that road, I want to make sure this is even what I should be looking at.

Crown421 avatar Aug 20 '21 16:08 Crown421

Sorry for the slow response @Crown421

I would suggest starting with Zygote, and ChainRulesTestUtils.

In particular, if you take a look at the docs on testing AD systems, and utilise the method implemented in Zygote and Zygote's config. I personally generally find this to be the most debuggable bit of testing infrastructure.

I'm imagining something along the lines of

f(A, B) = pairwise(SqEuclidean(), A, B)
test_rrule(Zygote.ZygoteConfig(), f, A, B; rrule_f=rrule_via_ad)

willtebbutt avatar Aug 25 '21 17:08 willtebbutt