imagen-pytorch icon indicating copy to clipboard operation
imagen-pytorch copied to clipboard

Code for conditional training on dataset which has images and text

Open BIG-PIE-MILK-COW opened this issue 2 years ago • 8 comments

Is there any code for conditional training on dataset which has images and text?

BIG-PIE-MILK-COW avatar Sep 02 '22 14:09 BIG-PIE-MILK-COW

I created a modified Dataset class that returns (image, text_embedding, text_mask), but I'm still getting errors. I wonder what the shape of the embedding and mask must be.

Rorical avatar Sep 02 '22 15:09 Rorical

The shape of the embedding may be ( 256, 768)

BIG-PIE-MILK-COW avatar Sep 02 '22 15:09 BIG-PIE-MILK-COW

After reading the codes I concluded that the dimension with size 768 is the embedding of single words and size 256 is the result of padding if the word count is less than 256. The word mask can tell at which position is the padding.

Rorical avatar Sep 02 '22 16:09 Rorical

I created a modified Dataset class that returns (image, text_embedding, text_mask), but I'm still getting errors. I wonder what the shape of the embedding and mask must be.

Get the tensor .size() from the encoder you're using and pass text_embed_dim to both the imagen class and unet class, for example:

My last unet config that had an encoder that used 2048 size embeddings
imagen = ElucidatedImagenConfig(
            dim_mults=(1, 2, 3, 4),
            layer_attns=(False, True, True, True),

You also must make sure you pass the same exact number of tokens per batch, you can use the collate_fn feature of pytorch dataloaders to do so, or pad everything to the largest text in your dataset.

Nodja avatar Sep 02 '22 17:09 Nodja

Is there any code for conditional training on dataset which has images and text?

Assuming you have the texts next to the image with the exact same name but with a .txt extension instead you can use this modified version of __getitem__ of the dataset in

def __getitem__(self, index):
    path = self.paths[index]
    img =
    with open(path.with_suffix('.txt'), 'r', enconding='utf-8') as f:
          text =
    return self.transform(img), text

Then, assuming you're using the example from the dataloader section of the readme you have to pass dl_tuple_output_keywords_names=('images', 'texts') to the trainer as the default expects text_embeds, not raw texts, so in the example you would replace the trainer code with this:

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = True, # whether to split the validation dataset from the training
    dl_tuple_output_keywords_names=('images', 'texts'),

This will make the trainer pass the first returned value of __getitem__ to the images keyword of the forward function of the imagen class, and the second value to the texts keyword. The class will then encode the texts to embeddings for you, at the cost of slower training time, specialty if you change the encoder to be something much bigger (default is t5-base).

Nodja avatar Sep 02 '22 18:09 Nodja

I forced the T5 model to pad the text embedding to size 60 which is enough for my dataset. Now my Dataset returns text embedding of size (60, 768) and text mask of size (60, ). This works perfectly.

Rorical avatar Sep 03 '22 03:09 Rorical

dl_tuple_output_keywords_names=('images', 'texts'),

My dataset return a self.transform image and a string in the following way : self.transform(img), text Hey, doesnt work for me :

BeartypeCallHintParamViolation            Traceback (most recent call last)
[<ipython-input-14-bed965b0c30d>](https://localhost:8080/#) in <module>
     42 while c<50000: # Should converge in < 5000 steps
---> 44       loss = trainer.train_step(unet_number = 1)
     45       avg_loss = w_avg * avg_loss + (1 - w_avg) * loss

4 frames
[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/](https://localhost:8080/#) in train_step(self, unet_number, **kwargs)
    610             self.prepare()
    611         self.create_train_iter()
--> 612         loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
    613         self.update(unet_number = unet_number)
    614         return loss

[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/](https://localhost:8080/#) in step_with_dl_iter(self, dl_iter, **kwargs)
    628         dl_tuple_output = cast_tuple(next(dl_iter))
    629         model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
--> 630         loss = self.forward(**{**kwargs, **model_input})
    631         return loss

[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/](https://localhost:8080/#) in inner(model, *args, **kwargs)
    134         kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
--> 136         out = fn(model, *args, **kwargs)
    137         return out
    138     return inner

[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/](https://localhost:8080/#) in forward(self, unet_number, max_batch_size, *args, **kwargs)
    981         for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
    982             with self.accelerator.autocast():
--> 983                 loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
    984                 loss = loss * chunk_size_frac

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

<@beartype(imagen_pytorch.imagen_pytorch.Imagen.forward) at 0x7fc8fb8129d0> in forward(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_140501522573248, __beartype_getrandbits, *args, **kwargs)

BeartypeCallHintParamViolation: @beartyped imagen_pytorch.imagen_pytorch.Imagen.forward() parameter texts=('pine', 'birch boat') violates type hint typing.List[str], as tuple ('pine', 'birch boat') not instance of list.

axel578 avatar Feb 17 '23 09:02 axel578

Have you found a solution?

xiaoxiaodadada avatar Jun 25 '24 07:06 xiaoxiaodadada