tsinfer icon indicating copy to clipboard operation
tsinfer copied to clipboard

Batch job ancestor matching.

Open benjeffery opened this issue 1 year ago • 3 comments

After some hairy snakemake deliberation here is the strawman pipeline for ancestor matching.

It requires the following tsinfer methods: match_ancestors_batch_init - creates a folder with metadata match_ancestors_batch_group - matches a group locally, writes ts to folder match_ancestors_batch_group_init - creates a folder and writes metadata on partitions for a group match_ancestors_batch_group_partition - matches a partition of ancestors for a group match_ancestors_batch_group_finalise - uses the partitions to write a ts for the group match_ancestors_batch_finalise - writes final ts.

@jeromekelleher

from pathlib import Path
import json

# The number of ancestors needed in a group to trigger partitioning
BATCH_THRESHOLD = 100
# The number of groups to process in one job when doing local matching
NUM_GROUPS_ONE_JOB = 10

rule all:
    input: 'ancestors.ts'

checkpoint match_ancestors_init:
    input: 'ancestors.zarr'
    output: 'match_wip/metadata.json'
    run:
        #tsinfer.match_ancestors_batch_init(input[0], working_dir="match_wip"))
        
        #Dump some dummy groupings
        groupings = []
        for i, ancestors in enumerate(range(0,1000,10)):
            groupings.append(list(range(ancestors, ancestors+10)))
        for i, ancestors in enumerate(range(1000,5000,1000)):
            groupings.append(list(range(ancestors, ancestors+1000)))
        md = {'ancestor_groups':groupings}
        with open(output[0], 'w') as f:
            json.dump(md, f)


# Load ancestor groupings from metadata for snakemake
def ancestor_groupings(wildcards):
    checkpoint_output = checkpoints.match_ancestors_init.get(**wildcards)
    with open('match_wip/metadata.json') as f:
        md = json.load(f)
    return md['ancestor_groups']

# Load the number of partitions for a group
def num_partitions(wildcards):
    checkpoint_output = checkpoints.match_ancestors_large_group_init.get(**wildcards)
    with open(f'match_wip/batch_{wildcards.group}_wip/metadata.json') as f:
        metadata = json.load(f)
    return metadata["num_partitions"]

# This function decides if a group should be processed in a single job or partitioned
def match_ancestor_group_input(wildcards):
    groupings = ancestor_groupings(wildcards)
    group_index = int(wildcards.group)
    # If the group is a large one then the inputs will be the partitions
    if len(groupings[group_index]) > BATCH_THRESHOLD:
        return expand(
            'match_wip/batch_{group}_wip/partition-{partition}.json',
             partition=range(num_partitions(wildcards)), allow_missing=True
            )
    # This group is small enough to do locally
    # search back until we find a group that is large enough to require partitioning, or we reach the start, or we have enough groups
    for i in range(group_index, max(group_index-NUM_GROUPS_ONE_JOB, 0), -1):
        if len(groupings[i]) > BATCH_THRESHOLD:
            return 'match_wip/ancestors_{i}.ts'
    if group_index-NUM_GROUPS_ONE_JOB > 0:
        return f'match_wip/ancestors_{group_index-NUM_GROUPS_ONE_JOB}.ts'
    else:
        return 'match_wip/metadata.json'

rule match_ancestors_group:
    input:
        match_ancestor_group_input
    output:
        'match_wip/ancestors_{group}.ts'
    run:
        # Use the input to determine if we are processing a set of groups or finalising
        if "partition" in input[0]:
            #tsinfer match_ancestors_batch_group_finalise("match_wip", group=group)
            print(f"Finalise group {wildcards.group}")
        else:
            output_group = int(wildcards.group)
            print(input[0])
            if "metadata" in input[0]:
                input_group = -1
            else:
                input_group = int(re.match(r'match_wip/ancestors_(\d+).ts', input[0]).group(1))
            for group in range(input_group+1, output_group+1):
                #tsinfer.match_ancestors_batch_group("match_wip", group=group)
                print(f"Local Match group {group}")
        Path(output[0]).touch()


checkpoint match_ancestors_large_group_init:
    input: 
        lambda wildcards: f'match_wip/ancestors_{int(wildcards.group)-1}.ts'
    output:
        'match_wip/batch_{group}_wip/metadata.json'
    run:
        #tsinfer.match_ancestors_batch_group_init("match_wip", group=wildcards.group)
        print(f"Init large group {wildcards.group}")
        # Write some dummy data
        with open(f'match_wip/batch_{wildcards.group}_wip/metadata.json', 'w') as f:
            json.dump({"num_partitions": 10}, f)

rule match_ancestors_large_group_partition:
    input: 
        'match_wip/batch_{group}_wip/metadata.json'
    output:
        'match_wip/batch_{group}_wip/partition-{partition}.json'
    run:
        #tsinfer.match_ancestors_batch_group_partition("match_wip", group=wildcards.group, partition=wildcards.partition)
        print(f"Match group {wildcards.group} partition {wildcards.partition}")
        Path(output[0]).touch()


def last_ancestor_group(wildcards):
    groupings = ancestor_groupings(wildcards)
    return len(groupings)-1
    
rule match_ancestors_final:
    input:
        lambda wildcards: f'match_wip/ancestors_{last_ancestor_group(wildcards)}.ts'
    output:
        'ancestors.ts'
    run:
        #tsinfer.match_ancestors_batch_finalise("match_wip")
        print("Finalise")
        Path(output[0]).touch()

benjeffery avatar May 22 '24 13:05 benjeffery