nussl icon indicating copy to clipboard operation
nussl copied to clipboard

How to use pretrained models

Open expectopatronum opened this issue 5 years ago • 2 comments

You question here I don't understand how the pretrained models (e.g. musdb+slakhv0_TG3EvX6.pth) should be used. What is the model class for each of the models (I can only see that the first two should be Deep Clustering models and if I recall correctly I managed to use them some time ago), but I can't figure out what model types the other models are.

Could you point me to the relevant documentation or provide some usage examples please?

What you tried

I tried loading the model in the following ways:

model = DeepClustering(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))
model = DeepAudioEstimation(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))
model = DeepMaskEstimation(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))

but each of them gives another error when calling model.run().

Also the usage description in the External File Zoo seems outdated:

import nussl
nussl.utils.print_available_audio_files()

utils should be efz_utils, same for model_path = nussl.utils.download_trained_model('example.model')

Thanks a lot & best regards Verena

expectopatronum avatar Jan 20 '21 14:01 expectopatronum

Hi Verena,

Sorry for the late response, but I can help you use the pretrained models. The models that are 100.00 MiB large were used the the paper: https://arxiv.org/abs/2010.12650, and follow the same general pattern to use. All of these are recurrent deep clustering models: https://arxiv.org/pdf/1508.04306.pdf

The following code should load the model.

from nussl.ml import SeparationModel
import torch

checkpoint = torch.load("path/to/checkpoint.pth")
model = SeparationModel(checkpoint["config"])
model.load_state_dict(checkpoint["state_dict"])

You should now have a loaded model! Note that this model can separate STFTs with a window length of 512 samples.

Let me know if you have any other questions. Also, this page in the documentation will be useful in handling this model: https://nussl.github.io/docs/tutorials/training.html

EDIT: fixed code

abugler avatar Jan 24 '21 22:01 abugler

Hi! Thanks a lot for your response and the link to your paper, I wasn't aware of it. I am looking forward to reading it!

What is the nussl version that I should be using? I am using 1.1.3 (I think that was the most current one when I downloaded the models 3 weeks ago), I am getting the following error:

ValueError: Expected keys ['connections', 'modules', 'name', 'output'], got ['connections', 'modules', 'output']

in the line model = SeparationModel(checkpoint["config"]).

Thanks and best regards Verena

expectopatronum avatar Feb 10 '21 09:02 expectopatronum