tskit icon indicating copy to clipboard operation
tskit copied to clipboard

Method to 'prune back' a tree sequence

Open jeromekelleher opened this issue 4 years ago • 18 comments

I want a function that will remove the topology after (forwards in time) a specific time, and mark each extant lineage with as a sample. The intuition is that we're drawing a line accross the tees, and sort of "pruning back" to this line. (This will allow us to generate the haplotypes extant at this time.)

Here's a simple implementation:

def prune_leaves(ts, time):
    """
    Returns a tree sequence in which all the topology and mutation
    information after the specified time is remove. That is, the
    samples in the returned tree sequence will be the lineages
    extant at the specified time.
    """
    tables = ts.dump_tables()
    t = tables.nodes.time
        
    tables.edges.clear()
    samples = []
    edge_map = {}
    for edge in ts.edges():
        if t[edge.child] > time:
            tables.edges.add_row(edge.left, edge.right, edge.parent, edge.child)
        elif t[edge.child] <= time and t[edge.parent] > time:
            key = (edge.child, edge.parent)
            if key in edge_map:
                # If the same parent/child combination exists, then we should map
                # to the same ancestor node.
                u = edge_map[key]
            else:
                u = tables.nodes.add_row(
                    flags=tskit.NODE_IS_SAMPLE, time=time)
                samples.append(u)
                edge_map[key] = u
            tables.edges.add_row(edge.left, edge.right, edge.parent, u)
    tables.sort()
    tables.simplify(samples)
    return tables.tree_sequence()

There's probably some subtleties and corner case I've missed out here, but this probably the basic idea.

I think this would be a useful addition to tskit. I don't know what we'd call it though, and I think it's worth keeping this in mind with the discussion going on over in #261 where we're chopping up the tree sequence in the space rather than time direction.

jeromekelleher avatar Aug 01 '19 15:08 jeromekelleher

This is great. prune_leaves is OK, but it's not specific to leaves. Other ideas:

  • truncate
  • pollard
  • time_slice

petrelharp avatar Aug 01 '19 16:08 petrelharp

If removing all nodes ancestral to a certain time is decapitating, then this would be to amputate or dismember the tree:

import tskit.gory_ops

molpopgen avatar Aug 01 '19 23:08 molpopgen

Maybe I'm not fully understanding here, but isn't this basically what map_ancestors/find_ancestors does? You'd basically just need to add a sample on each lineage at the right time (the census event you've talked about with me in person before - btw I'm happy to start working on this soon if it will be useful), and then apply map_ancestors with the ancestors being all nodes at this census time or earlier, and the samples being all the leaves.

gtsambos avatar Aug 04 '19 19:08 gtsambos

You're right, it's definitely closely related @gtsambos. It would be excellent if we can implement it using map/find_ancestors!

ps. Let's talk about the census event soon --- I do actually need this as it happens.

jeromekelleher avatar Aug 04 '19 21:08 jeromekelleher

Hi @jeromekelleher, I just did a mock-up of this idea in this gist

then apply map_ancestors with the ancestors being all nodes at this census time or earlier, and the samples being all the leaves.

It's actually even simpler than this - you just need to set the census nodes as your samples, and all older nodes as ancestors. Then each leaf node in the output corresponds to a lineage present at the specified census time.

Btw, here's a Python function that puts nodes on trees at a given 'census' time: (Ni = input node table, Ei = input edge table, L = sequence length)

def add_census(Ni, Ei, census_time, L):
    No = tskit.NodeTable()
    No.append_columns(time=Ni.time, flags=Ni.flags)
    Eo = tskit.EdgeTable()
    
    edge_list = []
    for ind, e in enumerate(Ei):
        edge_list.append(e)
        if ind + 1 == len(Ei) or Ei[ind + 1].parent != Ei[ind].parent:
            for e in edge_list:
                if Ni.time[e.parent] > census_time and Ni.time[e.child] < census_time:
                    v = No.add_row(time=census_time)
                    Eo.add_row(e.left, e.right, v, e.child)
                    Eo.add_row(e.left, e.right, e.parent, v)
                else:
                    Eo.add_row(e.left, e.right, e.parent, e.child)
            edge_list = []

    # Sort output. 
    new_tables = tskit.TableCollection(sequence_length=L)
    new_tables.nodes.append_columns(time=No.time,flags=No.flags)
    new_tables.edges.append_columns(left=Eo.left, right=Eo.right, 
                                    parent=Eo.parent, child=Eo.child)
    new_tables.sort()
    return(new_tables.tree_sequence())
    

