DALLE-pytorch icon indicating copy to clipboard operation
DALLE-pytorch copied to clipboard

Results

Open NaxAlpha opened this issue 4 years ago • 57 comments

I have trained DiscreteVEE on 128x128 FFHQ dataset. using this configration:

vae = DiscreteVAE(
    num_layers = 2,
    num_tokens = 4096,
    dim = 1024,
    hidden_dim = 256
)

Here are the results after 3 epochs (top original, bottom reconstructed):

image image image

NaxAlpha avatar Jan 07 '21 15:01 NaxAlpha

Even smaller model works pretty neat:

vae = DiscreteVAE(
    num_layers = 3,
    num_tokens = 4096,
    codebook_dim = 512,
    hidden_dim = 256,
)

Here are the samples:

image image image

Here is the (expected) loss after ~3 epochs: image

NaxAlpha avatar Jan 08 '21 10:01 NaxAlpha

Are you inputting descriptions for images or just let it randomly generate an image?

mrconter1 avatar Jan 08 '21 10:01 mrconter1

Are you inputting descriptions for images or just let it randomly generate an image?

Not OP, but this is just the VQVAE and only images from reconstruction not sampling. So input image top and the image bottom is the output of the VAE. The VQVAE is used for the codebook construction which will be then used by the transformer to generate image by a description

adrian-spataru avatar Jan 08 '21 10:01 adrian-spataru

Yes these results are for VAE - took only ~30 min to an 1hr on colab pro (V100) - I am in process of training DALLE - results should be ready soon!!!

NaxAlpha avatar Jan 08 '21 11:01 NaxAlpha

Would you mind sharing the Colab you have so far? :)

mrconter1 avatar Jan 08 '21 15:01 mrconter1

Sure! Here is the notebook so far. Also there is an update after discussion on #12 and applying the fix here are the results I got which is not so promising as previous but still it will actually work now (hopefully).

from dalle_pytorch import DiscreteVAE

NUM_LAYERS = 2
IMAGE_SIZE = 128

BATCH_SIZE = 32
NUM_TOKENS = 8192

EMB_DIM = 256
HID_DIM = 128

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim =  EMB_DIM,
    hidden_dim =    HID_DIM,
)

(Top is ground truth, middle one is soft decoded [via gumbel_softmax], bottom is hard decoded [via argmax] which only noise previously because of the bug)

image image image

NaxAlpha avatar Jan 08 '21 15:01 NaxAlpha

Thanks! I'm a noob but I tried to help: https://colab.research.google.com/drive/1KxG1iGBoKt2fLVH7uXG_vhvll2OlFkey?usp=sharing :)

mrconter1 avatar Jan 08 '21 16:01 mrconter1

Okay. Here is a fully working Colab for at least VAE training. Thanks to NaxAlpha of course!

https://colab.research.google.com/drive/1KxG1iGBoKt2fLVH7uXG_vhvll2OlFkey?usp=sharing

image After around 600 training pairs.

mrconter1 avatar Jan 08 '21 16:01 mrconter1

Here are results after a few hours of training of DALL-E:

image image image

Loss is still very high right now but its going down slowly

image

NaxAlpha avatar Jan 08 '21 18:01 NaxAlpha

Which dataset are you using to train DALL-E? Don't you need text as well? Also, what are you training on? Do you have access to Google Colab Pro?

mrconter1 avatar Jan 08 '21 19:01 mrconter1

@NaxAlpha nice! I just realized, without text, this essentially becomes iGPT! (If that is what you are doing)

lucidrains avatar Jan 08 '21 20:01 lucidrains

@lucidrains Isn't iGPT on pixel level or close to pixel level (a.k.a. the 9-bit color palette), whereas DALL-E operates on codebook vectors level? In a sense, DALL-E works at the right level of abstraction (pixels and local features are too fine, and entire scenes are too coarse).

VIVelev avatar Jan 08 '21 22:01 VIVelev

