tsinfer
tsinfer copied to clipboard
Batch job ancestor matching.
After some hairy snakemake deliberation here is the strawman pipeline for ancestor matching.
It requires the following tsinfer methods:
match_ancestors_batch_init - creates a folder with metadata
match_ancestors_batch_group - matches a group locally, writes ts to folder
match_ancestors_batch_group_init - creates a folder and writes metadata on partitions for a group
match_ancestors_batch_group_partition - matches a partition of ancestors for a group
match_ancestors_batch_group_finalise - uses the partitions to write a ts for the group
match_ancestors_batch_finalise - writes final ts.
@jeromekelleher
from pathlib import Path
import json
# The number of ancestors needed in a group to trigger partitioning
BATCH_THRESHOLD = 100
# The number of groups to process in one job when doing local matching
NUM_GROUPS_ONE_JOB = 10
rule all:
input: 'ancestors.ts'
checkpoint match_ancestors_init:
input: 'ancestors.zarr'
output: 'match_wip/metadata.json'
run:
#tsinfer.match_ancestors_batch_init(input[0], working_dir="match_wip"))
#Dump some dummy groupings
groupings = []
for i, ancestors in enumerate(range(0,1000,10)):
groupings.append(list(range(ancestors, ancestors+10)))
for i, ancestors in enumerate(range(1000,5000,1000)):
groupings.append(list(range(ancestors, ancestors+1000)))
md = {'ancestor_groups':groupings}
with open(output[0], 'w') as f:
json.dump(md, f)
# Load ancestor groupings from metadata for snakemake
def ancestor_groupings(wildcards):
checkpoint_output = checkpoints.match_ancestors_init.get(**wildcards)
with open('match_wip/metadata.json') as f:
md = json.load(f)
return md['ancestor_groups']
# Load the number of partitions for a group
def num_partitions(wildcards):
checkpoint_output = checkpoints.match_ancestors_large_group_init.get(**wildcards)
with open(f'match_wip/batch_{wildcards.group}_wip/metadata.json') as f:
metadata = json.load(f)
return metadata["num_partitions"]
# This function decides if a group should be processed in a single job or partitioned
def match_ancestor_group_input(wildcards):
groupings = ancestor_groupings(wildcards)
group_index = int(wildcards.group)
# If the group is a large one then the inputs will be the partitions
if len(groupings[group_index]) > BATCH_THRESHOLD:
return expand(
'match_wip/batch_{group}_wip/partition-{partition}.json',
partition=range(num_partitions(wildcards)), allow_missing=True
)
# This group is small enough to do locally
# search back until we find a group that is large enough to require partitioning, or we reach the start, or we have enough groups
for i in range(group_index, max(group_index-NUM_GROUPS_ONE_JOB, 0), -1):
if len(groupings[i]) > BATCH_THRESHOLD:
return 'match_wip/ancestors_{i}.ts'
if group_index-NUM_GROUPS_ONE_JOB > 0:
return f'match_wip/ancestors_{group_index-NUM_GROUPS_ONE_JOB}.ts'
else:
return 'match_wip/metadata.json'
rule match_ancestors_group:
input:
match_ancestor_group_input
output:
'match_wip/ancestors_{group}.ts'
run:
# Use the input to determine if we are processing a set of groups or finalising
if "partition" in input[0]:
#tsinfer match_ancestors_batch_group_finalise("match_wip", group=group)
print(f"Finalise group {wildcards.group}")
else:
output_group = int(wildcards.group)
print(input[0])
if "metadata" in input[0]:
input_group = -1
else:
input_group = int(re.match(r'match_wip/ancestors_(\d+).ts', input[0]).group(1))
for group in range(input_group+1, output_group+1):
#tsinfer.match_ancestors_batch_group("match_wip", group=group)
print(f"Local Match group {group}")
Path(output[0]).touch()
checkpoint match_ancestors_large_group_init:
input:
lambda wildcards: f'match_wip/ancestors_{int(wildcards.group)-1}.ts'
output:
'match_wip/batch_{group}_wip/metadata.json'
run:
#tsinfer.match_ancestors_batch_group_init("match_wip", group=wildcards.group)
print(f"Init large group {wildcards.group}")
# Write some dummy data
with open(f'match_wip/batch_{wildcards.group}_wip/metadata.json', 'w') as f:
json.dump({"num_partitions": 10}, f)
rule match_ancestors_large_group_partition:
input:
'match_wip/batch_{group}_wip/metadata.json'
output:
'match_wip/batch_{group}_wip/partition-{partition}.json'
run:
#tsinfer.match_ancestors_batch_group_partition("match_wip", group=wildcards.group, partition=wildcards.partition)
print(f"Match group {wildcards.group} partition {wildcards.partition}")
Path(output[0]).touch()
def last_ancestor_group(wildcards):
groupings = ancestor_groupings(wildcards)
return len(groupings)-1
rule match_ancestors_final:
input:
lambda wildcards: f'match_wip/ancestors_{last_ancestor_group(wildcards)}.ts'
output:
'ancestors.ts'
run:
#tsinfer.match_ancestors_batch_finalise("match_wip")
print("Finalise")
Path(output[0]).touch()