BS-RoFormer
BS-RoFormer copied to clipboard
MelRoformer parameters from paper
I'm trying to reproduce the paper model. But I have no luck.
My current settings which gave me batch size only 2 for 48 GB memory:
dim: 192
depth: 8
stereo: true
num_stems: 1
time_transformer_depth: 1
freq_transformer_depth: 1
num_bands: 60
dim_head: 64
heads: 8
attn_dropout: 0.1
ff_dropout: 0.1
flash_attn: True
dim_freqs_in: 1025
sample_rate: 44100 # needed for mel filter bank from librosa
stft_n_fft: 2048
stft_hop_length: 512
stft_win_length: 2048
stft_normalized: False
mask_estimator_depth: 2
multi_stft_resolution_loss_weight: 1.0
multi_stft_resolutions_window_sizes: !!python/tuple
- 4096
- 2048
- 1024
- 512
- 256
multi_stft_hop_size: 147
multi_stft_normalized: False
On input I give 8 seconds of 44100Hz so length is 352800.
I run my code model through torchinfo:
from torchinfo import summary
summary(model, input_size=(1, 2, 352768))
Report is:
==============================================================================================================
Layer (type:depth-idx) Output Shape Param #
==============================================================================================================
MelBandRoformer [1, 2, 352768] 56,503,768
├─ModuleList: 1-1 -- --
│ └─ModuleList: 2-1 -- 384
│ │ └─Transformer: 3-77 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-78 [690, 60, 192] (recursive)
│ └─ModuleList: 2-2 -- 384
│ │ └─Transformer: 3-79 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-80 [690, 60, 192] (recursive)
│ └─ModuleList: 2-3 -- 384
│ │ └─Transformer: 3-81 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-82 [690, 60, 192] (recursive)
│ └─ModuleList: 2-4 -- 384
│ │ └─Transformer: 3-83 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-84 [690, 60, 192] (recursive)
│ └─ModuleList: 2-5 -- 384
│ │ └─Transformer: 3-85 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-86 [690, 60, 192] (recursive)
│ └─ModuleList: 2-6 -- 384
│ │ └─Transformer: 3-87 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-88 [690, 60, 192] (recursive)
│ └─ModuleList: 2-7 -- 384
│ │ └─Transformer: 3-89 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-90 [690, 60, 192] (recursive)
│ └─ModuleList: 2-8 -- 384
│ │ └─Transformer: 3-91 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-92 [690, 60, 192] (recursive)
├─BandSplit: 1-2 [1, 690, 60, 192] --
│ └─ModuleList: 2 -- --
....
├─ModuleList: 1-1 -- --
│ └─ModuleList: 2-1 -- 384
│ │ └─Transformer: 3-77 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-78 [690, 60, 192] (recursive)
│ └─ModuleList: 2-2 -- 384
│ │ └─Transformer: 3-79 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-80 [690, 60, 192] (recursive)
│ └─ModuleList: 2-3 -- 384
│ │ └─Transformer: 3-81 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-82 [690, 60, 192] (recursive)
│ └─ModuleList: 2-4 -- 384
│ │ └─Transformer: 3-83 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-84 [690, 60, 192] (recursive)
│ └─ModuleList: 2-5 -- 384
│ │ └─Transformer: 3-85 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-86 [690, 60, 192] (recursive)
│ └─ModuleList: 2-6 -- 384
│ │ └─Transformer: 3-87 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-88 [690, 60, 192] (recursive)
│ └─ModuleList: 2-7 -- 384
│ │ └─Transformer: 3-89 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-90 [690, 60, 192] (recursive)
│ └─ModuleList: 2-8 -- 384
│ │ └─Transformer: 3-91 [60, 690, 192] (recursive)
│ │ └─Transformer: 3-92 [690, 60, 192] (recursive)
├─ModuleList: 1 -- --
│ └─MaskEstimator: 2-9 [1, 690, 7916] --
==============================================================================================================
Total params: 69,102,468
Trainable params: 69,102,404
Non-trainable params: 64
Total mult-adds (G): 8.35
==============================================================================================================
Input size (MB): 2.82
Forward/backward pass size (MB): 703.40
Params size (MB): 232.17
Estimated Total Size (MB): 938.40
==============================================================================================================
From report I expect to have batch more than 48. But in the end I can use batch only 2.
GPU memory usage for batch = 2:
To follow the paper I must increase dim to 384, depth to 12 and decrease stft_hop_length to 441 - to be 10 ms. In this case batch size will be only 1 or not fit in memory )
Any ideas how to deal with such big memory usage?
there's a bunch of techniques
gradient accumulation is what you can immediately use, but also look into gradient checkpointing etc
I use gradient accumulation now. But in paper authors report model fit in 32 GB with batch size = 6.
how many GPUs do you have?
ah, well I can't help you with that, but this is a common issue and you aught to be able to figure it out with the resources online
the biggest gun I can bring out would be reversible networks, which could work for this architecture. maybe at a later date. for now, maybe accumulate gradients and wait it out?
are you using mixed precision with flash attention turned on?
I use mixed precision, but I'm not sure about Flash Attention.
how many GPUs do you have?
I usually train on 1 GPU )
def turn on flash attention if you have right hardware
it is some flag on init
Wow! Reversible networks would be cool
@ZFTurbo Are we confident that MelRoFormer trained upon stereo stems? (If not, that would halve their effective batch size, if they trained mono stems.) Perhaps they didn't because: a) Bandsplit-RNN didn't either, as far as I know and b) training individual models per stem suggests that, like Bandsplit-RNN, they didn't add multistem/multichannel functionality like @lucidrains did.
HTDemucs and variants, on the other hand, trained all stems simultaneously, and then fine-tuned on individual stems. This leads to models 4x the size, but I am a little surprised the authors didn't possibly include this easy win, given how many GPUs they used to train.
@ZFTurbo how do you collect the songs?
@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.
@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.
I posted some weights here: https://github.com/ZFTurbo/Music-Source-Separation-Training/tree/main?tab=readme-ov-file#vocal-models
But SDR metric is not really great
Thank you a lot! I really appreciate your work. So, the crown still goes to "MDX23C for vocals + HTDemucs4 FT" currently, right? How do you think they achieved a score so high in the MDX'23?
How do you think they achieved a score so high in the MDX'23?
Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.
How do you think they achieved a score so high in the MDX'23?
Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.
Oh, I see. I got caught up with the names. Are you aware of any trained BS-RoFormer model?