bertax icon indicating copy to clipboard operation
bertax copied to clipboard

Speeding up predictions

Open peterk87 opened this issue 2 years ago • 1 comments

Hello,

Thank you for developing BERTax! It looks like a really great tool for taxonomic classification of sequences that are typically difficult to classify with tools that rely on big databases.

I was interested to see if BERTax could be used for classification of metagenomic sequencing reads, but it seems like it would be quite a bit slower than kmer based methods (Centrifuge, Kraken2) even with GPU acceleration (16 CPU threads (Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz): 6 reads/s; Nvidia Quadro RTX 5000 (Driver Version: 470.63.01; CUDA Version: 11.4): 20 reads/s).

Are there any plans to optimize BERTax for performing predictions on larger inputs?

I tried to modify the BERTax code to be a little more efficient on large inputs (reads in FASTQ) in PR https://github.com/peterk87/bertax/pull/1 but I'm not familiar with Keras or Tensorflow, so I'm not sure how one would go about optimizing that code. The call to model.predict seems to be taking the most time by far.

For example, for a read of length 6092 split into 5 chunks:

  • seq2tokens: 0.792363 ms
  • process_bert_tokens_batch: 1.096281 ms
  • model.predict: 67.773608 ms
  • writing output: 1.32 ms

Total elapsed time of 70.986515 ms. Timings were obtained with time.time_ns. Although there may be optimizations that could be possible for input processing and formatting output, most of the time (>95%) is spent running model.predict.

I noticed that in the bertax-visualize script, that the Keras model is converted into a PyTorch model:

https://github.com/f-kretschmer/bertax/blob/ae8cc568a2e66692e7663025906fda0016aa8b52/bertax/visualize.py#L29

I haven't tested whether using PyTorch and a converted model would help speed-up predictions. Maybe the Keras model could be converted to a Tensorflow model for less overhead per call to model.predict as per the following blogpost:

https://micwurm.medium.com/using-tensorflow-lite-to-speed-up-predictions-a3954886eb98

Unfortunately, I'm only familiar with NumPy and not familiar with Keras, Tensorflow or PyTorch. I have a bit of experience working with Cython and Numba for accelerating Python code, but using those may not be appropriate in this case.

Any speed-ups (or ideas for how to achieve speed-ups) would be extremely useful and appreciated and allow BERTax to be used on a wider range of datasets!

Thanks! Peter

peterk87 avatar Aug 30 '21 17:08 peterk87

Hello Peter,

Many thanks for your tests and suggestions! I haven't looked into runtime optimization that much so far, so I think there are definitely some improvements that can be made. I didn't know about tensorflow lite, that seems like a promising starting point, although I'm not sure how well custom models (keras-bert) can be converted. Thanks again, I'll look into it! Fleming

f-kretschmer avatar Aug 31 '21 08:08 f-kretschmer