imagen-pytorch
imagen-pytorch copied to clipboard
Code for conditional training on dataset which has images and text
Is there any code for conditional training on dataset which has images and text?
I created a modified Dataset
class that returns (image, text_embedding, text_mask)
, but I'm still getting errors. I wonder what the shape of the embedding and mask must be.
The shape of the embedding may be ( 256, 768)
After reading the codes I concluded that the dimension with size 768
is the embedding of single words and size 256
is the result of padding if the word count is less than 256. The word mask can tell at which position is the padding.
I created a modified
Dataset
class that returns(image, text_embedding, text_mask)
, but I'm still getting errors. I wonder what the shape of the embedding and mask must be.
Get the tensor .size() from the encoder you're using and pass text_embed_dim
to both the imagen class and unet class, for example:
My last unet config that had an encoder that used 2048 size embeddings
imagen = ElucidatedImagenConfig(
unets=[
dict(
dim=192,
dim_mults=(1, 2, 3, 4),
text_embed_dim=2048,
num_resnet_blocks=2,
layer_attns=(False, True, True, True),
memory_efficient=False,
self_cond=True,
),
],
image_sizes=(64,),
cond_drop_prob=0.1,
text_embed_dim=2048,
num_sample_steps=50,
).create()
You also must make sure you pass the same exact number of tokens per batch, you can use the collate_fn feature of pytorch dataloaders to do so, or pad everything to the largest text in your dataset.
Is there any code for conditional training on dataset which has images and text?
Assuming you have the texts next to the image with the exact same name but with a .txt extension instead you can use this modified version of __getitem__
of the dataset in data.py
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
with open(path.with_suffix('.txt'), 'r', enconding='utf-8') as f:
text = f.read()
return self.transform(img), text
Then, assuming you're using the example from the dataloader section of the readme you have to pass dl_tuple_output_keywords_names=('images', 'texts')
to the trainer as the default expects text_embeds, not raw texts, so in the example you would replace the trainer code with this:
trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True, # whether to split the validation dataset from the training
dl_tuple_output_keywords_names=('images', 'texts'),
).cuda()
This will make the trainer pass the first returned value of __getitem__
to the images
keyword of the forward function of the imagen class, and the second value to the texts
keyword. The class will then encode the texts to embeddings for you, at the cost of slower training time, specialty if you change the encoder to be something much bigger (default is t5-base).
I forced the T5 model to pad the text embedding to size 60 which is enough for my dataset. Now my Dataset
returns text embedding of size (60, 768) and text mask of size (60, ). This works perfectly.
dl_tuple_output_keywords_names=('images', 'texts'),
My dataset return a self.transform image and a string in the following way : self.transform(img), text Hey, doesnt work for me :
BeartypeCallHintParamViolation Traceback (most recent call last)
[<ipython-input-14-bed965b0c30d>](https://localhost:8080/#) in <module>
42 while c<50000: # Should converge in < 5000 steps
43
---> 44 loss = trainer.train_step(unet_number = 1)
45 avg_loss = w_avg * avg_loss + (1 - w_avg) * loss
46
4 frames
[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/trainer.py](https://localhost:8080/#) in train_step(self, unet_number, **kwargs)
610 self.prepare()
611 self.create_train_iter()
--> 612 loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
613 self.update(unet_number = unet_number)
614 return loss
[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/trainer.py](https://localhost:8080/#) in step_with_dl_iter(self, dl_iter, **kwargs)
628 dl_tuple_output = cast_tuple(next(dl_iter))
629 model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
--> 630 loss = self.forward(**{**kwargs, **model_input})
631 return loss
632
[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/trainer.py](https://localhost:8080/#) in inner(model, *args, **kwargs)
134 kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
135
--> 136 out = fn(model, *args, **kwargs)
137 return out
138 return inner
[/usr/local/lib/python3.8/dist-packages/imagen_pytorch/trainer.py](https://localhost:8080/#) in forward(self, unet_number, max_batch_size, *args, **kwargs)
981 for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
982 with self.accelerator.autocast():
--> 983 loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
984 loss = loss * chunk_size_frac
985
[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
<@beartype(imagen_pytorch.imagen_pytorch.Imagen.forward) at 0x7fc8fb8129d0> in forward(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_140501522573248, __beartype_getrandbits, *args, **kwargs)
BeartypeCallHintParamViolation: @beartyped imagen_pytorch.imagen_pytorch.Imagen.forward() parameter texts=('pine', 'birch boat') violates type hint typing.List[str], as tuple ('pine', 'birch boat') not instance of list.
Have you found a solution?