ps. Let's talk about the census event soon --- I do actually need this as it happens.

Yes! 💻

gtsambos avatar Aug 04 '19 22:08 gtsambos

The major difference between my implementation and yours is that you treat any (parent, child) combo as the same lineage,

            key = (edge.child, edge.parent)
            if key in edge_map:
               # If the same parent/child combination exists, then we should map
                # to the same ancestor node.
                u = edge_map[key]

even where the edges with this combo span non-adjacent intervals of the sequences. Whereas my add_census function above treats each edge as a distinct lineage.

I guess there might be contextual reasons to prefer one or the other, but my way makes a bit more sense to me because it ensures that each distinct 'leaf' corresponds to no more than one distinct historical chromosome. Ie. a particular ancestor may have contributed two distinct segments to a particular descendant's chromosome, but that doesn't necessarily mean that both segments have passed through the same sequence of individuals (which is my usual interpretation of lineage)

gtsambos avatar Aug 04 '19 22:08 gtsambos

This would be cool to have implemented. There are some subtleties that arise for handling mutations that don't occur in the case of decapitation. For any parent/child edge that you truncate to a time "in between", some of the mutations on the child nodes may need ramapping to to the new node. I'm thinking about forward simulations here, where the mutation's origin time may be a part of its metadata, and so you may know that it predates your census time but post-dates the first parent ancestral to that time. For many coalescent sims, (I think) you can justify simply binomially sampling any mutations and remap them.

molpopgen avatar Aug 04 '19 23:08 molpopgen

It's actually even simpler than this - you just need to set the census nodes as your samples, and all older nodes as ancestors.

Shouldn't you just set all the census-and-older nodes to be your samples?

handling mutations

Oh, this is a good point.

petrelharp avatar Aug 05 '19 03:08 petrelharp

So, perhaps the thing to do is to add in the 'census' nodes separately then, as I can imagine this being useful without the pollarding step. Handling the mutations is tricky, and I think we'll have to make the handling pluggable somehow (a callback that decides if a mutation goes above or below the census node, or something), as there's no general way for tskit to know what time the mutation happens at.

jeromekelleher avatar Aug 05 '19 16:08 jeromekelleher

Exactly. For each parent/child combo spanning the census time, you can pass the parent time, child time, census time, and the index of the mutation on the child node to a function returning True if it should be "pushed up" to the new census node. You could even just pass in all indexes for all mutations on the child node, which may be more efficient. tskit could provide default using numpy to make a random decision based on the proportion of time above the census node.

molpopgen avatar Aug 05 '19 16:08 molpopgen

So, perhaps the thing to do is to add in the 'census' nodes separately then, as I can imagine this being useful without the pollarding step.

Right; then we could just simplify.

petrelharp avatar Aug 05 '19 17:08 petrelharp

Related to #382

jeromekelleher avatar Sep 29 '20 15:09 jeromekelleher

It turns out we'd like to do the opposite of this (cutting off the top of a tree sequence) for some applications (it'd give us an easy way to do time windowed statistics).

petrelharp avatar May 03 '22 19:05 petrelharp

Note that cutting off the top of a tree sequence is less problematic than the bottom - we don't want to re-map sample nodes or anything; I'd say just discard anything above a certain time - however, we do have to figure out what to call the newly-created roots.

petrelharp avatar May 03 '22 19:05 petrelharp

Yes, cutting off the top is much easier and simpler since we don't have to worry about samples. I'll make a new function for and an issue to track. Since moving the samples up the tree is much more subtle and would need more options, I think this should remain its own function, as discussed here.

jeromekelleher avatar May 04 '22 08:05 jeromekelleher

