PyTorch-Wavelet-Toolbox
PyTorch-Wavelet-Toolbox copied to clipboard
3D Matrix Wavelet Decomposition
I have a specific application where the sparse matrix representation of the DWT is really useful for a 3D signals. I'd like to propose it as an enhancement.
Yes, we currently do not support sparse matrix transforms in 3D. I agree it would be very nice to have.
This is going to be a big project. The last time I checked, PyTorch did not provide a sparse QR solver. We already have a gram-Schmitt alternative implemented, but it's not in CUDA and is slow. It works in the 2D case, but I expect 3D matrices to be bigger and sparser. Personally, this is something I will look at when I have an extended time block available in the future.
As always, contributions are welcome.
Dear @johnryan465 , after installing the most recent release candidate via,
pip install git+ssh://[email protected]/v0lta/[email protected]
you should be able to run the following:
import ptwt, torch, pywt
import numpy as np
data = torch.rand(2, 32, 32, 32).type(torch.float64)
matrixfwt = ptwt.MatrixWavedec3(pywt.Wavelet("haar"), level=2)
mat_coeff = matrixfwt(data)
matrixifwt = ptwt.MatrixWaverec3(pywt.Wavelet("haar"))
reconstruction = matrixifwt(mat_coeff)
np.allclose(reconstruction.numpy(), data.numpy())
which computes and inverts a separable sparse matrix boundary 3d-dwt.