tskit icon indicating copy to clipboard operation
tskit copied to clipboard

"matrix multiplication" statistic

Open petrelharp opened this issue 4 years ago • 25 comments

Here's a draft of the C code to compute covariances of aribtrary weighted sums of genotypes. Borrowing from @brieuclehmann's code, this should always be true:

def genetic_relatedness_matrix(ts, sample_sets, mode):
    n = len(sample_sets)
    indexes = [(n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)]
    K = np.zeros((n,n))
    K[np.triu_indices(n)] = ts.genetic_relatedness(sample_sets, indexes, mode = mode, proportion=False, span_normalise=False)
    K = K + np.triu(K,1).transpose()
    return K

ts = msprime.simulate(10, length=1e6, recombination_rate=1e-8, Ne=1e4, mutation_rate=1e-8)
W = np.random.normal(size=2 * ts.num_samples).reshape((ts.num_samples, 2))
K = genetic_relatedness_matrix(ts, [[n] for n in ts.samples()], mode='site')
A = W.T @ K @ W
B = ts.genetic_relatedness_weighted(W, indexes=[(0,0), (0, 1), (1, 0), (1,1)], mode='site', span_normalise=False).reshape((2,2))
assert np.allclose(A, B)

(... and, it is, happy day!)

TODO: input checking, and tests. Tests should be straightforward, as we can just check for the property above, but hooking it in to the testing code will take some doing.

Also TODO:

  • is this a good name? a good python API? I also thought about just making weights an argument to genetic_relatedness, which would call one method or the other depending on whether sample_sets or weights were present. This could be a general pattern, even.
  • maybe there's some tidying to do at the C level, in terms of refactoring the various general stats computing helpers? but this could be put off.

petrelharp avatar Mar 16 '21 05:03 petrelharp

Codecov Report

Merging #1246 (039a3b8) into main (555913c) will increase coverage by 3.50%. The diff coverage is 80.50%.

:exclamation: Current head 039a3b8 differs from pull request most recent head 6c144d8. Consider uploading reports for the commit 6c144d8 to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1246      +/-   ##
==========================================
+ Coverage   89.89%   93.39%   +3.50%     
==========================================
  Files          30       28       -2     
  Lines       29043    27464    -1579     
  Branches     5672     1259    -4413     