I think the general way to do this sort of operation should be to

  1. add new nodes across at a given point in time (an operation we've called a census)
  2. remove everything above (or, below) those nodes

This "adding new nodes" step removes a lot of the possible weirdnesses. Note that even cutting off the top runs into issues about samples (since there may be samples older than that time). I think we don't want to be movign around times of nodes, and if we add new nodes we don't have to.

Also note that "everything" could include the nodes outside the time range or not; it could just mean removing the edges and mutations.

petrelharp avatar May 04 '22 17:05 petrelharp

Yep, agreed. That's basically what I've done in #2240

jeromekelleher avatar May 04 '22 18:05 jeromekelleher

Note that you get quite different numbers of lineages when "pruning back" if you leave in the regions where a coalescent node is unary (or, I assume if you use @petrelharp 's extend_edges idea). I wonder if this should be noted in any function that we provide in the API?

import msprime
import tskit
import numpy as np

def node_is_coalescent(ts):
    is_coalescent = np.zeros(ts.num_nodes, dtype=bool)
    for tree in ts.trees():
        is_coalescent[tree.num_children_array[:-1] > 1] = True
    return np.where(is_coalescent)[0]

def keep_unary_regions_of_coalescent_nodes(ts):
    tables = ts.dump_tables()
    individual_arr = tables.nodes.individual
    for u in node_is_coalescent(ts):
        individual_arr[u] = tables.individuals.add_row()
    tables.nodes.individual = individual_arr
    tables.simplify(keep_unary_in_individuals=True, filter_nodes=False)
    tables.nodes.individual = ts.nodes_individual  # set the individuals back to the original
    tables.individuals.truncate(ts.num_individuals)
    return tables.tree_sequence()

def simplify_to_census(ts, census_time):
    tables = tskit.TableCollection(sequence_length=ts.sequence_length)
    tables.time_units = ts.time_units
    tables.nodes.append_columns(time=ts.nodes_time, flags=ts.nodes_flags)
    keep_nodes = list(np.where(ts.nodes_time == census_time)[0])
    edge_list = []
    for e in ts.edges():
        edge_list.append(e)
        if (
            e.id + 1 == ts.num_edges or 
            ts.edges_parent[e.id + 1] != ts.edges_parent[e.id]
        ):
            for e in edge_list:
                if ts.nodes_time[e.parent] > census_time and ts.nodes_time[e.child] < census_time:
                    v = tables.nodes.add_row(time=census_time)
                    keep_nodes.append(v)
                    tables.edges.add_row(e.left, e.right, v, e.child)
                    tables.edges.add_row(e.left, e.right, e.parent, v)
                else:
                    tables.edges.add_row(e.left, e.right, e.parent, e.child)
            edge_list = []

    # Sort output.
    tables.sort()
    tables.simplify(keep_nodes)
    return tables.tree_sequence()


ts = msprime.sim_ancestry(
    5, sequence_length=1e6, population_size=1e4, recombination_rate=1e-8,
    record_full_arg=True,
)
ts = keep_unary_regions_of_coalescent_nodes(ts)


print("With unary regions of coalescent nodes")
for t in np.insert(np.logspace(0, np.log10(ts.max_time), 10), 0, 0):
    nts = simplify_to_census(ts, t)
    print("At time", t, nts.time_units, nts.num_samples, "lineages")

print("\nFully simplified")
sts = ts.simplify()
for t in np.insert(np.logspace(0, np.log10(ts.max_time), 10), 0, 0):
    nts = simplify_to_census(sts, t)
    print("At", t, nts.time_units, nts.num_samples, "lineages")

Giving:

With unary regions of coalescent nodes
At time 0.0 generations 10 lineages
At time 1.0 generations 89 lineages
At time 3.805226480356486 generations 89 lineages
At time 14.47974856680621 generations 89 lineages
At time 55.09872267531488 generations 89 lineages
At time 209.66311855792648 generations 101 lineages
At time 797.815650690743 generations 118 lineages
At time 3035.869240451258 generations 183 lineages
At time 11552.170024664856 generations 256 lineages
At time 43958.62328343514 generations 208 lineages
At time 167272.51735814253 generations 0 lineages

Fully simplified
At 0.0 generations 10 lineages
At 1.0 generations 641 lineages
At 3.805226480356486 generations 641 lineages
At 14.47974856680621 generations 641 lineages
At 55.09872267531488 generations 641 lineages
At 209.66311855792648 generations 748 lineages
At 797.815650690743 generations 860 lineages
At 3035.869240451258 generations 1077 lineages
At 11552.170024664856 generations 1221 lineages
At 43958.62328343514 generations 472 lineages
At 167272.51735814253 generations 0 lineages

hyanwong avatar Jul 05 '23 12:07 hyanwong