BS-RoFormer icon indicating copy to clipboard operation
BS-RoFormer copied to clipboard

MelRoformer parameters from paper

Open ZFTurbo opened this issue 1 year ago • 16 comments

I'm trying to reproduce the paper model. But I have no luck.

My current settings which gave me batch size only 2 for 48 GB memory:

  dim: 192
  depth: 8
  stereo: true
  num_stems: 1
  time_transformer_depth: 1
  freq_transformer_depth: 1
  num_bands: 60
  dim_head: 64
  heads: 8
  attn_dropout: 0.1
  ff_dropout: 0.1
  flash_attn: True
  dim_freqs_in: 1025
  sample_rate: 44100  # needed for mel filter bank from librosa
  stft_n_fft: 2048
  stft_hop_length: 512
  stft_win_length: 2048
  stft_normalized: False
  mask_estimator_depth: 2
  multi_stft_resolution_loss_weight: 1.0
  multi_stft_resolutions_window_sizes: !!python/tuple
  - 4096
  - 2048
  - 1024
  - 512
  - 256
  multi_stft_hop_size: 147
  multi_stft_normalized: False

On input I give 8 seconds of 44100Hz so length is 352800.

I run my code model through torchinfo:

from torchinfo import summary
summary(model, input_size=(1, 2, 352768))

Report is:

==============================================================================================================
Layer (type:depth-idx)                                       Output Shape              Param #
==============================================================================================================
MelBandRoformer                                              [1, 2, 352768]            56,503,768
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─BandSplit: 1-2                                             [1, 690, 60, 192]         --
│    └─ModuleList: 2                                         --                        --
....
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─ModuleList: 1                                              --                        --
│    └─MaskEstimator: 2-9                                    [1, 690, 7916]            --
==============================================================================================================
Total params: 69,102,468
Trainable params: 69,102,404
Non-trainable params: 64
Total mult-adds (G): 8.35
==============================================================================================================
Input size (MB): 2.82
Forward/backward pass size (MB): 703.40
Params size (MB): 232.17
Estimated Total Size (MB): 938.40
==============================================================================================================

From report I expect to have batch more than 48. But in the end I can use batch only 2.

GPU memory usage for batch = 2: изображение

To follow the paper I must increase dim to 384, depth to 12 and decrease stft_hop_length to 441 - to be 10 ms. In this case batch size will be only 1 or not fit in memory )

Any ideas how to deal with such big memory usage?

ZFTurbo avatar Dec 04 '23 18:12 ZFTurbo

there's a bunch of techniques

gradient accumulation is what you can immediately use, but also look into gradient checkpointing etc

lucidrains avatar Dec 04 '23 18:12 lucidrains

I use gradient accumulation now. But in paper authors report model fit in 32 GB with batch size = 6.

ZFTurbo avatar Dec 04 '23 18:12 ZFTurbo

how many GPUs do you have?

lucidrains avatar Dec 04 '23 18:12 lucidrains

ah, well I can't help you with that, but this is a common issue and you aught to be able to figure it out with the resources online

lucidrains avatar Dec 04 '23 18:12 lucidrains

the biggest gun I can bring out would be reversible networks, which could work for this architecture. maybe at a later date. for now, maybe accumulate gradients and wait it out?

lucidrains avatar Dec 04 '23 18:12 lucidrains

are you using mixed precision with flash attention turned on?

lucidrains avatar Dec 04 '23 18:12 lucidrains

I use mixed precision, but I'm not sure about Flash Attention.

ZFTurbo avatar Dec 04 '23 18:12 ZFTurbo

how many GPUs do you have?

I usually train on 1 GPU )

ZFTurbo avatar Dec 04 '23 18:12 ZFTurbo

def turn on flash attention if you have right hardware

it is some flag on init

lucidrains avatar Dec 04 '23 18:12 lucidrains

Wow! Reversible networks would be cool

@ZFTurbo Are we confident that MelRoFormer trained upon stereo stems? (If not, that would halve their effective batch size, if they trained mono stems.) Perhaps they didn't because: a) Bandsplit-RNN didn't either, as far as I know and b) training individual models per stem suggests that, like Bandsplit-RNN, they didn't add multistem/multichannel functionality like @lucidrains did.

HTDemucs and variants, on the other hand, trained all stems simultaneously, and then fine-tuned on individual stems. This leads to models 4x the size, but I am a little surprised the authors didn't possibly include this easy win, given how many GPUs they used to train.

turian avatar Jan 06 '24 03:01 turian

@ZFTurbo how do you collect the songs?

deyituo avatar Jan 15 '24 13:01 deyituo

@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.

carlosalberto2000 avatar Feb 19 '24 03:02 carlosalberto2000

@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.

I posted some weights here: https://github.com/ZFTurbo/Music-Source-Separation-Training/tree/main?tab=readme-ov-file#vocal-models

But SDR metric is not really great

ZFTurbo avatar Feb 19 '24 09:02 ZFTurbo

Thank you a lot! I really appreciate your work. So, the crown still goes to "MDX23C for vocals + HTDemucs4 FT" currently, right? How do you think they achieved a score so high in the MDX'23?

carlosalberto2000 avatar Feb 19 '24 13:02 carlosalberto2000

How do you think they achieved a score so high in the MDX'23?

Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.

jarredou avatar Feb 19 '24 17:02 jarredou

How do you think they achieved a score so high in the MDX'23?

Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.

Oh, I see. I got caught up with the names. Are you aware of any trained BS-RoFormer model?

carlosalberto2000 avatar Feb 19 '24 22:02 carlosalberto2000