@VIVelev Yup, you are correct! iGPT is pixel level, but clustered into 512 (9-bit) discrete tokens. Equivalent to a 0 layer discrete VAE with a codebook of 512

lucidrains avatar Jan 08 '21 22:01 lucidrains

@NaxAlpha I just added a temperature flag on the DiscreteVAE class so you can control the hardness of gumbel during training! just fyi!

lucidrains avatar Jan 09 '21 02:01 lucidrains

Awesome! Yeah I am training it unconditionally - (just 1 text token which is random xD). Here are the results after 9 more hours:

image image image

I feel like it is going slower than my expectation. (Might need to scale up the transformer) Here is the DALLE configuration I am using:

from dalle_pytorch import DALLE, DiscreteVAE

NUM_LAYERS = 3
IMAGE_SIZE = 128

BATCH_SIZE = 16
NUM_TOKENS = 8192

EMB_DIM = 256
HID_DIM = 128

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim =  EMB_DIM,
    hidden_dim =    HID_DIM,
)

dalle = DALLE(
    dim = EMB_DIM,
    vae = vae,
    num_text_tokens = 1024,     # 1024 fixed latents (model should learn to ignore it)
    text_seq_len = 1,           # Acts like a latent variable
    depth = 16,
    heads = 24,
)

NaxAlpha avatar Jan 09 '21 04:01 NaxAlpha

@NaxAlpha haha yea, they used 64 layers! perhaps this could be tried on something small scale, like cifar sized

lucidrains avatar Jan 09 '21 04:01 lucidrains

Would it possible that using more coherent text (instead of random) also would result in more coherent images?

mrconter1 avatar Jan 09 '21 08:01 mrconter1

@lucidrains wow! temperature feature is awesome! Gradually decreasing it from 5 to 0.05 over 5 epochs and convergence is really fast as well as results look much better!!!

@mrconter1 Yes using coherent text should help but since I do not have any text for now so I am using just 1 token to make it work for now xD.

NaxAlpha avatar Jan 09 '21 10:01 NaxAlpha

@NaxAlpha

I created an image + desc fetcher. You can see it here. Could it be useful?

mrconter1 avatar Jan 09 '21 13:01 mrconter1

I just benchmarked my scraper on Google Colab Pro. It takes around 3.46 hours/10 000 image+desc pairs. I will upload the data when I'm done.

mrconter1 avatar Jan 09 '21 15:01 mrconter1

@NaxAlpha Added reversible networks! https://github.com/lucidrains/DALLE-pytorch#scaling-depth Maybe depth will help!

lucidrains avatar Jan 09 '21 16:01 lucidrains

Nevermind my scraper. Just use the COCO dataset. It has 500 000 images with descriptions for each one. Takes 10 minutes to download on Colab Pro. If anyone wants me to set up and Colab just tell me what format you want to have the data in.

mrconter1 avatar Jan 09 '21 17:01 mrconter1

hi everyone, thanks for all the amazing work and sharing results!

I have a really noobish question, hope it's okay. What do we think the scale of the image+text pairs needs to be to have something of use? I want to train it on my specific domain (architecture) and I'll probably need to create custom datasets. Any idea of what scale and above is worth to try? Also, concerning the codebook, does it need to be build on a similar dataset or variety is better?

Thanks in advance!

TheodoreGalanos avatar Jan 10 '21 12:01 TheodoreGalanos

@lucidrains Awesome, I have scaled the model - lets wait and see the results 😁.

The main problem right now is that VAE output is not really great. When temperature is high (>1) results look good but when temperature goes near 0.1, it becomes horrible - ideally we want temperature to be close to 0 because otherwise no matter how good the language model is decoded output would be rough.

Below are the outputs where top row is ground truth, middle row is output of VAE through gumbel softmax at different temperatures and last is output through following code:

codes = vae.get_codebook_indices(images[:k])
image = vae.decode(codes)

@ temperature = 2.9 image

@ temperature = 1.8 image

@ temperature = 0.6 image

@ temperature = 0.1 image

BTW Here is the config that I am using:

