tskit icon indicating copy to clipboard operation
tskit copied to clipboard

new (?) statistic to efficiently calculate shared times between every pair of samples at every tree

Open mmosmond opened this issue 3 years ago • 13 comments

Hi all, I've been messing around with different ways to calculate the shared time between every pair of samples in many trees, which is a useful metric because it describes the covariance of characteristics of the samples (e.g., traits, locations) under Brownian motion. At first I couldn't figure out how to use an existing statistic to do this, so I used the general_stat function to make one of my own:

import numpy as np
def shared_times(ts):
    """Use general_stat function to calculate shared branch lengths between all samples at all trees.
    
    Parameters
    ts: tskit tree sequence
    
    Returns
    A (n x k x k) numpy array, with n the number of trees in the sequence and k the number of sample nodes. 
    At each tree the (k x k) matrix gives the shared evolutionary times between each pair of sample nodes.
    """
    
    k = ts.num_samples #number of samples
    W = np.identity(k) #each node i in [0,1,...,k] given row vector of weights with 1 in column i and 0's elsewhere 
    def f(x): return (x.reshape(-1,1) * x).flatten() #determine which pairs of samples share the branch above a node, flattened

    return ts.general_stat(
        W, f, k**2, mode='branch', windows='trees', polarised=True, strict=False
    ).reshape(ts.num_trees, k, k)

For example:

import msprime
ts = msprime.sim_ancestry(samples=5, population_size=1e4, sequence_length=1e4, recombination_rate=1e-8, random_seed=1)
sts = shared_times(ts)

The result can also be achieved by simply looping through trees and samples to calculate pairwise TMRCAs:

def shared_times(ts):
    k = ts.num_samples
    sts = np.zeros((ts.num_trees,k,k))
    for t,tree in enumerate(ts.trees()):
        T = tree.time(tree.root)
        for i in range(k):
            for j in range(i):
                st = T - tree.tmrca(i,j)
                sts[t,i,j] = st
                sts[t,j,i] = st
    return sts

but this is considerably slower when there are many trees (although it seems to be faster when there are few trees with many samples, e.g., ts = msprime.sim_ancestry(samples=2e2, population_size=1e4, random_seed=1) -- not sure why).

Coming back to this much later (today), I realized you can use the divergence statistic, with a few tweaks, to do the same:

def shared_times(ts):
    
    k = ts.num_samples
    
    # get 2*tmrcas
    sample_sets = [[i] for i in ts.samples()]
    indexes = [(i,j) for i in ts.samples() for j in range(i,k)] #compare each sample with each other only once (entries of upper triangular)
    divs = ts.divergence(sample_sets=sample_sets, indexes=indexes, mode='branch', windows='trees')
    divs[np.isnan(divs)] = 0 #tmrcas with self are said to be nan, so convert to 0

    # convert to matrices
    divs_mat = np.zeros((ts.num_trees,k,k)) 
    for ix,(i,j) in enumerate(indexes):
        divs_mat[:,i,j] = divs[:,ix] #convert list to upper triangle
        divs_mat[:,j,i] = divs[:,ix] #fill in lower triangle symmetrically
    
    # convert to shared times
    sts = np.zeros(divs_mat.shape)
    for i,div in enumerate(divs_mat):
        sts[i] = (np.max(div) - div)/2 #convert from 2*tmrcas to shared times
    
    return sts

This latter method seems to be the fastest for small tree sequences (eg, ts = msprime.sim_ancestry(samples=5, population_size=1e4, sequence_length=1e6, recombination_rate=1e-8, random_seed=1)) but loses this advantage as tree sequences get larger (eg, msprime.sim_ancestry(samples=1e2, population_size=1e4, sequence_length=1e6, recombination_rate=1e-8, random_seed=1)). It also uses >2x the RAM of the first method, which quickly becomes important with larger tree sequences.

I'm wondering if my general_stat matrix formulation would be useful as a new statistic of its own: related to the divergence statistic but calculated differently (I think), sometimes faster, sometimes less memory intensive, and semantically simpler (for creating matrices). If so, I guess I'd have to also think about what happens as you vary the options in general_stat, e.g., mode='node'. Anyway, just wanted to put this out there in case it is helpful/useful. Thanks for building all this stuff!

(This is my first time writing an issue so sorry if I've messed up somehow -- I'd be keen to hear how to do this better!)

mmosmond avatar Aug 09 '21 20:08 mmosmond