imagen-pytorch
imagen-pytorch copied to clipboard
Same Image whatever the input text on conditionnal train
Hello,
I was able to train a model on custom random minecraft texture:
But the issue is that it's generating the same image whatever the input text is. If I use this code (with embedding )
It's the same issue.
and from the dataset itself the images are different :
The dataset code :
import torch
import multiprocessing
from PIL import Image
from torch.utils.data import Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torchvision import transforms
from tqdm import tqdm
class MCDataset(Dataset):
def __init__(self, dataset_dict, is_train=True, is_skip=False):
super().__init__()
self.dataset = dataset_dict["train"]
self.is_train = is_train
self.transform = transforms.Compose([
transforms.Resize(16),
transforms.ToTensor(),
transforms.Lambda(lambda x: x if x.shape[0] == 4 else torch.cat([x, torch.full_like(x[:1], 255)]))
])
# Convert to RGBA
for i, sample in enumerate(self.dataset):
if i % 10000 == 0:
print("Images :"+str(i))
image = sample["image"]
if not isinstance(image, Image.Image):
image = Image.open(image)
if image.mode != "RGBA":
image = image.convert("RGBA")
self.dataset[i]["image"] = image
# Split train and validation sets
if is_train:
self.dataset = self.dataset[:int(0.98 * len(self.dataset))]
else:
self.dataset = self.dataset[int(0.98 * len(self.dataset)):]
print('Preparing text encoding')
# Text encoding
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=512)
self.model = T5ForConditionalGeneration.from_pretrained("t5-base").to(self.device)
self.model.eval()
print('Finished preparing text encoding')
self.texts = []
u = 0
for text in tqdm(self.dataset["text"]):
self.texts.append(text)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image = self.dataset["image"][idx]
image = self.transform(image)
enc = self.tokenizer(self.texts[idx], return_tensors="pt", padding="max_length",
max_length=512).to(device)
# forward pass through encoder only
output = self.model.encoder(
input_ids=enc["input_ids"].to(self.device),
attention_mask=enc["attention_mask"].to(self.device),
return_dict=True
)
# get the final hidden states
emb = output.last_hidden_state.cpu()
return image, emb.squeeze(0)
Hi @axel588 Did you resolve issue ? I am experiencing same thing