Pytorch-Correlation-extension
Pytorch-Correlation-extension copied to clipboard
3D images
how can this extension work on 3D images?
You probably need to extend the C++/Cuda code for that. Shouldn't be too hard, but it will be very verbose.
Essentially, you go from a 4D cost volume (x, y, shift x, and shift y) to 6D cost volume (you add the z and shift z dimensions). loops must be changed accordingly, as long as tensor expected dims, and border checks.
What you can do though is try to decompose the problem in 2D slices and use the 2D correlation here, only to restack the ( z * shift z) 4D slices together at the end. This can be done with only python, but will be much slower than rewriting it in C++.
thank you very much for your advice. I'll try it.