hover_net
hover_net copied to clipboard
Optimised inference pipeline
Hi all,
First of all, thanks for this amazing project. Our lab is really happy to be able to use it, and I've learned a lot from working with it.
While applying HoverNet to our dataset, I encountered a pretty significant obstacle. We use a HPC cluster equipped with GPUs to run inference across our fairly large dataset (several thousand WSI). The runtime of the current inference script (run_infer.py
) reached multiple days for a single WSI before being killed by the scheduler.
When reading through the code, I noticed that the inference script (run_infer.py
) makes quite extensive use of memory mapped numpy files. Memmapping this way has quite significant overhead, making it virtually impossible to use on network filesystems like that of our HPC cluster.
I definitely wanted to use your project, so I ended up refactoring the inference script a bit. Please find the result in this PR. I ended up making a few additional optimisations that should improve runtime and code simplicity a bit. The new infer_simple.py
should be able to completely replace the previous inference code, and only imports the parts of the existing codebase that it needs (i.e. the HoVerNet
model definition and the process
postprocessing routine).
- Inference now happens in the following stages: the WSI is divided into patches, and these patches are then collected into square 'chunks' that overlap with eachother. For each chunk, the model creates feature maps by running the model's forward() pass on all patches and then writes these to disk in h5 files. After all chunks have been inferred, each chunk is postprocessed individually, in parallel. Border artefacts are resolved by simply discarding the cells located in the chunk's overlapping padding area (as was helpfully explained in #102 ).
- I used Python's brilliant Shapely package to allow for versatile manipulation of patch, chunk and mask shapes.
- Patches are no longer cached on disk, and are read as-needed by the data loader from their coordinates.
- Currently it only outputs nuclear centroids, but it should be trivial to extend.
- It does not take the slide's MPP into account. I'm not sure how that's done in the current inference script and would love some feedback here.
- There are also some more TODO comments in there regarding smaller implementation details that you may wish to look over.
Our lab is using this script 'in production'. I'm happy to polish this PR so it better fits your vision for this project or to explain my design decisions some more. If you prefer to use the current scripts instead that's also perfectly fine with me, of course! I'm just happy to have been able to use your work.
Hi @jjhbw ,
Thanks a lot for this - we really appreciate it. We are pleased that you and your lab are enjoying working with the code. Also thanks for the detailed description. We will pull down the PR and test it before giving some feedback later this week.
Great! Would love to hear your feedback. Note that i've only really changed 'plumbing' code and the model and postprocessing routines are untouched, as shown by the diff. Just let me know if anything is unclear.
Hi @jjhbw , I was busied with other stuffs. I ended up rolling another version for not doing the caching. I have looked at your version a while back, here are some of my concerns, I havent tested your code so this may be unfounded.
Your code relies solely only on chunk grid defined here https://github.com/vqdang/hover_net/blob/e0191055c62f8f5c579ea205071d53200bbf497a/infer_simple.py#L140 https://github.com/vqdang/hover_net/blob/e0191055c62f8f5c579ea205071d53200bbf497a/infer_simple.py#L380 which can be considered as seamless tiling of chunk output. Now, If we check the postproc here, https://github.com/vqdang/hover_net/blob/e0191055c62f8f5c579ea205071d53200bbf497a/infer_simple.py#L298 , this code remove instance within a margin (of padded size) of each chunk. From my understanding, when viewing on the entire WSI, thoses region may not contain nuclei. In case they may contain, those on the margin line/tile cross section will still be removed. I suggest you perform a viz check using a small wsi for both missing and duplications problems.
Hi @vqdang, thanks for looking over the PR.
The chunks overlap each other by a fixed-size strip of padding. For each chunk, the cells within the padding area are discarded. This redundancy should ensure that cells near the borders of a chunk are always part of another chunk, so duplicates can be safely discarded. Each cell should be present only once in the final set.
The below image may help explain the concept a bit (sorry about its low resolution). In red, you see the chunk boundaries. The tile boundaries are shown in green. I will ignore the imperfect tissue segmentation for the purposes of this discussion.
I also performed a more thorough investigation by exporting the nuclei found in neighbouring chunks to QuPath. I'll re-do it and share a few screenshots.
Hey, no movement on this?