Add possibility to simplify but keep nodes that are partially unary + partially coalescent
As discussed in https://github.com/tskit-dev/tskit/discussions/2089 and with some code providing a workaround in https://github.com/tskit-dev/what-is-an-arg-paper/issues/80.
Also note that issue https://github.com/tskit-dev/tskit/issues/1120 discusses how we might avoid a large increase in the number of parameters to simplify, when focussing on particular nodes. I'm not sure if that is relevant to the current case: the semantics of the parameters passed to simplify requires some thought, I reckon.
Some options:
- keep_coalescent_unary
- keep_unary_that_coalesce
- keep_locally_unary (as in, keep nodes that are locally unary, but with greater arity elsewhere)
Here is a plot to show the effect of running that versus the normal simplify. It looks like I was wrong when chatting to @petrelharp : on standard tree sequences we nearly always get compression when leaving the "unary-when-on-coalescent_nodes" in. Unless I've made some error here:
import msprime
import numpy as np
import tskit
import matplotlib.pyplot as plt
def simplify_keeping_unary_in_coal(ts, map_nodes=False):
"""
Keep the unary regions of nodes that are coalescent at least someone in the tree seq
Temporary hack until https://github.com/tskit-dev/tskit/issues/2127 is addressed
"""
tables = ts.dump_tables()
# remove existing individuals. We will reinstate them later
tables.individuals.clear()
tables.nodes.individual = np.full_like(tables.nodes.individual, tskit.NULL)
_, node_map = ts.simplify(map_nodes=True)
keep_nodes = np.where(node_map != tskit.NULL)[0]
# Add an individual for each coalescent node, so we can run
# simplify(keep_unary_in_individuals=True) to leave the unary portions in.
for u in keep_nodes:
i = tables.individuals.add_row()
tables.nodes[u] = tables.nodes[u].replace(individual=i)
node_map = tables.simplify(keep_unary_in_individuals=True, filter_individuals=False)
# Reinstate individuals
tables.individuals.clear()
for i in ts.individuals():
tables.individuals.append(i)
val, inverted_map = np.unique(node_map, return_index=True)
inverted_map = inverted_map[val != tskit.NULL]
tables.nodes.individual = ts.tables.nodes.individual[inverted_map]
if map_nodes:
return tables.tree_sequence(), node_map
else:
return tables.tree_sequence()
sample_size = [10, 100, 1000, 10000]
num_reps = 100
byte_diff = np.zeros((num_reps, len(sample_size)))
node_diff = np.zeros((num_reps, len(sample_size)))
edge_diff = np.zeros((num_reps, len(sample_size)))
for j, sz in enumerate(sample_size):
for i, ts in enumerate(msprime.sim_ancestry(
sz,
sequence_length=5e8,
recombination_rate=1e-8,
record_full_arg=True,
num_replicates=num_reps
)):
ts_simplified = ts.simplify()
ts_simplified_keep_unary = simplify_keeping_unary_in_coal(ts)
byte_diff[i][j] = ts_simplified.nbytes - ts_simplified_keep_unary.nbytes
edge_diff[i][j] = ts_simplified.num_edges - ts_simplified_keep_unary.num_edges
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
plt.subplots_adjust(wspace=0.5)
ax1.violinplot(byte_diff)
ax1.set_xticks(range(1, 5))
ax1.set_xticklabels(sample_size)
ax1.set_ylabel("Extra bytes when unary removed")
ax1.set_xlabel("Sample size")
ax2.violinplot(edge_diff)
ax2.set_xticks(range(1, 5))
ax2.set_xticklabels(sample_size)
ax2.set_ylabel("Extra edges when unary removed")
ax2.set_xlabel("Sample size")
Thanks, @hyanwong - interesting that it's nearly always, not strictly always. (I guess we know why, though?)
See also discussion here: https://github.com/tskit-dev/tskit/discussions/2089#discussioncomment-2005450