from dalle_pytorch import DiscreteVAE

NUM_LAYERS = 3
IMAGE_SIZE = 128

BATCH_SIZE = 8
NUM_TOKENS = 8192

EMB_DIM = 1024
HID_DIM = 256

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim =  EMB_DIM,
    hidden_dim =    HID_DIM,
)

NaxAlpha avatar Jan 10 '21 15:01 NaxAlpha

Nevermind my scraper. Just use the COCO dataset. It has 500 000 images with descriptions for each one. Takes 10 minutes to download on Colab Pro. If anyone wants me to set up and Colab just tell me what format you want to have the data in.

@mrconter1 how do you feed in the text descriptions with corresponding images as the input parameter of dalle training? Would u mind to share your colab?

HenryHengZJ avatar Jan 10 '21 16:01 HenryHengZJ

@NaxAlpha thanks for sharing your results! So I have an end to end version at a different branch in the repository that could be tried, perhaps with an annealing schedule

I'll also add resnet blocks to the VAE later today, per suggestion of Aran

Keep us posted!

Edit - will also reread https://arxiv.org/abs/2012.09841 for insights

lucidrains avatar Jan 10 '21 17:01 lucidrains

I've just created the text+image pairs. Not sure how to feed it to DALL-E. I think you are supposed to tokenize the text.

Den sön 10 jan. 2021 17:33Henry Heng [email protected] skrev:

Nevermind my scraper. Just use the COCO dataset. It has 500 000 images with descriptions for each one. Takes 10 minutes to download on Colab Pro. If anyone wants me to set up and Colab just tell me what format you want to have the data in.

@mrconter1 https://github.com/mrconter1 how do you feed in the text descriptions with corresponding images as the input parameter of dalle training? Would u mind to share your colab?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/lucidrains/DALLE-pytorch/issues/10#issuecomment-757504069, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHYLDTRLNTTLLXE47PZP2FDSZHJFDANCNFSM4VZDHPFA .

mrconter1 avatar Jan 10 '21 17:01 mrconter1

I've just put together quick and dirty code to train dalle. Not directly usable for anyone, I am afraid. I am using a small dataset of 2000+ landscapes, for which I automatically generated captions into a text files. This script reads the image filenames and captions from a text file, builds a vocabulary and uses it to convert text tokens into numeric.

There is not even a proper pytorch dataset, just quick code iterate through the data. So far, it appears to be learning. Loss is decreasing and the generated images are starting to rougly resemble landscapes. dallevae-cdim256_epoch_20

https://github.com/htoyryla/DALLE-pytorch/blob/main/trainDALLE.py

PS. The vocabulary class is missing from my repo. My code uses one from this page https://www.kdnuggets.com/2019/11/create-vocabulary-nlp-tasks-python.html

htoyryla avatar Jan 10 '21 19:01 htoyryla

@htoyryla Oh wow, that is better than I thought for only 2000 images!

I added resnet blocks https://github.com/lucidrains/DALLE-pytorch/commit/c5c56e287b2167606f6b9227f4cdab3622c14b6c , suggested by @AranKomat

lucidrains avatar Jan 10 '21 20:01 lucidrains

I've just put together quick and dirty code to train dalle. Not directly usable for anyone, I am afraid. I am using a small dataset of 2000+ landscapes, for which I automatically generated captions into a text files. This script reads the image filenames and captions from a text file, builds a vocabulary and uses it to convert text tokens into numeric.

There is not even a proper pytorch dataset, just quick code iterate through the data. So far, it appears to be learning. Loss is decreasing and the generated images are starting to rougly resemble landscapes. dallevae-cdim256_epoch_20

https://github.com/htoyryla/DALLE-pytorch/blob/main/trainDALLE.py

PS. The vocabulary class is missing from my repo. My code uses one from this page https://www.kdnuggets.com/2019/11/create-vocabulary-nlp-tasks-python.html

what was the codebook size you used?

lucidrains avatar Jan 10 '21 20:01 lucidrains