VAR
VAR copied to clipboard
Image reconstruction via Transformer.
Hello, I have a question about the image reconsturction via VAR. I want the transformer model to predict the ground truth tokens, just like in the training situation, by obtaining image tokens through an vq-encoder, and then interpolating the tokens, finally inputting them into the transformer. (like inversion in diffusion models)
However, when I configured the code, there was a difference from the original image. Could I have missed something, or is this approach not feasible?
Here's original image and recon image.
original image
recon image
And Here's my code.
gt_idx = vae.img_to_idxBl(img)
tr_input_embed = quantize_local.idxBl_to_var_input(gt_idx)
tr_input_embed's shape is [B, 679, 32] And I implement this code in the VAR class.
def image_recon_forward(self, tr_input_embed, gt_start_emd):
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
B = tr_input_embed.shape[0]
with torch.cuda.amp.autocast(enabled=False):
sos = cond_BD = self.class_emb(torch.tensor(1000).repeat(B).to(tr_input_embed.device)).unsqueeze(1)
sos = sos.expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
if self.prog_si == 0: x_BLC = sos
else: x_BLC = torch.cat((sos, self.word_embed(tr_input_embed.float())), dim=1)
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
temp = x_BLC.new_ones(8, 8)
main_type = torch.matmul(temp, temp).dtype
x_BLC = x_BLC.to(dtype=main_type)
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
attn_bias = attn_bias.to(dtype=main_type)
AdaLNSelfAttn.forward
for i, b in enumerate(self.blocks):
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
idx_Bl = x_BLC.argmax(dim=-1)
idx_Bl[:,0] = gt_start_emd.squeeze(1)
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl)
h_list = []
for si, (st, ed) in enumerate(self.begin_ends):
pn = self.patch_nums[si]
h_list.append(h_BChw[:,st:ed,:].reshape(B, int((ed-st)**0.5), int((ed-st)**0.5), self.Cvae).permute(0,3,1,2))
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw[:,st:ed,:].transpose_(1,2).reshape(B, self.Cvae, pn, pn))
for b in self.blocks: b.attn.kv_caching(False)
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5)
Again, thanks for your great work!
@minimini-1 Hi, I have the same problem as you, did you solve that?
@minimini-1 @maggiesong7 In deed, the way you try to recons an image using VAR is incorrect. VAR formulates a next-scale prediciton task where current scale prediciton is conditioned on previous scales' predicitons. In your implementation, current scale prediciton is condition on previous scales' ground truth predicitons instead of VAR's previous predicitons. For example, a number 12 is tokenizied to [SOS 8 3 1]. In your code: VAR predicts 7 conditioned on [SOS] VAR predicts 2 conditioned on [SOS 8] VAR predicts 0 conditioned on [SOS 8 3] Then you add up [7 2 0] only gets 9, which is quite different to the target 12.