SpliceAI
SpliceAI copied to clipboard
SpliceAI Batching Support
This PR adds batching support to run predictions in larger batches to get better gpu utilization and speed. See comments in readme for speed comparisons.
- Adds new command line parameters,
--prediction-batch-size
and--tensorflow-batch-size
to support batching variants to optimize prediction utilization on a GPU - Adds a
VCFPredictionBatch
class that manages collection the VCF records, placing them in batches based on the encoded tensor size. Once the batch size is reached, predictions are run in batches, then output is written back in the original order reassembling the annotations for the VCF record. Each VCF record has a lookup key for where each of the ref/alts are within their batches, so it knows where to grab the results during reassembly - Breaks out code in the existing
get_delta_scores
method into reusable methods used in the batching and the original source code. This way the batching code can utilize the same logic inside that method while still maintaining the original version - Adds batch utility methods that split up what was all previously done in
get_delta_scores
.encode_batch_record
handles what was in the first half, taking in the VCF record and generating one-hot encoded matrices for the ref/alts.extract_delta_scores
handles the second half of theget_delta_scores
by reassembling the annotations based on the batched predictions - Adds test cases to run a small file using a generated FASTA reference to test if the results are the same with no batching and with different batching sizes
- Slightly modifies the entrypoint of running the code to allow for easier unit testing. Being able to pass in what would normally come from the argparser
@kishorejaganathan What do you think, can we have this merged?
@kazmiekr that looks like a super useful fork.
Do you think you could refactor the really old keras call to load_model in utils.py
currently import load_model from keras to from tensorflow import keras.models.load_model as load_model