DALLE-pytorch icon indicating copy to clipboard operation
DALLE-pytorch copied to clipboard

Trained on COCO

Open fomalhautb opened this issue 4 years ago • 11 comments

I trained a 6 layer 8 head model on the COCO dataset (15 epochs, 82k images, 1x Tesla V100, batch size 12, no CLIP filtering). The results look as follows (conditioned on text "A car on the street"): image image image image image image image image There are more or less some structures but the description "A car on the street" does not meet well (I can only recognize a "car" in a few samples). Maybe the model and the dataset are still too small to learn the text condition.

fomalhautb avatar Feb 19 '21 12:02 fomalhautb

Maybe try "An image of a car on the street"??

Thierryonre avatar Feb 19 '21 14:02 Thierryonre

Doesn't look bad at all! You need to start scaling dataset and size of the attention net! Just think of the difference between a 6 layer gpt vs a 64 layer one

lucidrains avatar Feb 19 '21 16:02 lucidrains

Did you try the COCO notebook from @mrconter1? https://github.com/lucidrains/DALLE-pytorch/issues/15. It has a bug where it mixes up image ID with annotation ID

use image_id = int(info['image_id']) instead of image_id = int(info['id']) This randomizes all annotations across the images!

vigyanik avatar Mar 07 '21 17:03 vigyanik

Did you try the COCO notebook from @mrconter1? https://github.com/lucidrains/DALLE-pytorch/issues/15. It has a bug where it mixes up image ID with annotation ID

use image_id = int(info['image_id']) instead of image_id = int(info['id'])

No, I wrote my own training script. But I may have also made the same mistake. This is my code for the COCO dataset:

class COCODataset(Dataset):
    def __init__(self, annot_path, image_path, image_size=256):
        self._image_path = image_path
        
        with open(annot_path) as file:
            json_file = json.load(file)
        
        self._image_size = image_size
        self._metadata = json_file['images']
        self._captions = {entry['image_id']: entry for entry in json_file['annotations']}
    
    def __getitem__(self, index):
        metadata = self._metadata[index]
        caption = self._captions[metadata['id']]
        image_path = os.path.join(self._image_path, metadata['file_name'])
        image = Image.open(image_path).convert('RGB')
        image = self._crop_image(image)
        x = np.asarray(image) / 127.5 - 1
        return torch.Tensor(x).permute(2, 0, 1), caption['caption']
    
    def _crop_image(self, image):
        width, height = image.size
        min_length = min(width, height)
        
        # center crop
        left = (width - min_length)/2
        top = (height - min_length)/2
        right = (width + min_length)/2
        bottom = (height + min_length)/2
        image = image.crop((left, top, right, bottom))
        
        # resize
        image = image.resize((self._image_size, self._image_size))
        
        return image
    
    def __len__(self):
        return len(self._metadata)

Is the problem you mentioned in it?

fomalhautb avatar Mar 07 '21 17:03 fomalhautb

Did you try the COCO notebook from @mrconter1? https://github.com/lucidrains/DALLE-pytorch/issues/15. It has a bug where it mixes up image ID with annotation ID use image_id = int(info['image_id']) instead of image_id = int(info['id'])

No, I wrote my own training script. But I may have also made the same mistake. This is my code for the COCO dataset:

class COCODataset(Dataset):
    def __init__(self, annot_path, image_path, image_size=256):
        self._image_path = image_path
        
        with open(annot_path) as file:
            json_file = json.load(file)
        
        self._image_size = image_size
        self._metadata = json_file['images']
        self._captions = {entry['image_id']: entry for entry in json_file['annotations']} # This line here
    
    def __getitem__(self, index):
        metadata = self._metadata[index]
        caption = self._captions[metadata['id']]
        image_path = os.path.join(self._image_path, metadata['file_name'])
        image = Image.open(image_path).convert('RGB')
        image = self._crop_image(image)
        x = np.asarray(image) / 127.5 - 1
        return torch.Tensor(x).permute(2, 0, 1), caption['caption']
    
    def _crop_image(self, image):
        width, height = image.size
        min_length = min(width, height)
        
        # center crop
        left = (width - min_length)/2
        top = (height - min_length)/2
        right = (width + min_length)/2
        bottom = (height + min_length)/2
        image = image.crop((left, top, right, bottom))
        
        # resize
        image = image.resize((self._image_size, self._image_size))
        
        return image
    
    def __len__(self):
        return len(self._metadata)

Is the problem you mentioned in it?

self._captions = {int(entry['image_id']): entry for entry in json_file['annotations']}

Hey, I'm about to dig into COCO myself. Is the mistake referenced this one?

afiaka87 avatar Mar 15 '21 18:03 afiaka87

Did you train the dVAE by yourself or use the pre-trained model by openai?

GuoxingY avatar Apr 08 '21 01:04 GuoxingY

Did you train the dVAE by yourself or use the pre-trained model by openai?

There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.

afiaka87 avatar Apr 08 '21 11:04 afiaka87

Did you train the dVAE by yourself or use the pre-trained model by openai?

There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.

Well, the post is actually older than the OpenAI release, I trained the dVAE by myself. I also tested VQGAN after that, it uses significantly fewer resources and the generations look sharper and clearer.

fomalhautb avatar Apr 09 '21 12:04 fomalhautb

Did you train the dVAE by yourself or use the pre-trained model by openai?

There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.

Well, the post is actually older than the OpenAI release, I trained the dVAE by myself. I also tested VQGAN after that, it uses significantly fewer resources and the generations look sharper and clearer.

I also try to train dalle with the dVAE trained by myself, but got worse results, could you share some details about training the dVAE (or VQGAN) and the dalle. Besides, the COCO dataset you used is the 2014train?

GuoxingY avatar Apr 09 '21 15:04 GuoxingY

Did you train the dVAE by yourself or use the pre-trained model by openai?

There is also the 1024VQGAN from the taming-transformers research available. This post is older and may use the OpenAI one - not sure.

Well, the post is actually older than the OpenAI release, I trained the dVAE by myself. I also tested VQGAN after that, it uses significantly fewer resources and the generations look sharper and clearer.

I also try to train dalle with the dVAE trained by myself, but got worse results, could you share some details about training the dVAE (or VQGAN) and the dalle. Besides, the COCO dataset you used is the 2014train?

I can share my code. It uses an old version of DALLE and it is very ugly, but you can still reference it. About the dataset: yes, I used the 2014train.

fomalhautb avatar Apr 09 '21 17:04 fomalhautb

This is the result from "COCO 2014 train" , almost the same dalle structure @ 1epoch

I notice that after 1epoch loss doesn't seem to decrease and the image is still ugly. Can I get a better result, with more epochs?

caronstreet

SeungyounShin avatar Nov 15 '21 11:11 SeungyounShin