tskit icon indicating copy to clipboard operation
tskit copied to clipboard

Remove general_stat cacheing

Open brieuclehmann opened this issue 3 years ago • 17 comments
trafficstars

Here's a first pass at removing cacheing to reduce memory consumption in the general_stat framework. I've only tweaked the branch C function, and the python tests are passing.

Without caching, however, ts.genetic_relatedness is still slower. For the script below, with caching takes about 10s, without caching: 20s. When I increased n_samples to 1000: with caching - 58s; without caching - 110s.

I may have introduced some unnecessary computation so would appreciate a sense-check!

import itertools
import time
import msprime
import numpy as np
import tskit
from memory_profiler import profile

fp = open("memprof_cache.log", "a+")
@profile(stream=fp)
def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    return x

seed = 42
n_samples = 500
ts = msprime.simulate(
    n_samples,
    length=1e6,
    recombination_rate=1e-8,
    Ne=1e4,
    mutation_rate=1e-7,
    random_seed=seed,
)

n_ind = int(ts.num_samples / 2)
sample_sets = [(2 * i, (2 * i) + 1) for i in range(n_ind)]

n = len(sample_sets)
indexes = [
    (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
]

start_time = time.time()
x = genetic_relatedness_matrix(ts, sample_sets, indexes, 'branch')
end_time = time.time()
print(end_time - start_time)

brieuclehmann avatar Nov 24 '21 09:11 brieuclehmann

Codecov Report

Merging #1937 (0b96818) into main (315e47a) will decrease coverage by 13.58%. The diff coverage is 100.00%.

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #1937       +/-   ##
===========================================
- Coverage   93.15%   79.57%   -13.59%     
===========================================
  Files          27       27               
  Lines       25085    24959      -126     
  Branches     1109     1107        -2     
===========================================
- Hits        23369    19862     -3507     
- Misses       1682     5036     +3354     
- Partials       34       61       +27     
Flag Coverage Δ
c-tests 92.20% <100.00%> (+<0.01%) :arrow_up:
lwt-tests 89.14% <ø> (ø)
python-c-tests 68.29% <ø> (ø)
python-tests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
c/tskit/trees.c 94.19% <100.00%> (+0.01%) :arrow_up:
python/tskit/cli.py 0.00% <0.00%> (-95.94%) :arrow_down:
python/tskit/formats.py 6.08% <0.00%> (-93.66%) :arrow_down:
python/tskit/vcf.py 9.78% <0.00%> (-90.22%) :arrow_down:
python/tskit/drawing.py 10.41% <0.00%> (-89.02%) :arrow_down:
python/tskit/text_formats.py 12.62% <0.00%> (-87.38%) :arrow_down:
python/tskit/combinatorics.py 13.70% <0.00%> (-85.66%) :arrow_down:
python/tskit/stats.py 35.13% <0.00%> (-64.87%) :arrow_down:
python/tskit/trees.py 43.96% <0.00%> (-53.84%) :arrow_down:
python/tskit/util.py 59.56% <0.00%> (-40.44%) :arrow_down:
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 315e47a...0b96818. Read the comment docs.

codecov[bot] avatar Nov 24 '21 09:11 codecov[bot]

Any idea how much this drops RAM use? For me, the case that motivated opening the issue on RAM use was computing the pairwise distance matrix. It blew up to 7+ GB for 1000s of nodes.

molpopgen avatar Nov 24 '21 18:11 molpopgen

We had a discussion about this earlier and @brieuclehmann is going to run it through memory-profiler. Assuming we're on the right track memory reduction wise, the plan is to add an option to general_stat to disable caching like this.

jeromekelleher avatar Nov 24 '21 19:11 jeromekelleher

Have used memory-profiler to check RAM usage (see updated script above). Oddly, there doesn't appear to be too much of a difference with and without caching, but I'm not 100% sure I'm reading the output correctly (or have set up the profiling properly).

memprof_cache.log memprof_nocache.log )

brieuclehmann avatar Nov 25 '21 15:11 brieuclehmann

Is memory-profiler measuring actual process memory, or just memory occupied by Python object? The caching is happening on the C side, and so may not be visible? If you are on Linux, then this may be more relevant:

/usr/bin/time -f "%e %M" python script.py

The second number output will be the peak RAM use in KB.

molpopgen avatar Nov 25 '21 18:11 molpopgen

Is memory-profiler measuring actual process memory,

Yes, it takes periodic snapshots using OS routines, so it's the "real" process footprint.

jeromekelleher avatar Nov 26 '21 09:11 jeromekelleher

Pasting in the profiles for ease:

