music-spectrogram-diffusion-pytorch
music-spectrogram-diffusion-pytorch copied to clipboard
Multi-instrument Music Synthesis with Spectrogram Diffusion
An unofficial PyTorch implementation of the paper Multi-instrument Music Synthesis with Spectrogram Diffusion, adapted from official codebase. We aim to increase the reproducibility of their work by providing training code and pre-trained models in PyTorch.
Data Preparation
Please download the following datasets.
Create Clean MIDI for URMP
The MIDI files in the URMP dataset mostly don't contain the correct program number. Use the clean_urmp_midi.py
script to create a new set of MIDI files that contain the correct program number corresponding to the instruments in the file names.
Training
Small Autoregressive
python main.py fit --config cfg/ar_small.yaml
Diffusion, Small without Context
python main.py fit --config cfg/diff_small.yaml
Diffusion, Small with Context
python main.py fit --config cfg/diff_small.yaml --data.init_args.with_context true --model.init_args.with_context true
Diffusion, Base with Context
python main.py fit --config cfg/diff_base.yaml
Remember to change the path arguments under the data
section of the yaml files to where you downloaded the dataset, or set them using --data.init_args.*_path
keyword in commandline.
You can also set the path to null
if you want to ommit that dataset.
Notice that URMP requires one extra path argument, which is where you create the clean MIDI.
To adjust other hyperparmeters, please refer to LightningCLI documentation for more information.
Evaluating
The following command will compute the Reconstruction and FAD metrics using the embeddings from the VGGish and TRILL models and reporting the averages across the whole test dataset.
python main.py test --config config.yaml --ckpt_path your_checkpoint.ckpt
Inferencing
To synthesize audio from MIDI with trained models:
python infer.py input.mid checkpoint.ckpt config.yaml output.wav
Pre-Trained Models
We provided three pre-trained models corresponding to the diffusion baselines in the paper.
We trained them following the settings in the paper besides the batch size, which we reduced to 8 due to limited computational resources. We evaluated these models using our codebase and summarized them in the following table:
Models | VGGish Recon | VGGish FAD | Trill Recon | Trill FAD |
---|---|---|---|---|
Small w/o Context | 2.48 | 0.49 | 0.84 | 0.08 |
Small w/ Context | 2.44 | 0.59 | 0.68 | 0.04 |
Base w/ Context | - | - | - | - |
Ground Truth Encoded | 1.80 | 0.80 | 0.35 | 0.02 |
TODO
- [ ] Use MidiTok for tokenization.
- [ ] Use torchvggish for vggish embeddings.
- [ ] Remove context encoder and use inpainting techniques for segment-by-segment generation, similar to https://github.com/archinetai/audio-diffusion-pytorch.
- [ ] Encoder-free Classifier-Guidance generation with MT3.