tskit
tskit copied to clipboard
Remove general_stat cacheing
Here's a first pass at removing cacheing to reduce memory consumption in the general_stat framework. I've only tweaked the branch C function, and the python tests are passing.
Without caching, however, ts.genetic_relatedness is still slower. For the script below, with caching takes about 10s, without caching: 20s. When I increased n_samples to 1000: with caching - 58s; without caching - 110s.
I may have introduced some unnecessary computation so would appreciate a sense-check!
import itertools
import time
import msprime
import numpy as np
import tskit
from memory_profiler import profile
fp = open("memprof_cache.log", "a+")
@profile(stream=fp)
def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
return x
seed = 42
n_samples = 500
ts = msprime.simulate(
n_samples,
length=1e6,
recombination_rate=1e-8,
Ne=1e4,
mutation_rate=1e-7,
random_seed=seed,
)
n_ind = int(ts.num_samples / 2)
sample_sets = [(2 * i, (2 * i) + 1) for i in range(n_ind)]
n = len(sample_sets)
indexes = [
(n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
]
start_time = time.time()
x = genetic_relatedness_matrix(ts, sample_sets, indexes, 'branch')
end_time = time.time()
print(end_time - start_time)
Codecov Report
Merging #1937 (0b96818) into main (315e47a) will decrease coverage by
13.58%. The diff coverage is100.00%.
@@ Coverage Diff @@
## main #1937 +/- ##
===========================================
- Coverage 93.15% 79.57% -13.59%
===========================================
Files 27 27
Lines 25085 24959 -126
Branches 1109 1107 -2
===========================================
- Hits 23369 19862 -3507
- Misses 1682 5036 +3354
- Partials 34 61 +27
| Flag | Coverage Δ | |
|---|---|---|
| c-tests | 92.20% <100.00%> (+<0.01%) |
:arrow_up: |
| lwt-tests | 89.14% <ø> (ø) |
|
| python-c-tests | 68.29% <ø> (ø) |
|
| python-tests | ? |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Impacted Files | Coverage Δ | |
|---|---|---|
| c/tskit/trees.c | 94.19% <100.00%> (+0.01%) |
:arrow_up: |
| python/tskit/cli.py | 0.00% <0.00%> (-95.94%) |
:arrow_down: |
| python/tskit/formats.py | 6.08% <0.00%> (-93.66%) |
:arrow_down: |
| python/tskit/vcf.py | 9.78% <0.00%> (-90.22%) |
:arrow_down: |
| python/tskit/drawing.py | 10.41% <0.00%> (-89.02%) |
:arrow_down: |
| python/tskit/text_formats.py | 12.62% <0.00%> (-87.38%) |
:arrow_down: |
| python/tskit/combinatorics.py | 13.70% <0.00%> (-85.66%) |
:arrow_down: |
| python/tskit/stats.py | 35.13% <0.00%> (-64.87%) |
:arrow_down: |
| python/tskit/trees.py | 43.96% <0.00%> (-53.84%) |
:arrow_down: |
| python/tskit/util.py | 59.56% <0.00%> (-40.44%) |
:arrow_down: |
| ... and 4 more |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update 315e47a...0b96818. Read the comment docs.
Any idea how much this drops RAM use? For me, the case that motivated opening the issue on RAM use was computing the pairwise distance matrix. It blew up to 7+ GB for 1000s of nodes.
We had a discussion about this earlier and @brieuclehmann is going to run it through memory-profiler. Assuming we're on the right track memory reduction wise, the plan is to add an option to general_stat to disable caching like this.
Have used memory-profiler to check RAM usage (see updated script above). Oddly, there doesn't appear to be too much of a difference with and without caching, but I'm not 100% sure I'm reading the output correctly (or have set up the profiling properly).
Is memory-profiler measuring actual process memory, or just memory occupied by Python object? The caching is happening on the C side, and so may not be visible? If you are on Linux, then this may be more relevant:
/usr/bin/time -f "%e %M" python script.py
The second number output will be the peak RAM use in KB.
Is memory-profiler measuring actual process memory,
Yes, it takes periodic snapshots using OS routines, so it's the "real" process footprint.
Pasting in the profiles for ease:
ilename: test_time.py
Line # Mem usage Increment Occurences Line Contents
============================================================
9 65.5 MiB 65.5 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, mode):
11 65.5 MiB 0.0 MiB 1 n = len(sample_sets)
12 67.6 MiB -46.8 MiB 31379 indexes = [
13 67.6 MiB -44.6 MiB 31376 (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
14 ]
15 67.6 MiB 0.0 MiB 1 K = np.zeros((n, n))
16 74.9 MiB 7.3 MiB 2 K[np.triu_indices(n)] = ts.genetic_relatedness(
17 67.6 MiB 0.0 MiB 1 sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
18 )
19 74.9 MiB 0.0 MiB 1 K = K + np.triu(K, 1).transpose()
20 74.9 MiB 0.0 MiB 1 return K
Filename: test_time.py
Line # Mem usage Increment Occurences Line Contents
============================================================
9 64.6 MiB 64.6 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, mode):
11 64.6 MiB 0.0 MiB 1 n = len(sample_sets)
12 66.7 MiB -69.3 MiB 31379 indexes = [
13 66.7 MiB -67.1 MiB 31376 (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
14 ]
15 66.7 MiB 0.0 MiB 1 K = np.zeros((n, n))
16 66.7 MiB -18.0 MiB 2 K[np.triu_indices(n)] = ts.genetic_relatedness(
17 66.7 MiB 0.0 MiB 1 sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
18 )
19 48.8 MiB -17.9 MiB 1 K = K + np.triu(K, 1).transpose()
20 48.8 MiB 0.0 MiB 1 return K
There isn't a huge difference either way here - maybe run this on a later example? Also do just one thing on the line where we call genetic_relatedness, the assignment stuff on the LHS will be complicating things. (so x = ts.genetic_related...)
I simplified the code slightly (see edits above) so that it only profiles the ts.genetic_relatedness call. Here are the results with n_samples = 500 across 3 runs of the same script. Pretty similar overall usage between the two (not sure what's going on in line 11 of each?). Going to run again now with n_samples = 2000
With caching:
Line # Mem usage Increment Occurences Line Contents
============================================================
9 67.3 MiB 67.3 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 58.3 MiB -9.0 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 58.3 MiB 0.0 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 67.3 MiB 67.3 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 74.5 MiB 7.2 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 74.5 MiB 0.0 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 67.3 MiB 67.3 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 21.0 MiB -46.3 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 21.0 MiB 0.0 MiB 1 return x
Without caching:
Line # Mem usage Increment Occurences Line Contents
============================================================
9 68.6 MiB 68.6 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 56.6 MiB -12.0 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 56.7 MiB 0.0 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 68.1 MiB 68.1 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 52.1 MiB -16.0 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 52.1 MiB 0.0 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 68.3 MiB 68.3 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 75.4 MiB 7.1 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 75.4 MiB 0.0 MiB 1 return x
Hm - notes:
update_node_summaryandupdate_running_sumcan be combined into one (I think), eliminating some extra copying around of results- it's a bit hard to know how well it's possible to do, since for
mode="branch"we don't have something convenient to compare it to (although we could use Caoqi's eGRM, I guess?) - looking atmode="site"we could compare to the timings that Gregor reported - have you computed how much less memory we expect it to be using? Like, what's the difference in memory usage for the two calloc()'s with
num_samples = 1000andnum_nodesequal to whatever it is?
Thanks @petrelharp ! Agreed with your point about combining update_node_summary and update_running_sum - is it worth doing this now or shall I wait until we're sure we want to include a no caching option?
For expected change in memory usage, I guess it should be a num_nodes-fold decrease between the two callocs, so I'm pretty surprised that they are so similar. Is there a way to check where in the C code memory is being used? For context, when n_samples = 2000 in the above examples, we have num_nodes = 6068.
And here the memory-profiler results for n_samples = 2000, TL;DR largely the same again
With caching:
Line # Mem usage Increment Occurences Line Contents
============================================================
9 111.1 MiB 111.1 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 31.0 MiB -80.1 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 31.1 MiB 0.1 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 110.9 MiB 110.9 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 33.9 MiB -77.0 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 34.0 MiB 0.1 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 111.3 MiB 111.3 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 27.8 MiB -83.4 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 27.9 MiB 0.1 MiB 1 return x
Without caching:
Line # Mem usage Increment Occurences Line Contents
============================================================
9 109.9 MiB 109.9 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 56.3 MiB -53.6 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 56.4 MiB 0.1 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 110.0 MiB 110.0 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 54.0 MiB -56.0 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 54.1 MiB 0.1 MiB 1 return x
Line # Mem usage Increment Occurences Line Contents
============================================================
9 110.8 MiB 110.8 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 31.6 MiB -79.2 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 31.7 MiB 0.1 MiB 1 return x
For context, when n_samples = 2000 in the above examples, we have num_nodes = 6068.
Ah, I see what's happening here - your tree sequence is quite short, so there's not that many nodes. Try increasing the sequence length or recombination rate so that you have at least 100K trees. Then you should see some difference.
Ah OK, now with n_samples = 500 and recombination_rate = 1e-7 corresponding to ~250K trees and ~150k nodes. We now see a slight improvement without caching.
With caching:
Line # Mem usage Increment Occurences Line Contents
============================================================
9 215.7 MiB 215.7 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 26.6 MiB -189.1 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 26.7 MiB 0.1 MiB 1 return x
Without caching:
Line # Mem usage Increment Occurences Line Contents
============================================================
9 197.3 MiB 197.3 MiB 1 @profile(stream=fp)
10 def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
11 36.3 MiB -161.0 MiB 1 x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
12 36.4 MiB 0.1 MiB 1 return x
Still surprisingly little...
ping @jeromekelleher
Ah - I don't think this is computing the right value @brieuclehmann. Tests are failing because we're getting different numbers.
I'll think about how to do this again.