imagen-pytorch icon indicating copy to clipboard operation
imagen-pytorch copied to clipboard

Same Image whatever the input text on conditionnal train

Open axel588 opened this issue 2 years ago • 1 comments

Hello,

I was able to train a model on custom random minecraft texture: image But the issue is that it's generating the same image whatever the input text is. If I use this code (with embedding ) image It's the same issue.

and from the dataset itself the images are different : image

The dataset code :

import torch
import multiprocessing
from PIL import Image
from torch.utils.data import Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torchvision import transforms
from tqdm import tqdm


class MCDataset(Dataset):
    def __init__(self, dataset_dict, is_train=True, is_skip=False):
        super().__init__()

        
        self.dataset = dataset_dict["train"]
        self.is_train = is_train
        self.transform = transforms.Compose([
            transforms.Resize(16),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x if x.shape[0] == 4 else torch.cat([x, torch.full_like(x[:1], 255)]))
        ])
        # Convert to RGBA
        
        
        for i, sample in enumerate(self.dataset):
            if i % 10000 == 0:
              print("Images :"+str(i))
            image = sample["image"]
            if not isinstance(image, Image.Image):
                image = Image.open(image)
            if image.mode != "RGBA":
                image = image.convert("RGBA")
            self.dataset[i]["image"] = image

        # Split train and validation sets
        if is_train:
            self.dataset = self.dataset[:int(0.98 * len(self.dataset))]
        else:
            self.dataset = self.dataset[int(0.98 * len(self.dataset)):]
        print('Preparing text encoding')
        # Text encoding
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=512)
        self.model = T5ForConditionalGeneration.from_pretrained("t5-base").to(self.device)
        self.model.eval()
        print('Finished preparing text encoding')
        self.texts = []
        u = 0

        for text in tqdm(self.dataset["text"]):
            self.texts.append(text)

    

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset["image"][idx]
        image = self.transform(image)
        enc = self.tokenizer(self.texts[idx], return_tensors="pt", padding="max_length",
            max_length=512).to(device)

        # forward pass through encoder only
        output = self.model.encoder(
            input_ids=enc["input_ids"].to(self.device), 
            attention_mask=enc["attention_mask"].to(self.device), 
            return_dict=True
        )
        # get the final hidden states
        emb = output.last_hidden_state.cpu()
        
        return image, emb.squeeze(0)

axel588 avatar Feb 16 '23 18:02 axel588

Hi @axel588 Did you resolve issue ? I am experiencing same thing

gauravbyte avatar Jul 11 '23 13:07 gauravbyte