VAR icon indicating copy to clipboard operation
VAR copied to clipboard

Image reconstruction via Transformer.

Open minimini-1 opened this issue 9 months ago • 2 comments

Hello, I have a question about the image reconsturction via VAR. I want the transformer model to predict the ground truth tokens, just like in the training situation, by obtaining image tokens through an vq-encoder, and then interpolating the tokens, finally inputting them into the transformer. (like inversion in diffusion models)

However, when I configured the code, there was a difference from the original image. Could I have missed something, or is this approach not feasible?

Here's original image and recon image. original image original image recon image recon image

And Here's my code.

gt_idx = vae.img_to_idxBl(img)
tr_input_embed = quantize_local.idxBl_to_var_input(gt_idx)

tr_input_embed's shape is [B, 679, 32] And I implement this code in the VAR class.

def image_recon_forward(self, tr_input_embed, gt_start_emd):
    bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)

    B = tr_input_embed.shape[0]

    with torch.cuda.amp.autocast(enabled=False):
        sos = cond_BD = self.class_emb(torch.tensor(1000).repeat(B).to(tr_input_embed.device)).unsqueeze(1)
        sos = sos.expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
        if self.prog_si == 0: x_BLC = sos
        else: x_BLC = torch.cat((sos, self.word_embed(tr_input_embed.float())), dim=1)
        x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC;  pos: 1LC

    f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
    attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
    cond_BD_or_gss = self.shared_ada_lin(cond_BD)
    
    temp = x_BLC.new_ones(8, 8)
    main_type = torch.matmul(temp, temp).dtype
    
    x_BLC = x_BLC.to(dtype=main_type)
    cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
    attn_bias = attn_bias.to(dtype=main_type)

    AdaLNSelfAttn.forward
    for i, b in enumerate(self.blocks):
        x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
    x_BLC = self.get_logits(x_BLC.float(), cond_BD)

    idx_Bl = x_BLC.argmax(dim=-1)
    idx_Bl[:,0] = gt_start_emd.squeeze(1)
    h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl)
    h_list = []
    for si, (st, ed) in enumerate(self.begin_ends):
        pn = self.patch_nums[si]
        h_list.append(h_BChw[:,st:ed,:].reshape(B, int((ed-st)**0.5), int((ed-st)**0.5), self.Cvae).permute(0,3,1,2))
        f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw[:,st:ed,:].transpose_(1,2).reshape(B, self.Cvae, pn, pn))
    for b in self.blocks: b.attn.kv_caching(False)
    return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5)

Again, thanks for your great work!

minimini-1 avatar May 11 '24 21:05 minimini-1

@minimini-1 Hi, I have the same problem as you, did you solve that?

maggiesong7 avatar Jul 15 '24 05:07 maggiesong7

@minimini-1 @maggiesong7 In deed, the way you try to recons an image using VAR is incorrect. VAR formulates a next-scale prediciton task where current scale prediciton is conditioned on previous scales' predicitons. In your implementation, current scale prediciton is condition on previous scales' ground truth predicitons instead of VAR's previous predicitons. For example, a number 12 is tokenizied to [SOS 8 3 1]. In your code: VAR predicts 7 conditioned on [SOS] VAR predicts 2 conditioned on [SOS 8] VAR predicts 0 conditioned on [SOS 8 3] Then you add up [7 2 0] only gets 9, which is quite different to the target 12.

JeyesHan avatar Dec 13 '24 09:12 JeyesHan