How to use pretrained models
You question here I don't understand how the pretrained models (e.g. musdb+slakhv0_TG3EvX6.pth) should be used. What is the model class for each of the models (I can only see that the first two should be Deep Clustering models and if I recall correctly I managed to use them some time ago), but I can't figure out what model types the other models are.
Could you point me to the relevant documentation or provide some usage examples please?
What you tried
I tried loading the model in the following ways:
model = DeepClustering(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))
model = DeepAudioEstimation(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))
model = DeepMaskEstimation(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))
but each of them gives another error when calling model.run().
Also the usage description in the External File Zoo seems outdated:
import nussl
nussl.utils.print_available_audio_files()
utils should be efz_utils, same for model_path = nussl.utils.download_trained_model('example.model')
Thanks a lot & best regards Verena
Hi Verena,
Sorry for the late response, but I can help you use the pretrained models. The models that are 100.00 MiB large were used the the paper: https://arxiv.org/abs/2010.12650, and follow the same general pattern to use. All of these are recurrent deep clustering models: https://arxiv.org/pdf/1508.04306.pdf
The following code should load the model.
from nussl.ml import SeparationModel
import torch
checkpoint = torch.load("path/to/checkpoint.pth")
model = SeparationModel(checkpoint["config"])
model.load_state_dict(checkpoint["state_dict"])
You should now have a loaded model! Note that this model can separate STFTs with a window length of 512 samples.
Let me know if you have any other questions. Also, this page in the documentation will be useful in handling this model: https://nussl.github.io/docs/tutorials/training.html
EDIT: fixed code
Hi! Thanks a lot for your response and the link to your paper, I wasn't aware of it. I am looking forward to reading it!
What is the nussl version that I should be using? I am using 1.1.3 (I think that was the most current one when I downloaded the models 3 weeks ago), I am getting the following error:
ValueError: Expected keys ['connections', 'modules', 'name', 'output'], got ['connections', 'modules', 'output']
in the line model = SeparationModel(checkpoint["config"]).
Thanks and best regards Verena