DALLE-pytorch
DALLE-pytorch copied to clipboard
Allowing for custom trained VQGAN during DALLE training
It would be nice to allow for custom trained VQGAN models when training DALLE.
Right now the library hard-codes usage to the pretrained imagenet model, allowing a user input for the model to be used (along with params) for train_dalle.py would be great.
p.s. The download scripts seem to bring the model to .cache which, in my tests, bypasses the custom model one can pass in the VQGAN1024 class. They would also probably need to be included in the option.
I have this running on my own tests, if it's difficult to find time for it, I could try a PR but my approach is very hacky.
@lucidrains is this supported? I never train my own VAE but it shouldn't be too tough to allow for VQGANs trained via the taming-transformers method, right?
@afiaka87 yup, should be doable! let me think about the interface - perhaps something like --vqvae-path
and --vqvae-config-path
or perhaps, to stay simple, i should just default the VQVAE config to be the one with 1024 tokens
Ive been working on this a bit - kobiso and others are fond of using hugging_faces tokenizers library. I think drop in support for the parameter --bpe_path="<vocab-file.json>"
would be nice as well.
@afiaka87 yup I can work on that too!
@afiaka87 yup I can work on that too! About to push a wip up actually
@lucidrains https://github.com/lucidrains/DALLE-pytorch/pull/193
If you can work from this then go for it - if you have a better implementation in mind let me know.
Here's a sample tokenizer to work with - perhaps include if you think it's a good idea
wget https://www.dropbox.com/s/uie7is0dyuxqmk0/hg_bpe_cc12m.json
Not a permanent host fyi.
@lucidrains
Here's a sample tokenizer to work with - perhaps include if you think it's a good idea
wget https://www.dropbox.com/s/uie7is0dyuxqmk0/hg_bpe_cc12m.json
Not a permanent host fyi.
@lucidrains
seems to work in the latest! but for some reason, when i decode, it isn't concatenating the subwords
Here's a sample tokenizer to work with - perhaps include if you think it's a good idea
wget https://www.dropbox.com/s/uie7is0dyuxqmk0/hg_bpe_cc12m.json
Not a permanent host fyi.
@lucidrains
seems to work in the latest! but for some reason, when i decode, it isn't concatenating the subwords
Yep having that issue as well - did you figure it out?
In my case, the pretrained VQGAN model is significantly worse than a custom trained VQGAN model. So I guess many people may have such requirement. For now, I simply replace the name of the custom trained model.
@afiaka87 yup, should be doable! let me think about the interface - perhaps something like
--vqvae-path
and--vqvae-config-path
or perhaps, to stay simple, i should just default the VQVAE config to be the one with 1024 tokens
I would say for my purposes, and anyone who has been actively using this code, a VQVAE with the existing VAE configuration already provided would be extremely useful.
Putting
model_filename = 'your_model.ckpt'
config_filename = 'your_model.yaml'
config = OmegaConf.load(config_filename)
model = VQModel(**config.model.params)
state = torch.load(model_filename, map_location = 'cpu')['state_dict']
there https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/vae.py#L135 works for training
but actually this results in confusion. This should also be supported:
- at save time: the vqgan weights should be saved in the dalle model, similarly as what is done for the vae
- in generation.py
- to resume a training
What is worse is if at generation you use the pretrained vqgan with the dalle trained with the custom vqgan, it actually generates something not completely broken it seems, which I guess can be explained because some structure stays the same. This makes things pretty confusing.
Let's implement this properly.
For people who successfully train VQGAN, do you experience increasing quantized loss over time?
this is now done