muse-maskgit-pytorch icon indicating copy to clipboard operation
muse-maskgit-pytorch copied to clipboard

WIP: Adding training script

Open isamu-isozaki opened this issue 2 years ago • 26 comments

This is a work in progress pr. For now, the plan is

  • [x] Add BaseTrainer
  • [x] Abstract away parts of the VQGANVAETrainer to the BaseTrainer
  • [x] Add MaskGitTrainer
  • [x] Make training script for vae in train_muse_vae.py
  • [x] Make training script for mask git trainer in train_maskgit.py
  • [ ] Do saving loading using accelerate like shown here
  • [ ] Add inference script in infer_muse.py
  • [x] Add text and image dataset integrated with huggingface datasets
  • [x] Tokenize text within the dataset
  • [ ] Test training scripts
  • [ ] Add support for loading partial weights from google's mask git repo(without condition)
  • [ ] Add support for loading partial weights from paella
  • [ ] Add support for colossalai(if I have time+good way to do it)

There's quite a bit of detail missing here but overall this is the idea. @lucidrains Will ping you when I'm done!

isamu-isozaki avatar Feb 19 '23 23:02 isamu-isozaki

@isamu-isozaki Boss.

thank you!

lucidrains avatar Feb 21 '23 17:02 lucidrains

@lucidrains np! Will do some testing today. Feel free to ask for any changes btw!

isamu-isozaki avatar Feb 21 '23 17:02 isamu-isozaki

@ZeroCool940711 fixed some of the bugs so nearly done with a basic script!

isamu-isozaki avatar Feb 23 '23 18:02 isamu-isozaki

vae training script seems to be working with wandb/tensorboard integrated! Next trying out mask git

isamu-isozaki avatar Feb 24 '23 03:02 isamu-isozaki

Got training for both base maskgit and vae working!

isamu-isozaki avatar Feb 24 '23 05:02 isamu-isozaki

I'll test a bit more and then move on to the other tasks above.

isamu-isozaki avatar Feb 24 '23 05:02 isamu-isozaki

Got training for both base maskgit and vae working!

wow! :100: you should get in touch with huggingface, specifically @patrickvonplaten

they are looking at replicating this!

ps: people like you who can tinker and train a model to completion is in short supply :smile:

lucidrains avatar Feb 24 '23 18:02 lucidrains

@lucidrains haha thank you! Sent him a message. And thanks for the kind words. Wouldn't have been able to do it without your great work!

isamu-isozaki avatar Feb 24 '23 19:02 isamu-isozaki

Currently @ZeroCool940711 is testing if it properly trains in the long term

isamu-isozaki avatar Feb 24 '23 19:02 isamu-isozaki

@wnakano Added some code to support loading text from directories+adding support for multiple descriptions+fixing the tokenizer. Training with text+image pairs from directories should now be way more managable

isamu-isozaki avatar Feb 25 '23 14:02 isamu-isozaki

He's next working on trying loading maskgit weights from the google repo which will be pretty huge!

isamu-isozaki avatar Feb 25 '23 14:02 isamu-isozaki

Added some qol such as saving a folder of image text pairs in a huggingface dataset. Now, @ZeroCool940711 is training the base maskgit on resolution 384x384 on his INE dataset and we'll later test on cifar10 and mnist as well!

isamu-isozaki avatar Feb 25 '23 17:02 isamu-isozaki

@ZeroCool940711 added resume paths for both maskgit and vae so you can resume training!

isamu-isozaki avatar Feb 25 '23 20:02 isamu-isozaki

This is a wandb run on base maskgit on mnist. The prompt is 4 but we accidentally enabled flipping image for augmentation. Check out the wandb run here it was trained by @ZeroCool940711 on 8gb of gpu and all the train parameters are there. He's also uploading checkpoints to here

isamu-isozaki avatar Feb 27 '23 22:02 isamu-isozaki

Main params: batch size: 400, image size: 64, cosine_with_restarts, warm up steps: 1000, dim:32, vq_codebook_size:64, num_tokens: 64

isamu-isozaki avatar Feb 27 '23 23:02 isamu-isozaki

We got 7 here

isamu-isozaki avatar Feb 27 '23 23:02 isamu-isozaki

It took less than 1K steps for it to be able to generate the next number correctly, it was a 4 before and now a 7, the first sample was more like a 1 than a 7 but after around 600 steps more of training the samples started to look more and more like 7s.

ZeroCool940711 avatar Feb 27 '23 23:02 ZeroCool940711

We plan to next move onto more complicated datasets like cifar10 and hopefully test out super res training too

isamu-isozaki avatar Feb 27 '23 23:02 isamu-isozaki

wow, you kids are incredible. and fast!

maybe you'll be part of the next generation of AI engineers at Cerebral Valley 😂

lucidrains avatar Feb 28 '23 16:02 lucidrains

@lucidrains not as fast as you! And haha thanks I'll take that as a compliment.

isamu-isozaki avatar Feb 28 '23 16:02 isamu-isozaki

I feel out of place for some reason. Most of you guys with PHDs and others going to college, then somehow there is also this guy (me) who didn't even finish high school and don't even consider himself a programmer working on this stuff with you guys :V well, I'm having fun with this so that's probably what counts.

ZeroCool940711 avatar Feb 28 '23 16:02 ZeroCool940711

Wow amazing work here @isamu-isozaki! cc'ing @patil-suraj here as well

patrickvonplaten avatar Mar 01 '23 15:03 patrickvonplaten

Update. @wnakano added code for loading pretrained vaes from taming transformers and latent diffusion. So less training is needed most likely!

For my end, I'm working with @dsbuddy, who I worked on the medsegdiff training script with, and I think we are close to finishing training the vqvae for cifar10 and will move on to base maskgit training on cifar10 tomorrow.

isamu-isozaki avatar Mar 04 '23 06:03 isamu-isozaki

@patrickvonplaten oh awesome seeing you here! We tried our best to keep the code clean/huggingface format so hope you like it!

isamu-isozaki avatar Mar 04 '23 06:03 isamu-isozaki

Hi! Brief update. Laion seems pretty interested in first improving the vqgan so we are trying that out now! The main improvements we are thinking of for now are

  1. Add an option for a vit vqgan inspired from here
  2. Add in wavelet transforms similar to simple diffusion mainly from crowsonkb's scripts so that high-frequency noise is removed
  3. Add an option for other timm models in place of the vgg discriminator

We are still reading research papers on vqgans so if anyone has any recommendations for papers let us know!

isamu-isozaki avatar Mar 10 '23 17:03 isamu-isozaki

@lucidrains Sorry forgot to do an update here for close to 2 months haha. I mainly moved to huggingface/open-muse here. And we are doing some pretty exciting stuff! We already got some pretty good results on Imagenet so we are mainly scaling up+combining insights from paella. Once that's done happy to finish this pr!

isamu-isozaki avatar May 02 '23 02:05 isamu-isozaki