SpliceAI icon indicating copy to clipboard operation
SpliceAI copied to clipboard

SpliceAI Batching Support

Open kazmiekr opened this issue 2 years ago • 2 comments

This PR adds batching support to run predictions in larger batches to get better gpu utilization and speed. See comments in readme for speed comparisons.

  • Adds new command line parameters, --prediction-batch-size and --tensorflow-batch-size to support batching variants to optimize prediction utilization on a GPU
  • Adds a VCFPredictionBatch class that manages collection the VCF records, placing them in batches based on the encoded tensor size. Once the batch size is reached, predictions are run in batches, then output is written back in the original order reassembling the annotations for the VCF record. Each VCF record has a lookup key for where each of the ref/alts are within their batches, so it knows where to grab the results during reassembly
  • Breaks out code in the existing get_delta_scores method into reusable methods used in the batching and the original source code. This way the batching code can utilize the same logic inside that method while still maintaining the original version
  • Adds batch utility methods that split up what was all previously done in get_delta_scores. encode_batch_record handles what was in the first half, taking in the VCF record and generating one-hot encoded matrices for the ref/alts. extract_delta_scores handles the second half of the get_delta_scores by reassembling the annotations based on the batched predictions
  • Adds test cases to run a small file using a generated FASTA reference to test if the results are the same with no batching and with different batching sizes
  • Slightly modifies the entrypoint of running the code to allow for easier unit testing. Being able to pass in what would normally come from the argparser

kazmiekr avatar Dec 02 '21 19:12 kazmiekr

@kishorejaganathan What do you think, can we have this merged?

Hoeze avatar Apr 12 '22 10:04 Hoeze

@kazmiekr that looks like a super useful fork.

Do you think you could refactor the really old keras call to load_model in utils.py

currently import load_model from keras to from tensorflow import keras.models.load_model as load_model

sicotteh avatar Feb 14 '23 17:02 sicotteh