PyTorch-Wavelet-Toolbox
PyTorch-Wavelet-Toolbox copied to clipboard
Support for Differentiable CWT
Hello!
I was wondering if it would be possible to support a differentiable version of the ptwt.continuous_transform.cwt function. I see that internally, the function converts everything to numpy arrays, and so it's not able to handle input tensors with gradients attached to them.
This would be very useful for my case where I'm using CWT scalograms for computing a similarity score/loss between signals. I understand that several other transforms support gradients e.g. wavedec and waverec, which work fantastically in ML pipelines I've tested, so I was hoping that such functionality could be extended to the continuous transform as well.
Cheers!
Dear @david-andrew, You are right, there is still a lot of NumPy code in the cwt module. We have a test to check the GPU functionality, but backprop into the NumPy code for the continuous wavelet is impossible. I would love to allow backprop into the cwt. I may be able to make time for this feature this autumn. Of course, contributions are welcome.
@v0lta I might be interested in taking a look since I probably would be implementing my own version in the meantime anyways. What do you think the solution might entail? e.g. is it just a case of adjusting the function to replace instances of numpy with torch + making sure the inputs/outputs match + checking that gradients can flow through, or do you think a solution would be more involved?
I'm also not familiar with this library's development process/norms, so good to know if there's anything to be aware of
@david-andrew I would expect the development to turn out a lot like you are describing it.
Regarding the development workflow, the idea is to follow the best practices as established by the Python community. The ptwt library uses nox to automate testing. You can run the test pipeline locally using nox -s test. Similarly nox -s lint checks the code style using flake8. You do not have to fix everything yourself. nox -s format will fix many formatting issues for you. Finally, you can double-check your typing with nox -s typing. If you add support for a new feature, I would expect tests for it in the tests folder.
Regarding the linting, I have hardcoded version four in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/commit/cc13922e090fbd76f0150ebd718bcb57d77fbb54 . I suggest we adopt the newest version when people have sorted out the problem described in https://github.com/tholo/pytest-flake8/issues/87 .
Dear @david-andrew ,
I have a first working prototype. See https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/diff_cwt/examples/continuous_signal_analysis/adaptive_cwt.py , the cwt now allows backdrop into wavelets. I am still looking for a reasonable cost function for the example. So that is still missing. However, I think the diff_cwt branch may already be worth a closer look.
I think we have no established cost function yet. The cost won't be part of the toolbox to avoid experimental features. The new cwt function from the diff_cwt branch supports gradient descent and should be usable.
Until the next release is done the new cwt is installable via the command below.
pip install git+ssh://[email protected]/v0lta/PyTorch-Wavelet-Toolbox.git@diff_cwt
Awesome! I really appreciate you putting this all together!
I was able to run the example you mentioned, looks like it works great! I presume you haven't looked into gpu/cuda support yet since it looks like the scales computations still use numpy. But this looks like an awesome first step
I have added a unit test for GPU support with differentiable continuous wavelets here: https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/8a00d6693a825f1b9805a11c358334c64bbc9b97/tests/test_cwt.py#L111 differentiable cwt computations on graphics cards should now work as expected.
@david-andrew does it work for your use case now?
Gave it a shot and it sort of looks like it works. I was able to get it to run a single training iteration, but then it crashes due to encountering tensors on both cpu and gpu.
Here's a minimal version of what I'm doing:
import torch
from torch import nn
import numpy as np
import ptwt
from ptwt.continuous_transform import _ComplexMorletWavelet
class ScalogramLoss(nn.Module):
"""Complex Continuous Wavelet Transform Loss"""
def __init__(self, wavelet, octaves=9, octave_divs=24, alpha=0.5, eps=1e-8):
super().__init__()
self.scales = 2**(torch.arange(octave_divs*octaves)/octave_divs + 1)
self.wavelet = wavelet
self.alpha = alpha
self.eps = eps
def forward(self, x, x_hat):
S, _ = ptwt.continuous_transform.cwt(x, self.scales, self.wavelet)
S_hat, _ = ptwt.continuous_transform.cwt(x_hat, self.scales, self.wavelet)
S, S_hat = S.abs(), S_hat.abs() #take the magnitude of the complex wavelet transform
linear_term = nn.functional.l1_loss(S, S_hat)
log_term = nn.functional.l1_loss((S + self.eps).log2(), (S_hat + self.eps).log2())
return self.alpha * linear_term + (1 - self.alpha) * log_term
def main():
duration = 10 # seconds
fs = 44100
sig = np.sin(np.arange(int(fs*duration))*2*np.pi*440/fs)
sig = torch.Tensor(sig).cuda()
#reconstruct signal starting from random noise
sig_hat = torch.randn_like(sig, requires_grad=True).cuda()
wavelet = _ComplexMorletWavelet(name='cmor0.5-0.5').cuda()
optim = torch.optim.Adam([sig_hat], lr=1e-3)
loss_fn = ScalogramLoss(wavelet=wavelet)
iterations = 500
for it in range(iterations):
loss = loss_fn(sig, sig_hat)
loss.backward()
optim.step()
print(f'{it}: {loss.item()}')
if __name__ == "__main__":
main()
and the error I'm getting:
$ python test2.py
0: 1.4774491040433757
Traceback (most recent call last):
File "test2.py", line 52, in <module>
main()
File "test2.py", line 46, in main
loss.backward()
File "/home/david/anaconda3/envs/audio/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/david/anaconda3/envs/audio/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
I haven't had enough time to figure out if it's a problem with the ptwt implementation, or something else on my end. But exciting to see it work for at least a single iteration! I'm not sure I have a ton of time to focus on this, but I'll keep poking around at it when I can.
Dear @david-andrew ,
I think I found the problem on my end. I was moving the wavelet module to CPU to run the pywt-code for the axis labels. Afterward, it has to move back to the GPU. Commit https://github.com/v0lta/PyTorch-Wavelet-Toolbox/commit/f9dbff38a795d4b4e771d804858ef76f01c634ff fixes the problem. Your example runs on my machine now. The commit is in the wavedec2-improved-errors branch now, diff_cwt merged with the changes suggested in issue https://github.com/v0lta/PyTorch-Wavelet-Toolbox/issues/40 .
For testing
pip install git+ssh://[email protected]/v0lta/PyTorch-Wavelet-Toolbox.git@wavedec2-improved-errors
should do the job.
Awesome, yeah looks to be working great on my side as well!