InvertibleNetworks.jl icon indicating copy to clipboard operation
InvertibleNetworks.jl copied to clipboard

Wavelet squeeze is not GPU friendly

Open yl4070 opened this issue 2 years ago • 3 comments

Hi, on my testing, wavelet_squeeze cause scalar indexing warning. And on benchmark, the it appears it would took almost 2s to complete such calculation. Is there any way to work around this?

X = rand(Float32, 16, 16, 16, 10) |> gpu
@btime wavelet_squeeze(X)
# 1.988s

yl4070 avatar Mar 20 '22 03:03 yl4070

wavelet_squeeze uses an external Wavelet transform librairy that doesn't support GPU and shouldn't be used for GPU cases. If you want to use networks and layer on GPU involving squeeze you should use our Harr implementation that is a custom implementation of the Wavelet transform with a Harr basis that supports GPU. Check Haar_squeeze and invHaar_unsqueeze

mloubout avatar Mar 20 '22 03:03 mloubout

Thanks. I wonder if this package is still in active development? Can I expect new feature/layer/network get implemented?

yl4070 avatar Mar 20 '22 03:03 yl4070

This package is still very active yes. We haven't come across much new networks/layers to add but we are still working with this package daily for our research and updates and improvement will be added as much as we can.

mloubout avatar Mar 20 '22 03:03 mloubout