transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add training support for EnCodec

Open ArthurZucker opened this issue 1 year ago • 15 comments

Feature request

Would be cool to add training support for the EnCodec model. Not entirely sure if we can easily make it compatible with Trainer, so this can be a good second issue I think.

Motivation

…

Your contribution

…

ArthurZucker avatar Jun 15 '23 09:06 ArthurZucker

Hi @ArthurZucker , I want to try this, please assign it to me. Thanks.

Swastyy avatar Jun 17 '23 03:06 Swastyy

Sure! Feel free to open a PR and ping me

ArthurZucker avatar Jun 22 '23 11:06 ArthurZucker

@Swastyy @ArthurZucker Let me know if you are looking for any support. I would also like to help with this if possible. Thanks!

hackyon avatar Jun 26 '23 12:06 hackyon

Seems like he did not link a PR, feel free to synch and ping me for any help! Even a draft is good!

ArthurZucker avatar Jun 26 '23 13:06 ArthurZucker

Hi @ArthurZucker can you let me know of the overall changes that have to be made. I see the EnCodec model already implemented in transformers, so to integrate it with Trainer what are the additional requirements?

rishabbala avatar Jun 26 '23 15:06 rishabbala

The idea is mostly to integrate the loss computation for the VQVAE! Trainer might not work as the model does not use attention, but the target should be to have the same loss as the original model !

ArthurZucker avatar Jun 27 '23 05:06 ArthurZucker

Thanks Arthur.

I read through the paper (https://arxiv.org/pdf/2210.13438.pdf) and the existing code, and here is my impression on the work breakdown. Does any of this make sense or am I going in a totally wrong direction?

The loss function detailed in the paper (equation 4) is a combination of (1) the reconstruction loss (over frequency and time domains), (2) the discriminative loss (requires the discriminator), and (3) the VQ commitment loss (the quantizer loss).

(1) The reconstruction loss is computed using the original audio as the label, and we basically need to apply certain time and frequency transformations to the input/output and compute the L1/L2 distances between them.

(2) The discriminative loss requires a discriminator. As far as I can tell, this hasn't been ported/implemented yet and we'll need to do it if we wanted to compute the loss as stated in the paper (msstftd.py from facebookresearch). We'll need to hook up the discriminator in the training code somewhere (is there any pretrained discriminator here?). Also, it's unclear to me whether we can train the discriminator and the model/generator at the same time (I'm assuming not, and we'll need to train one at a time).

(3) The VQ commitment loss is from the quantizer. It looks like it's summing up the losses across all the residual steps. Are we supposed to train the quantizer at the same time as the encoder/decoders? Or should we train them at different times?

In addition to the general loss function, the paper introduced a balancer (balancer.py) that weighs the reconstruction, discriminative, and commitment losses differently. We would also need to import the balancer code if we want this special balancer.

hackyon avatar Jun 27 '23 09:06 hackyon

Makes sense to me! I think you can focus simply on returning the loss for the modules. The order of training is not that important (when implementing the module wise loss) since you don't need to train (but compare output losses) until you have eveything!

For the discriminator, you can live it in the training file! It should be pretty small and that's usually how we do things 🤗 ! The order of training, on what is frozen when should be in the paper/original codebase, have not looked it up!

ArthurZucker avatar Jun 27 '23 13:06 ArthurZucker

I'll attempt to code up (3) VQ commitment loss first then. I'll reach out if I get stuck or run into any issues. Thanks!

hackyon avatar Jun 27 '23 15:06 hackyon

I added an initial draft here: https://github.com/huggingface/transformers/commit/4f697be0b62c4f3b0401ccbd00d1d46aac81906d

Can you take a look and let me know what you think? Thanks

hackyon avatar Jun 28 '23 09:06 hackyon

FYI I will be traveling in July, so won't be as available that month.

hackyon avatar Jun 30 '23 04:06 hackyon

Sure, would you mind opening a proper PR? Would be easier to test locally and visualize and follow changes!

ArthurZucker avatar Jun 30 '23 09:06 ArthurZucker

So cool, I reproduced the code and release the code. If you have any question, we can solve together. https://github.com/NoFish-528/encodec-pytorch @hackyon @ArthurZucker In this work, I haven't add balancer, it's difficult for me... Hope you can successful

ZhikangNiu avatar Aug 26 '23 10:08 ZhikangNiu