==========================================
- Hits        26107    25649     -458     
- Misses       1667     1781     +114     
+ Partials     1269       34    -1235     
Flag Coverage Δ
c-tests 92.26% <91.89%> (+5.86%) :arrow_up:
lwt-tests 89.05% <ø> (+8.91%) :arrow_up:
python-c-tests 71.44% <75.30%> (+4.39%) :arrow_up:
python-tests 98.67% <10.00%> (-0.35%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
python/tskit/trees.py 97.82% <10.00%> (-0.83%) :arrow_down:
c/tskit/trees.c 94.94% <91.89%> (+4.16%) :arrow_up:
python/_tskitmodule.c 91.12% <96.72%> (+2.71%) :arrow_up:

... and 20 files with indirect coverage changes


Continue to review full report in Codecov by Sentry.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 555913c...6c144d8. Read the comment docs.

codecov[bot] avatar Mar 16 '21 06:03 codecov[bot]

Very cool!

is this a good name? a good python API? I also thought about just making weights an argument to genetic_relatedness, which would call one method or the other depending on whether sample_sets or weights were present. This could be a general pattern, even.

I can't think of a better one, but I'd be +1 on adding a weights argument to genetic_relatedness. We'd need a different name though - is this the same "weight" as the W arg in the general_stat method? I can see this being confusing.

jeromekelleher avatar Mar 16 '21 06:03 jeromekelleher

is this the same "weight" as the W arg in the general_stat method?

It is the same. So, maybe the argument should be called W.

I think we should add a genetic_relatedness_matrix method as well as a convenience, but you probably had that in mind too?

Yeah, I guess? We still haven't gotten around to writing the divergence_matrix method either.

petrelharp avatar Mar 16 '21 12:03 petrelharp

Just been playing around with this and it seems like there might be some numerical stability issues. I tried @petrelharp 's code above with different seeds and occasionally get an AssertionError. Actually, the SUPER WEIRD thing is that sometimes it works with the same seed and sometimes it doesn't... i.e. I run exactly the same code twice and the first time it will be fine but the second time returns, e.g., B = [inf, inf ; inf, inf]. Has anyone seen anything like this before? This is on my MacBook btw.

def genetic_relatedness_matrix(ts, sample_sets, mode):
    n = len(sample_sets)
    indexes = [(n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)]
    K = np.zeros((n,n))
    K[np.triu_indices(n)] = ts.genetic_relatedness(sample_sets, indexes, mode = mode, proportion=False, span_normalise=False)
    K = K + np.triu(K,1).transpose()
    return K

seed = 42
ts = msprime.simulate(10, length=1e6, recombination_rate=1e-8, Ne=1e4, mutation_rate=1e-8, random_seed=seed)
np.random.seed(seed)
W = np.random.normal(size=2 * ts.num_samples).reshape((ts.num_samples, 2))
K = genetic_relatedness_matrix(ts, [[n] for n in ts.samples()], mode='site')
A = W.T @ K @ W
B = ts.genetic_relatedness_weighted(W, indexes=[(0,0), (0, 1), (1, 0), (1,1)], mode='site', span_normalise=False).reshape((2,2))
assert np.allclose(A, B)

It seems to work fine on rescomp (the Oxford BDI computing cluster, which runs linux). Though this is using python v3.7.10 compared to v3.8.6. Also, the K matrices are ever so slightly different between the two machines.

Bizarre... I'll just develop on rescomp for now.

brieuclehmann avatar May 29 '21 11:05 brieuclehmann

OK, still some bizarre behaviour on rescomp. See the following code to perform matrix-vector multiplication using the same tree sequence as above. I am using the old code for comparison as mat_mul then mat_mul_stat for the call to genetic_relatedness_weighted. On some occasions it returns the right output, on others not:

n_ind = int(ts.num_samples / 2)
sample_sets = [(2 * i, (2 * i) + 1) for i in range(n_ind)]
W_samples = np.array([[float(u in A) for A in sample_sets] for u in ts.samples()])
indexes = [(i, n_ind) for i in range(len(sample_sets))]
n = np.array([len(x) for x in sample_sets])
n_total = sum(n)

def mat_mul_stat(a):
    W = np.c_[W_samples, W_samples @ a]
    return ts.genetic_relatedness_weighted(
        W, indexes=indexes, mode="site", span_normalise=False
    )

def mat_mul(a):
    W = np.c_[W_samples, W_samples @ a]
    wts = np.sum(W, axis=0)
    # summary function
    def f(x):
        mx = np.sum(x[0:n_ind]) / n_total
        return np.array(
            [(x[i] - wts[i] * mx) * (x[j] - wts[j] * mx) / 2 for i, j in indexes]
        )
    return ts.general_stat(W, f, len(indexes), span_normalise=False)

np.random.seed(seed)
a = np.random.randn(n_ind)
K = genetic_relatedness_matrix(ts, sample_sets, "site")
print(mat_mul_stat(a))
print(K @ a)
print(mat_mul(a))
print(mat_mul_stat(a))
print(mat_mul_stat(a))

Returns:

[  73.45122395 -205.96299977   53.02061499  435.5723698  -352.63344965]
[  73.45122395 -205.96299977   53.02061499  432.12461049 -352.63344965]
[  73.45122395 -205.96299977   53.02061499  432.12461049 -352.63344965]
[  41.41281634 3375.49081761 3669.81766139 4035.61057064 3232.49239152]
[-842.64551716  690.48841404  967.14364331 1339.59209569  545.65397605]

brieuclehmann avatar May 29 '21 12:05 brieuclehmann

Yep, I'm seeing that too. Bizarrely, only in an interactive shell,not at the command line. It must be that I've done something bad with memory management. I'm not sure how to track that down - can we run valgrind on the python tests?

petrelharp avatar May 30 '21 03:05 petrelharp

I'm also getting the problem when I run the above code in the command line via a script... Never used valgrind before but could look into it if that would help?

brieuclehmann avatar Jun 03 '21 15:06 brieuclehmann

I'm not sure how to track that down - can we run valgrind on the python tests?

No, Python gets in the way. We'd have to reproduce the error as a C test case to be able to use valgrind.

I have a bit of time in the morning, I could take a look if someone sends me a script to reproduce?

jeromekelleher avatar Jun 03 '21 16:06 jeromekelleher

#!/usr/bin/python3
import msprime, tskit
import numpy as np
import itertools

def genetic_relatedness_matrix(ts, sample_sets, mode):
    n = len(sample_sets)
    indexes = [(n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)]
    K = np.zeros((n,n))
    K[np.triu_indices(n)] = ts.genetic_relatedness(sample_sets, indexes, mode = mode, proportion=False, span_normalise=False)
    K = K + np.triu(K,1).transpose()
    return K


ts = msprime.simulate(10, length=1e6, recombination_rate=1e-8, Ne=1e4, mutation_rate=1e-8, random_seed=42)

for seed in range(10):
    np.random.seed(seed)
    W = np.random.normal(size=2 * ts.num_samples).reshape((ts.num_samples, 2))
    K = genetic_relatedness_matrix(ts, [[n] for n in ts.samples()], mode='site')
    A = W.T @ K @ W
    B = ts.genetic_relatedness_weighted(W, indexes=[(0,0), (0, 1), (1, 0), (1,1)], mode='site', span_normalise=False).reshape((2,2))
    print('A:', A)
    print('B:', B)
    assert np.allclose(A, B)

petrelharp avatar Jun 03 '21 16:06 petrelharp

Thanks @petrelharp, this is a weird one. Reliably fails on the 10th seed, but not if we just run that seed on its own. I'll poke around, but it nearly has to be a memory management error.

jeromekelleher avatar Jun 03 '21 16:06 jeromekelleher

I think it's worthwhile putting in some tests to cover the remaining code paths now - this will have to be done later anyway, and it might unearth the problem we're having here. Stuff covering error paths in the Python C stuff is particularly important to exercise, subtle problems can manifest themselves in all sorts of unexpected ways and lots of time can be lost tracking them down.

jeromekelleher avatar Jun 03 '21 16:06 jeromekelleher

That sounds good. I'm pretty swamped right now, might you be able to have a go at the tests, @brieuclehmann?

petrelharp avatar Jun 03 '21 17:06 petrelharp

Yep, I'd be happy to! Please could you point me towards a relevant function/method that I could use as a template?

brieuclehmann avatar Jun 03 '21 17:06 brieuclehmann

Thanks @brieuclehmann but it's probably better if I put in a commit to test out the code paths I'm thinking about - the testing code is quite convoluted for this stuff, and it would be a good way for me to start getting up to speed on what the function is doing anyway. I'll have a go as soon as I can (hopefully tomorrow AM)

jeromekelleher avatar Jun 03 '21 17:06 jeromekelleher

I think that should do it - I think the issue was that the total_weights array was being malloced rather than calloced (and therefore never being initialised to 0). Most of the time the memory happened to be zero already so it worked.

I don't understand how valgrind missed this though...

jeromekelleher avatar Jun 04 '21 11:06 jeromekelleher

I think the issue was that the total_weights array was being malloced rather than calloced (and therefore never being initialised to 0)

Doh! Good catch.

petrelharp avatar Jun 04 '21 15:06 petrelharp

I think we should try to get this merged @petrelharp - do you know what's needed to get it in? We don't need to have it fully documented, but it would be good to get it in to main so we can experiment.

jeromekelleher avatar Oct 04 '21 08:10 jeromekelleher

It needs, at least:

  • better documentation for things like what indexes defaults to
  • tests for genetic_relatedness_weighted to be added to tests/test_tree_stats.py

The latter should be easy because we can just compare its output to what we'd get by matrix-multiplying the genetic_relatedness matrix by the weight matrix (as in the definition). Pretty quick, but not quick enough for me to have done it just now...

petrelharp avatar Oct 07 '21 07:10 petrelharp

Do you think you could update this branch please @petrelharp? If you rebase and force push, then @brieuclehmann can open a PR against your fork with a some tests, and that should finish things up.

As far as I can see it's mainly simple interface tests needed here now?

jeromekelleher avatar Dec 08 '21 17:12 jeromekelleher

Ping @petrelharp - I'd be happy to write some tests and more documentation in the coming days so that we can get this merged, but I think I'm not able to rebase the branch myself?

brieuclehmann avatar Feb 22 '22 14:02 brieuclehmann

Done!

petrelharp avatar Feb 22 '22 21:02 petrelharp

FYI I just pushed an update fixing the build problems.

jeromekelleher avatar Feb 23 '22 10:02 jeromekelleher

Would it be possible to rebase this onto main? I'd like to use genetic_relatedness_weighted and decapitate in the same script, i.e. to do PCA on a decapitated tree sequence.

(Sorry for dropping the ball on writing the tests & documentation for this - I will get round to it v soon!)

brieuclehmann avatar Jul 26 '22 09:07 brieuclehmann

Uh-oh, a segfault - ping @jeromekelleher?

petrelharp avatar Jul 26 '22 19:07 petrelharp

Looks like a simple compile error because of API changes we made for C 1.0. Should be fixed.

jeromekelleher avatar Jul 27 '22 08:07 jeromekelleher

Closing this in favour of #2785 (which includes these commits)

jeromekelleher avatar Jul 13 '23 08:07 jeromekelleher