PyTorch-Wavelet-Toolbox icon indicating copy to clipboard operation
PyTorch-Wavelet-Toolbox copied to clipboard

Support for Differentiable CWT

Open david-andrew opened this issue 3 years ago • 4 comments

Hello!

I was wondering if it would be possible to support a differentiable version of the ptwt.continuous_transform.cwt function. I see that internally, the function converts everything to numpy arrays, and so it's not able to handle input tensors with gradients attached to them.

This would be very useful for my case where I'm using CWT scalograms for computing a similarity score/loss between signals. I understand that several other transforms support gradients e.g. wavedec and waverec, which work fantastically in ML pipelines I've tested, so I was hoping that such functionality could be extended to the continuous transform as well.

Cheers!

david-andrew avatar Aug 28 '22 20:08 david-andrew

Dear @david-andrew, You are right, there is still a lot of NumPy code in the cwt module. We have a test to check the GPU functionality, but backprop into the NumPy code for the continuous wavelet is impossible. I would love to allow backprop into the cwt. I may be able to make time for this feature this autumn. Of course, contributions are welcome.

v0lta avatar Aug 29 '22 14:08 v0lta

@v0lta I might be interested in taking a look since I probably would be implementing my own version in the meantime anyways. What do you think the solution might entail? e.g. is it just a case of adjusting the function to replace instances of numpy with torch + making sure the inputs/outputs match + checking that gradients can flow through, or do you think a solution would be more involved?

I'm also not familiar with this library's development process/norms, so good to know if there's anything to be aware of

david-andrew avatar Aug 30 '22 02:08 david-andrew

@david-andrew I would expect the development to turn out a lot like you are describing it.

Regarding the development workflow, the idea is to follow the best practices as established by the Python community. The ptwt library uses nox to automate testing. You can run the test pipeline locally using nox -s test. Similarly nox -s lint checks the code style using flake8. You do not have to fix everything yourself. nox -s format will fix many formatting issues for you. Finally, you can double-check your typing with nox -s typing. If you add support for a new feature, I would expect tests for it in the tests folder.

v0lta avatar Aug 30 '22 08:08 v0lta

Regarding the linting, I have hardcoded version four in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/commit/cc13922e090fbd76f0150ebd718bcb57d77fbb54 . I suggest we adopt the newest version when people have sorted out the problem described in https://github.com/tholo/pytest-flake8/issues/87 .

v0lta avatar Aug 30 '22 09:08 v0lta

Dear @david-andrew , I have a first working prototype. See https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/diff_cwt/examples/continuous_signal_analysis/adaptive_cwt.py , the cwt now allows backdrop into wavelets. I am still looking for a reasonable cost function for the example. So that is still missing. However, I think the diff_cwt branch may already be worth a closer look.

v0lta avatar Oct 10 '22 10:10 v0lta

I think we have no established cost function yet. The cost won't be part of the toolbox to avoid experimental features. The new cwt function from the diff_cwt branch supports gradient descent and should be usable.

v0lta avatar Oct 10 '22 14:10 v0lta

Until the next release is done the new cwt is installable via the command below.

pip install git+ssh://[email protected]/v0lta/PyTorch-Wavelet-Toolbox.git@diff_cwt

v0lta avatar Oct 10 '22 15:10 v0lta

Awesome! I really appreciate you putting this all together!

I was able to run the example you mentioned, looks like it works great! I presume you haven't looked into gpu/cuda support yet since it looks like the scales computations still use numpy. But this looks like an awesome first step

david-andrew avatar Oct 10 '22 20:10 david-andrew

I have added a unit test for GPU support with differentiable continuous wavelets here: https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/8a00d6693a825f1b9805a11c358334c64bbc9b97/tests/test_cwt.py#L111 differentiable cwt computations on graphics cards should now work as expected.

v0lta avatar Oct 11 '22 09:10 v0lta

@david-andrew does it work for your use case now?

v0lta avatar Oct 11 '22 18:10 v0lta

Gave it a shot and it sort of looks like it works. I was able to get it to run a single training iteration, but then it crashes due to encountering tensors on both cpu and gpu.

Here's a minimal version of what I'm doing:

import torch
from torch import nn
import numpy as np
import ptwt
from ptwt.continuous_transform import _ComplexMorletWavelet


class ScalogramLoss(nn.Module):
    """Complex Continuous Wavelet Transform Loss"""
    def __init__(self, wavelet, octaves=9, octave_divs=24, alpha=0.5, eps=1e-8):
        super().__init__()
        self.scales = 2**(torch.arange(octave_divs*octaves)/octave_divs + 1)
        self.wavelet = wavelet
        self.alpha = alpha
        self.eps = eps

    def forward(self, x, x_hat):
        S, _ = ptwt.continuous_transform.cwt(x, self.scales, self.wavelet)
        S_hat, _ = ptwt.continuous_transform.cwt(x_hat, self.scales, self.wavelet)
        S, S_hat = S.abs(), S_hat.abs() #take the magnitude of the complex wavelet transform

        linear_term = nn.functional.l1_loss(S, S_hat)
        log_term = nn.functional.l1_loss((S + self.eps).log2(), (S_hat + self.eps).log2())

        return self.alpha * linear_term + (1 - self.alpha) * log_term



def main():
    duration = 10 # seconds
    fs = 44100
    sig = np.sin(np.arange(int(fs*duration))*2*np.pi*440/fs)
    sig = torch.Tensor(sig).cuda()
    
    #reconstruct signal starting from random noise
    sig_hat = torch.randn_like(sig, requires_grad=True).cuda()

    wavelet = _ComplexMorletWavelet(name='cmor0.5-0.5').cuda()

    optim = torch.optim.Adam([sig_hat], lr=1e-3)
    loss_fn = ScalogramLoss(wavelet=wavelet)
    
    iterations = 500
    for it in range(iterations):
        loss = loss_fn(sig, sig_hat)
        loss.backward()
        optim.step()
        print(f'{it}: {loss.item()}')
       

if __name__ == "__main__":
    main()

and the error I'm getting:

$ python test2.py
0: 1.4774491040433757
Traceback (most recent call last):
  File "test2.py", line 52, in <module>
    main()
  File "test2.py", line 46, in main
    loss.backward()
  File "/home/david/anaconda3/envs/audio/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/david/anaconda3/envs/audio/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I haven't had enough time to figure out if it's a problem with the ptwt implementation, or something else on my end. But exciting to see it work for at least a single iteration! I'm not sure I have a ton of time to focus on this, but I'll keep poking around at it when I can.

david-andrew avatar Oct 12 '22 03:10 david-andrew

Dear @david-andrew , I think I found the problem on my end. I was moving the wavelet module to CPU to run the pywt-code for the axis labels. Afterward, it has to move back to the GPU. Commit https://github.com/v0lta/PyTorch-Wavelet-Toolbox/commit/f9dbff38a795d4b4e771d804858ef76f01c634ff fixes the problem. Your example runs on my machine now. The commit is in the wavedec2-improved-errors branch now, diff_cwt merged with the changes suggested in issue https://github.com/v0lta/PyTorch-Wavelet-Toolbox/issues/40 . For testing

pip install git+ssh://[email protected]/v0lta/PyTorch-Wavelet-Toolbox.git@wavedec2-improved-errors

should do the job.

v0lta avatar Oct 12 '22 09:10 v0lta

Awesome, yeah looks to be working great on my side as well!

david-andrew avatar Oct 12 '22 12:10 david-andrew