pytorch-NMF icon indicating copy to clipboard operation
pytorch-NMF copied to clipboard

Learning end2end with a neural network

Open jonnor opened this issue 2 years ago • 3 comments

Hi, thank you for this nice project.

Could one connect a neural network to the NFM module, and learn them at the same time? Any example code or tips on how to do that? I am interested in using a convolutional neural network frontend on spectrogram data, and capture a bit more complex activations than single stationary spectrogram frames.

jonnor avatar Aug 20 '21 16:08 jonnor

Hi @jonnor ,

Could one connect a neural network to the NFM module, and learn them at the same time? Any example code or tips on how to do that? I am interested in using a convolutional neural network frontend on spectrogram data, and capture a bit more complex activations than single stationary spectrogram frames.

I do plan to add some examples as jupyter notebooks but I'm currently busy at other projects. Your application sounds totally doable to me, but you have to make sure that all the gradients pass from the loss to the NMF parameters are always non-negative.

For example, you want to train a model that will predict the activations, and learn a shared non-negative template jointly, then you can do something like this:

import torch
from torch import nn
from torch import optim
from torchnmf.trainer import BetaMu
from torchnmf import NMF


#pick an activation function so the output is non-negative
H = nn.Sequential(AnotherModel(), nn.Softplus())       
W = NMF(W=(out_channels, in_channels))

optimizer = optim.Adm(H.parameters())
trainer = BetaMu(W.parameters())

for x, y in dataloader:
     # optimize NMF
    def closure():
        trainer.zero_grad()
        with torch.no_grad():
            h = H(x)
        return y, W(H=h)
    trainer.step(closure)

    # optimize nueral net
    h = H(x)
    predict = W(H=h)
    loss = ... # you can use other types of loss here
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

yoyololicon avatar Aug 21 '21 01:08 yoyololicon

Hi @yoyololicon - thank you for the response and example code! In this framework, the loss would be something that compares the output of the NMF (decomposed and re-composed)? Like RMS as a simple case, or a perceptual metric for something more advanced?

jonnor avatar Aug 31 '21 11:08 jonnor

Hi @yoyololicon - thank you for the response and example code! In this framework, the loss would be something that compares the output of the NMF (decomposed and re-composed)? Like RMS as a simple case, or a perceptual metric for something more advanced?

@jonnor Yes, in the above code you are free to use these kinds of loss function, not only beta divergence. The NMF part is still trained with beta divergence though.

yoyololicon avatar Aug 31 '21 14:08 yoyololicon