imagen-pytorch icon indicating copy to clipboard operation
imagen-pytorch copied to clipboard

How to generate images after training

Open oo92 opened this issue 2 years ago • 2 comments

Hi. I am using the provided code below:

import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = trainer(
    images,
    text_embeds = text_embeds,
    unet_number = 1,            # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
    max_batch_size = 4          # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)

trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
], cond_scale = 3.)

images.shape # (2, 3, 256, 256)

I want to know how to generate the images in texts. Right now, once the training is done, there aren't any outputs.

oo92 avatar Aug 12 '22 02:08 oo92

you should be able to use the torchvision library to save every tensor in the images list

korakoe avatar Aug 12 '22 09:08 korakoe

you should be able to use the torchvision library to save every tensor in the images dict

How can I do that? Can you show me a demo?

oo92 avatar Aug 12 '22 22:08 oo92

from torchvision.utils import save_image

for idx, image in enumerate(images):
    save_image(image, f"sample_{idx}.png")

edit: you should also be able to use the image .save function aswell

korakoe avatar Aug 13 '22 03:08 korakoe

a7ab6992-e07d-444e-982d-a8ffbfd29a6d

i got this result missing somethink'S?

kilik128 avatar Aug 22 '22 06:08 kilik128

a7ab6992-e07d-444e-982d-a8ffbfd29a6d

i got this result missing somethink'S?

Model isnt trained, Either wait for the pretrained models or train it yourself on something like a small subset of LAION

korakoe avatar Aug 29 '22 02:08 korakoe

Is there any pretrained model and how to use pretrained model?

liulwx avatar May 13 '23 09:05 liulwx