ilename: test_time.py
  
Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     65.5 MiB     65.5 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, mode):
    11     65.5 MiB      0.0 MiB           1       n = len(sample_sets)
    12     67.6 MiB    -46.8 MiB       31379       indexes = [
    13     67.6 MiB    -44.6 MiB       31376           (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
    14                                             ]
    15     67.6 MiB      0.0 MiB           1       K = np.zeros((n, n))
    16     74.9 MiB      7.3 MiB           2       K[np.triu_indices(n)] = ts.genetic_relatedness(
    17     67.6 MiB      0.0 MiB           1           sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
    18                                             )
    19     74.9 MiB      0.0 MiB           1       K = K + np.triu(K, 1).transpose()
    20     74.9 MiB      0.0 MiB           1       return K

Filename: test_time.py
  
Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     64.6 MiB     64.6 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, mode):
    11     64.6 MiB      0.0 MiB           1       n = len(sample_sets)
    12     66.7 MiB    -69.3 MiB       31379       indexes = [
    13     66.7 MiB    -67.1 MiB       31376           (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
    14                                             ]
    15     66.7 MiB      0.0 MiB           1       K = np.zeros((n, n))
    16     66.7 MiB    -18.0 MiB           2       K[np.triu_indices(n)] = ts.genetic_relatedness(
    17     66.7 MiB      0.0 MiB           1           sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
    18                                             )
    19     48.8 MiB    -17.9 MiB           1       K = K + np.triu(K, 1).transpose()
    20     48.8 MiB      0.0 MiB           1       return K

jeromekelleher avatar Nov 26 '21 09:11 jeromekelleher

There isn't a huge difference either way here - maybe run this on a later example? Also do just one thing on the line where we call genetic_relatedness, the assignment stuff on the LHS will be complicating things. (so x = ts.genetic_related...)

jeromekelleher avatar Nov 26 '21 09:11 jeromekelleher

I simplified the code slightly (see edits above) so that it only profiles the ts.genetic_relatedness call. Here are the results with n_samples = 500 across 3 runs of the same script. Pretty similar overall usage between the two (not sure what's going on in line 11 of each?). Going to run again now with n_samples = 2000 With caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     67.3 MiB     67.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     58.3 MiB     -9.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     58.3 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     67.3 MiB     67.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     74.5 MiB      7.2 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     74.5 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     67.3 MiB     67.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     21.0 MiB    -46.3 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     21.0 MiB      0.0 MiB           1       return x

Without caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     68.6 MiB     68.6 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     56.6 MiB    -12.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     56.7 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     68.1 MiB     68.1 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     52.1 MiB    -16.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     52.1 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     68.3 MiB     68.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     75.4 MiB      7.1 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     75.4 MiB      0.0 MiB           1       return x

brieuclehmann avatar Nov 29 '21 20:11 brieuclehmann

Hm - notes:

  • update_node_summary and update_running_sum can be combined into one (I think), eliminating some extra copying around of results
  • it's a bit hard to know how well it's possible to do, since for mode="branch" we don't have something convenient to compare it to (although we could use Caoqi's eGRM, I guess?) - looking at mode="site" we could compare to the timings that Gregor reported
  • have you computed how much less memory we expect it to be using? Like, what's the difference in memory usage for the two calloc()'s with num_samples = 1000 and num_nodes equal to whatever it is?

petrelharp avatar Nov 29 '21 21:11 petrelharp

Thanks @petrelharp ! Agreed with your point about combining update_node_summary and update_running_sum - is it worth doing this now or shall I wait until we're sure we want to include a no caching option?

For expected change in memory usage, I guess it should be a num_nodes-fold decrease between the two callocs, so I'm pretty surprised that they are so similar. Is there a way to check where in the C code memory is being used? For context, when n_samples = 2000 in the above examples, we have num_nodes = 6068.

brieuclehmann avatar Dec 01 '21 20:12 brieuclehmann

And here the memory-profiler results for n_samples = 2000, TL;DR largely the same again

With caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    111.1 MiB    111.1 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     31.0 MiB    -80.1 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     31.1 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    110.9 MiB    110.9 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     33.9 MiB    -77.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     34.0 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    111.3 MiB    111.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     27.8 MiB    -83.4 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     27.9 MiB      0.1 MiB           1       return x

Without caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    109.9 MiB    109.9 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     56.3 MiB    -53.6 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     56.4 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    110.0 MiB    110.0 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     54.0 MiB    -56.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     54.1 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    110.8 MiB    110.8 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     31.6 MiB    -79.2 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     31.7 MiB      0.1 MiB           1       return x

brieuclehmann avatar Dec 01 '21 20:12 brieuclehmann

For context, when n_samples = 2000 in the above examples, we have num_nodes = 6068.

Ah, I see what's happening here - your tree sequence is quite short, so there's not that many nodes. Try increasing the sequence length or recombination rate so that you have at least 100K trees. Then you should see some difference.

jeromekelleher avatar Dec 02 '21 10:12 jeromekelleher

Ah OK, now with n_samples = 500 and recombination_rate = 1e-7 corresponding to ~250K trees and ~150k nodes. We now see a slight improvement without caching.

With caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    215.7 MiB    215.7 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     26.6 MiB   -189.1 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     26.7 MiB      0.1 MiB           1       return x

Without caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    197.3 MiB    197.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     36.3 MiB   -161.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     36.4 MiB      0.1 MiB           1       return x

brieuclehmann avatar Dec 02 '21 19:12 brieuclehmann

Still surprisingly little...

jeromekelleher avatar Dec 03 '21 09:12 jeromekelleher

ping @jeromekelleher

petrelharp avatar Feb 08 '22 20:02 petrelharp

Ah - I don't think this is computing the right value @brieuclehmann. Tests are failing because we're getting different numbers.

I'll think about how to do this again.

jeromekelleher avatar Feb 22 '22 17:02 jeromekelleher