tskit
tskit copied to clipboard
Method to 'prune back' a tree sequence
I want a function that will remove the topology after (forwards in time) a specific time, and mark each extant lineage with as a sample. The intuition is that we're drawing a line accross the tees, and sort of "pruning back" to this line. (This will allow us to generate the haplotypes extant at this time.)
Here's a simple implementation:
def prune_leaves(ts, time):
"""
Returns a tree sequence in which all the topology and mutation
information after the specified time is remove. That is, the
samples in the returned tree sequence will be the lineages
extant at the specified time.
"""
tables = ts.dump_tables()
t = tables.nodes.time
tables.edges.clear()
samples = []
edge_map = {}
for edge in ts.edges():
if t[edge.child] > time:
tables.edges.add_row(edge.left, edge.right, edge.parent, edge.child)
elif t[edge.child] <= time and t[edge.parent] > time:
key = (edge.child, edge.parent)
if key in edge_map:
# If the same parent/child combination exists, then we should map
# to the same ancestor node.
u = edge_map[key]
else:
u = tables.nodes.add_row(
flags=tskit.NODE_IS_SAMPLE, time=time)
samples.append(u)
edge_map[key] = u
tables.edges.add_row(edge.left, edge.right, edge.parent, u)
tables.sort()
tables.simplify(samples)
return tables.tree_sequence()
There's probably some subtleties and corner case I've missed out here, but this probably the basic idea.
I think this would be a useful addition to tskit. I don't know what we'd call it though, and I think it's worth keeping this in mind with the discussion going on over in #261 where we're chopping up the tree sequence in the space rather than time direction.
This is great. prune_leaves
is OK, but it's not specific to leaves. Other ideas:
-
truncate
-
pollard
-
time_slice
If removing all nodes ancestral to a certain time is decapitating, then this would be to amputate
or dismember
the tree:
import tskit.gory_ops
Maybe I'm not fully understanding here, but isn't this basically what map_ancestors
/find_ancestors
does? You'd basically just need to add a sample on each lineage at the right time (the census
event you've talked about with me in person before - btw I'm happy to start working on this soon if it will be useful), and then apply map_ancestors
with the ancestors
being all nodes at this census time or earlier, and the samples
being all the leaves.
You're right, it's definitely closely related @gtsambos. It would be excellent if we can implement it using map/find_ancestors
!
ps. Let's talk about the census event soon --- I do actually need this as it happens.
Hi @jeromekelleher, I just did a mock-up of this idea in this gist
then apply map_ancestors with the ancestors being all nodes at this census time or earlier, and the samples being all the leaves.
It's actually even simpler than this - you just need to set the census nodes as your samples
, and all older nodes as ancestors
. Then each leaf node in the output corresponds to a lineage present at the specified census time.
Btw, here's a Python function that puts nodes on trees at a given 'census' time:
(Ni
= input node table, Ei
= input edge table, L
= sequence length)
def add_census(Ni, Ei, census_time, L):
No = tskit.NodeTable()
No.append_columns(time=Ni.time, flags=Ni.flags)
Eo = tskit.EdgeTable()
edge_list = []
for ind, e in enumerate(Ei):
edge_list.append(e)
if ind + 1 == len(Ei) or Ei[ind + 1].parent != Ei[ind].parent:
for e in edge_list:
if Ni.time[e.parent] > census_time and Ni.time[e.child] < census_time:
v = No.add_row(time=census_time)
Eo.add_row(e.left, e.right, v, e.child)
Eo.add_row(e.left, e.right, e.parent, v)
else:
Eo.add_row(e.left, e.right, e.parent, e.child)
edge_list = []
# Sort output.
new_tables = tskit.TableCollection(sequence_length=L)
new_tables.nodes.append_columns(time=No.time,flags=No.flags)
new_tables.edges.append_columns(left=Eo.left, right=Eo.right,
parent=Eo.parent, child=Eo.child)
new_tables.sort()
return(new_tables.tree_sequence())
ps. Let's talk about the census event soon --- I do actually need this as it happens.
Yes! 💻
The major difference between my implementation and yours is that you treat any (parent, child) combo as the same lineage,
key = (edge.child, edge.parent)
if key in edge_map:
# If the same parent/child combination exists, then we should map
# to the same ancestor node.
u = edge_map[key]
even where the edges with this combo span non-adjacent intervals of the sequences. Whereas my add_census
function above treats each edge as a distinct lineage.
I guess there might be contextual reasons to prefer one or the other, but my way makes a bit more sense to me because it ensures that each distinct 'leaf' corresponds to no more than one distinct historical chromosome. Ie. a particular ancestor may have contributed two distinct segments to a particular descendant's chromosome, but that doesn't necessarily mean that both segments have passed through the same sequence of individuals (which is my usual interpretation of lineage)
This would be cool to have implemented. There are some subtleties that arise for handling mutations that don't occur in the case of decapitation
. For any parent/child edge that you truncate to a time "in between", some of the mutations on the child nodes may need ramapping to to the new node. I'm thinking about forward simulations here, where the mutation's origin time may be a part of its metadata, and so you may know that it predates your census time but post-dates the first parent ancestral to that time. For many coalescent sims, (I think) you can justify simply binomially sampling any mutations and remap them.
It's actually even simpler than this - you just need to set the census nodes as your samples, and all older nodes as ancestors.
Shouldn't you just set all the census-and-older nodes to be your samples?
handling mutations
Oh, this is a good point.
So, perhaps the thing to do is to add in the 'census' nodes separately then, as I can imagine this being useful without the pollarding step. Handling the mutations is tricky, and I think we'll have to make the handling pluggable somehow (a callback that decides if a mutation goes above or below the census node, or something), as there's no general way for tskit to know what time the mutation happens at.
Exactly. For each parent/child combo spanning the census time, you can pass the parent time, child time, census time, and the index of the mutation on the child node to a function returning True
if it should be "pushed up" to the new census node. You could even just pass in all indexes for all mutations on the child node, which may be more efficient. tskit
could provide default using numpy
to make a random decision based on the proportion of time above the census node.
So, perhaps the thing to do is to add in the 'census' nodes separately then, as I can imagine this being useful without the pollarding step.
Right; then we could just simplify
.
Related to #382
It turns out we'd like to do the opposite of this (cutting off the top of a tree sequence) for some applications (it'd give us an easy way to do time windowed statistics).
Note that cutting off the top of a tree sequence is less problematic than the bottom - we don't want to re-map sample nodes or anything; I'd say just discard anything above a certain time - however, we do have to figure out what to call the newly-created roots.
Yes, cutting off the top is much easier and simpler since we don't have to worry about samples. I'll make a new function for and an issue to track. Since moving the samples up the tree is much more subtle and would need more options, I think this should remain its own function, as discussed here.
I think the general way to do this sort of operation should be to
- add new nodes across at a given point in time (an operation we've called a
census
) - remove everything above (or, below) those nodes
This "adding new nodes" step removes a lot of the possible weirdnesses. Note that even cutting off the top runs into issues about samples (since there may be samples older than that time). I think we don't want to be movign around times of nodes, and if we add new nodes we don't have to.
Also note that "everything" could include the nodes outside the time range or not; it could just mean removing the edges and mutations.
Yep, agreed. That's basically what I've done in #2240
Note that you get quite different numbers of lineages when "pruning back" if you leave in the regions where a coalescent node is unary (or, I assume if you use @petrelharp 's extend_edges
idea). I wonder if this should be noted in any function that we provide in the API?
import msprime
import tskit
import numpy as np
def node_is_coalescent(ts):
is_coalescent = np.zeros(ts.num_nodes, dtype=bool)
for tree in ts.trees():
is_coalescent[tree.num_children_array[:-1] > 1] = True
return np.where(is_coalescent)[0]
def keep_unary_regions_of_coalescent_nodes(ts):
tables = ts.dump_tables()
individual_arr = tables.nodes.individual
for u in node_is_coalescent(ts):
individual_arr[u] = tables.individuals.add_row()
tables.nodes.individual = individual_arr
tables.simplify(keep_unary_in_individuals=True, filter_nodes=False)
tables.nodes.individual = ts.nodes_individual # set the individuals back to the original
tables.individuals.truncate(ts.num_individuals)
return tables.tree_sequence()
def simplify_to_census(ts, census_time):
tables = tskit.TableCollection(sequence_length=ts.sequence_length)
tables.time_units = ts.time_units
tables.nodes.append_columns(time=ts.nodes_time, flags=ts.nodes_flags)
keep_nodes = list(np.where(ts.nodes_time == census_time)[0])
edge_list = []
for e in ts.edges():
edge_list.append(e)
if (
e.id + 1 == ts.num_edges or
ts.edges_parent[e.id + 1] != ts.edges_parent[e.id]
):
for e in edge_list:
if ts.nodes_time[e.parent] > census_time and ts.nodes_time[e.child] < census_time:
v = tables.nodes.add_row(time=census_time)
keep_nodes.append(v)
tables.edges.add_row(e.left, e.right, v, e.child)
tables.edges.add_row(e.left, e.right, e.parent, v)
else:
tables.edges.add_row(e.left, e.right, e.parent, e.child)
edge_list = []
# Sort output.
tables.sort()
tables.simplify(keep_nodes)
return tables.tree_sequence()
ts = msprime.sim_ancestry(
5, sequence_length=1e6, population_size=1e4, recombination_rate=1e-8,
record_full_arg=True,
)
ts = keep_unary_regions_of_coalescent_nodes(ts)
print("With unary regions of coalescent nodes")
for t in np.insert(np.logspace(0, np.log10(ts.max_time), 10), 0, 0):
nts = simplify_to_census(ts, t)
print("At time", t, nts.time_units, nts.num_samples, "lineages")
print("\nFully simplified")
sts = ts.simplify()
for t in np.insert(np.logspace(0, np.log10(ts.max_time), 10), 0, 0):
nts = simplify_to_census(sts, t)
print("At", t, nts.time_units, nts.num_samples, "lineages")
Giving:
With unary regions of coalescent nodes
At time 0.0 generations 10 lineages
At time 1.0 generations 89 lineages
At time 3.805226480356486 generations 89 lineages
At time 14.47974856680621 generations 89 lineages
At time 55.09872267531488 generations 89 lineages
At time 209.66311855792648 generations 101 lineages
At time 797.815650690743 generations 118 lineages
At time 3035.869240451258 generations 183 lineages
At time 11552.170024664856 generations 256 lineages
At time 43958.62328343514 generations 208 lineages
At time 167272.51735814253 generations 0 lineages
Fully simplified
At 0.0 generations 10 lineages
At 1.0 generations 641 lineages
At 3.805226480356486 generations 641 lineages
At 14.47974856680621 generations 641 lineages
At 55.09872267531488 generations 641 lineages
At 209.66311855792648 generations 748 lineages
At 797.815650690743 generations 860 lineages
At 3035.869240451258 generations 1077 lineages
At 11552.170024664856 generations 1221 lineages
At 43958.62328343514 generations 472 lineages
At 167272.51735814253 generations 0 lineages