DALLE-pytorch icon indicating copy to clipboard operation
DALLE-pytorch copied to clipboard

Allowing for custom trained VQGAN during DALLE training

Open TheodoreGalanos opened this issue 3 years ago • 14 comments

It would be nice to allow for custom trained VQGAN models when training DALLE.

Right now the library hard-codes usage to the pretrained imagenet model, allowing a user input for the model to be used (along with params) for train_dalle.py would be great.

p.s. The download scripts seem to bring the model to .cache which, in my tests, bypasses the custom model one can pass in the VQGAN1024 class. They would also probably need to be included in the option.

I have this running on my own tests, if it's difficult to find time for it, I could try a PR but my approach is very hacky.

TheodoreGalanos avatar Apr 13 '21 04:04 TheodoreGalanos

@lucidrains is this supported? I never train my own VAE but it shouldn't be too tough to allow for VQGANs trained via the taming-transformers method, right?

afiaka87 avatar Apr 13 '21 14:04 afiaka87

@afiaka87 yup, should be doable! let me think about the interface - perhaps something like --vqvae-path and --vqvae-config-path

or perhaps, to stay simple, i should just default the VQVAE config to be the one with 1024 tokens

lucidrains avatar Apr 13 '21 22:04 lucidrains

Ive been working on this a bit - kobiso and others are fond of using hugging_faces tokenizers library. I think drop in support for the parameter --bpe_path="<vocab-file.json>" would be nice as well.

afiaka87 avatar Apr 14 '21 00:04 afiaka87

@afiaka87 yup I can work on that too!

lucidrains avatar Apr 14 '21 01:04 lucidrains

@afiaka87 yup I can work on that too! About to push a wip up actually

afiaka87 avatar Apr 14 '21 02:04 afiaka87

@lucidrains https://github.com/lucidrains/DALLE-pytorch/pull/193

If you can work from this then go for it - if you have a better implementation in mind let me know.

afiaka87 avatar Apr 14 '21 02:04 afiaka87

Here's a sample tokenizer to work with - perhaps include if you think it's a good idea

wget https://www.dropbox.com/s/uie7is0dyuxqmk0/hg_bpe_cc12m.json

Not a permanent host fyi.

@lucidrains

afiaka87 avatar Apr 14 '21 03:04 afiaka87

Here's a sample tokenizer to work with - perhaps include if you think it's a good idea

wget https://www.dropbox.com/s/uie7is0dyuxqmk0/hg_bpe_cc12m.json

Not a permanent host fyi.

@lucidrains

seems to work in the latest! but for some reason, when i decode, it isn't concatenating the subwords

lucidrains avatar Apr 15 '21 05:04 lucidrains

Here's a sample tokenizer to work with - perhaps include if you think it's a good idea

wget https://www.dropbox.com/s/uie7is0dyuxqmk0/hg_bpe_cc12m.json

Not a permanent host fyi.

@lucidrains

seems to work in the latest! but for some reason, when i decode, it isn't concatenating the subwords

Yep having that issue as well - did you figure it out?

afiaka87 avatar Apr 15 '21 08:04 afiaka87

In my case, the pretrained VQGAN model is significantly worse than a custom trained VQGAN model. So I guess many people may have such requirement. For now, I simply replace the name of the custom trained model.

Yuliang-Liu avatar Apr 21 '21 02:04 Yuliang-Liu

@afiaka87 yup, should be doable! let me think about the interface - perhaps something like --vqvae-path and --vqvae-config-path

or perhaps, to stay simple, i should just default the VQVAE config to be the one with 1024 tokens

I would say for my purposes, and anyone who has been actively using this code, a VQVAE with the existing VAE configuration already provided would be extremely useful.

EmaadKhwaja avatar Apr 24 '21 20:04 EmaadKhwaja

Putting

        model_filename = 'your_model.ckpt'
        config_filename = 'your_model.yaml'

        config = OmegaConf.load(config_filename)
        model = VQModel(**config.model.params)

        state = torch.load(model_filename, map_location = 'cpu')['state_dict']

there https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/vae.py#L135 works for training

but actually this results in confusion. This should also be supported:

  • at save time: the vqgan weights should be saved in the dalle model, similarly as what is done for the vae
  • in generation.py
  • to resume a training

What is worse is if at generation you use the pretrained vqgan with the dalle trained with the custom vqgan, it actually generates something not completely broken it seems, which I guess can be explained because some structure stays the same. This makes things pretty confusing.

Let's implement this properly.

rom1504 avatar May 11 '21 21:05 rom1504

For people who successfully train VQGAN, do you experience increasing quantized loss over time?

richcmwang avatar May 13 '21 07:05 richcmwang

this is now done

rom1504 avatar Jun 16 '21 09:06 rom1504