DALLE-pytorch
DALLE-pytorch copied to clipboard
Trained on COCO
I trained a 6 layer 8 head model on the COCO dataset (15 epochs, 82k images, 1x Tesla V100, batch size 12, no CLIP filtering). The results look as follows (conditioned on text "A car on the street"):
There are more or less some structures but the description "A car on the street" does not meet well (I can only recognize a "car" in a few samples). Maybe the model and the dataset are still too small to learn the text condition.
Maybe try "An image of a car on the street"??
Doesn't look bad at all! You need to start scaling dataset and size of the attention net! Just think of the difference between a 6 layer gpt vs a 64 layer one
Did you try the COCO notebook from @mrconter1? https://github.com/lucidrains/DALLE-pytorch/issues/15. It has a bug where it mixes up image ID with annotation ID
use
image_id = int(info['image_id'])
instead of
image_id = int(info['id'])
This randomizes all annotations across the images!
Did you try the COCO notebook from @mrconter1? https://github.com/lucidrains/DALLE-pytorch/issues/15. It has a bug where it mixes up image ID with annotation ID
use
image_id = int(info['image_id'])
instead ofimage_id = int(info['id'])
No, I wrote my own training script. But I may have also made the same mistake. This is my code for the COCO dataset:
class COCODataset(Dataset):
def __init__(self, annot_path, image_path, image_size=256):
self._image_path = image_path
with open(annot_path) as file:
json_file = json.load(file)
self._image_size = image_size
self._metadata = json_file['images']
self._captions = {entry['image_id']: entry for entry in json_file['annotations']}
def __getitem__(self, index):
metadata = self._metadata[index]
caption = self._captions[metadata['id']]
image_path = os.path.join(self._image_path, metadata['file_name'])
image = Image.open(image_path).convert('RGB')
image = self._crop_image(image)
x = np.asarray(image) / 127.5 - 1
return torch.Tensor(x).permute(2, 0, 1), caption['caption']
def _crop_image(self, image):
width, height = image.size
min_length = min(width, height)
# center crop
left = (width - min_length)/2
top = (height - min_length)/2
right = (width + min_length)/2
bottom = (height + min_length)/2
image = image.crop((left, top, right, bottom))
# resize
image = image.resize((self._image_size, self._image_size))
return image
def __len__(self):
return len(self._metadata)
Is the problem you mentioned in it?
Did you try the COCO notebook from @mrconter1? https://github.com/lucidrains/DALLE-pytorch/issues/15. It has a bug where it mixes up image ID with annotation ID use
image_id = int(info['image_id'])
instead ofimage_id = int(info['id'])
No, I wrote my own training script. But I may have also made the same mistake. This is my code for the COCO dataset:
class COCODataset(Dataset): def __init__(self, annot_path, image_path, image_size=256): self._image_path = image_path with open(annot_path) as file: json_file = json.load(file) self._image_size = image_size self._metadata = json_file['images'] self._captions = {entry['image_id']: entry for entry in json_file['annotations']} # This line here def __getitem__(self, index): metadata = self._metadata[index] caption = self._captions[metadata['id']] image_path = os.path.join(self._image_path, metadata['file_name']) image = Image.open(image_path).convert('RGB') image = self._crop_image(image) x = np.asarray(image) / 127.5 - 1 return torch.Tensor(x).permute(2, 0, 1), caption['caption'] def _crop_image(self, image): width, height = image.size min_length = min(width, height) # center crop left = (width - min_length)/2 top = (height - min_length)/2 right = (width + min_length)/2 bottom = (height + min_length)/2 image = image.crop((left, top, right, bottom)) # resize image = image.resize((self._image_size, self._image_size)) return image def __len__(self): return len(self._metadata)
Is the problem you mentioned in it?
self._captions = {int(entry['image_id']): entry for entry in json_file['annotations']}
Hey, I'm about to dig into COCO myself. Is the mistake referenced this one?
Did you train the dVAE by yourself or use the pre-trained model by openai?
Did you train the dVAE by yourself or use the pre-trained model by openai?
There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.
Did you train the dVAE by yourself or use the pre-trained model by openai?
There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.
Well, the post is actually older than the OpenAI release, I trained the dVAE by myself. I also tested VQGAN after that, it uses significantly fewer resources and the generations look sharper and clearer.
Did you train the dVAE by yourself or use the pre-trained model by openai?
There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.
Well, the post is actually older than the OpenAI release, I trained the dVAE by myself. I also tested VQGAN after that, it uses significantly fewer resources and the generations look sharper and clearer.
I also try to train dalle with the dVAE trained by myself, but got worse results, could you share some details about training the dVAE (or VQGAN) and the dalle. Besides, the COCO dataset you used is the 2014train?
Did you train the dVAE by yourself or use the pre-trained model by openai?
There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.
Well, the post is actually older than the OpenAI release, I trained the dVAE by myself. I also tested VQGAN after that, it uses significantly fewer resources and the generations look sharper and clearer.
I also try to train dalle with the dVAE trained by myself, but got worse results, could you share some details about training the dVAE (or VQGAN) and the dalle. Besides, the COCO dataset you used is the 2014train?
I can share my code. It uses an old version of DALLE and it is very ugly, but you can still reference it. About the dataset: yes, I used the 2014train.
This is the result from "COCO 2014 train" , almost the same dalle structure @ 1epoch
I notice that after 1epoch loss doesn't seem to decrease and the image is still ugly. Can I get a better result, with more epochs?