VAR icon indicating copy to clipboard operation
VAR copied to clipboard

Can not align FID with provided checkpoint

Open LiCHH opened this issue 1 year ago • 7 comments
trafficstars

Hello, I wrote a script based on the demo_sample.ipynb to generate 50,000 samples and tested them using OpenAI's FID evaluation toolkit. However, I found that the metrics did not align. Could you help me identify the problem? I got image by using d20 checkpoint and image by using d30 checkpoint. The script is as below:

################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
from tqdm import tqdm
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 20    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

############################# 2. Sample with classifier-free guidance

# set args
seed = 1 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg = 1.5 #@param {type:"slider", min:1, max:10, step:0.1}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = 25
for img_cls in tqdm(range(1000)):
    for i in range(50 // B):
        label_B = torch.tensor([img_cls] * 25, device=device)
        # B = len(class_labels)
        # label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
        with torch.inference_mode():
            with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
                recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.96, g_seed=seed, more_smooth=more_smooth)
            bchw = recon_B3HW.permute(0, 2, 3, 1).mul_(255).cpu().numpy()
        bchw = bchw.astype(np.uint8)
        for j in range(B):
            img = PImage.fromarray(bchw[j])
            img.save(f"./samples_d20/sample_{img_cls * 50 + i * B + j}.png")

LiCHH avatar Jun 06 '24 09:06 LiCHH

@keyu-tian Thanks for your great work! could you please help with this issue? Any insights could be helpful.

ma-xu avatar Jun 11 '24 22:06 ma-xu

+1 Thanks for the great work! wondering if you are willing to share your evaluation scripts ?

Kumbong avatar Jul 03 '24 23:07 Kumbong

@LiCHH @ma-xu @Kumbong . It seems the generation seed is replicate for each class. recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.96, g_seed=seed, more_smooth=more_smooth)

gseed should be different . I would use

    for i in range(50 // B):
        label_B = torch.tensor([img_cls] * 25, device=device)
        # B = len(class_labels)
        # label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
        with torch.inference_mode():
            with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
                recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.96, g_seed=seed + i, more_smooth=more_smooth)

ChenDRAG avatar Jul 04 '24 05:07 ChenDRAG

I tried your suggestion, but the results got worse. The randomness of c2i has already been introduced by rng, so shouldn't it be unnecessary to change the random seed? Or could you share your evaluation code with me? @ChenDRAG @keyu-tian

Tom-zgt avatar Jul 17 '24 09:07 Tom-zgt

You can try to use B=50, since random seed is shared across different batches.

sen-ye avatar Aug 06 '24 11:08 sen-ye

You can try to use B=50, since random seed is shared across different batches.

Have you been able to reproduce the results in Table 1 of the paper? Could you share the inference script?

We use B=50 for each class and var_d16 for evaluation.

  • report
FID IS Pre Rec
3.60 257.5 0.85 0.48
  • reproduced
FID IS Pre Rec
3.49 281.71 0.85 0.50

The main issue is the IS. The results on other metrics are similar to those in the paper. Thanks.

adreamwu avatar Aug 15 '24 21:08 adreamwu

This solved the issue for me This solved the issue for me https://github.com/openai/guided-diffusion/issues/153

Kumbong avatar Oct 18 '24 05:10 Kumbong

since all of the people here solve the problem, we will close the issue

enjoyyi00 avatar Dec 13 '24 08:12 enjoyyi00

Hi, I encountered the same problem. And there is my setting. cfg = 1.5, seed=42, top_k = 900, top_p = 0.96, bsz=50 I get fid=4.24, is=328.19 and generate 50figs for each class.

duyuxuan1486 avatar Mar 30 '25 09:03 duyuxuan1486

@duyuxuan1486, what checkpoint are you using, and can you share your sampling code? (how are you assign the g_seed when generating each class ). I think you should used a different seed for each class

Kumbong avatar Mar 30 '25 10:03 Kumbong

@Kumbong , the ckpt is d16 hf weight provided, the seed=42 for all classes. There are samling code.

seed = args.seed #@param {type:"number"}
torch.manual_seed(seed)
random.seed(seed)
cfg = 1.5
top_k = 900
top_p = 0.96
class_labels = [i for i in range(1000) for _ in range(50)]
image_number = len(class_labels)
more_smooth = False # True for more smooth output
out_np = []
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        B = args.batch_size
        for i in range(int(image_number // args.batch_size)):
            label_B = torch.tensor(class_labels[i*B : (i+1)*B], device=device, dtype=torch.long)
            recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=top_k, top_p=top_p, g_seed=seed, more_smooth=more_smooth)
            chw = recon_B3HW.permute(0, 2, 3, 1).mul_(255).cpu().numpy().astype(np.uint8)
            out_np.append(chw)
out_np = np.concatenate(out_np, axis=0)
np.savez(f'{args.npz_path}/{cfg}cfg_{top_k}tk_{top_p}tp_{seed}seed.npz', arr_0=out_np)

the sampling code saves npz directly without saving pngs and the fid=3.47

duyuxuan1486 avatar Mar 30 '25 12:03 duyuxuan1486

do i need to change seed when generating different classes?

duyuxuan1486 avatar Mar 30 '25 12:03 duyuxuan1486

Could you set the seed to be the class ID for the class you are generating ?

So that the seed is different for each class

Kumbong avatar Mar 30 '25 12:03 Kumbong

I set the seed equal to the class ID, and the fid is 3.38. And the expected fid is 3.30

duyuxuan1486 avatar Mar 30 '25 13:03 duyuxuan1486

are there any other changes that might cause performance degradation?

duyuxuan1486 avatar Mar 30 '25 13:03 duyuxuan1486

I think that should be fine? Or within good enough error margin

Kumbong avatar Mar 30 '25 18:03 Kumbong

thanks for ur responce, var is a nice work

duyuxuan1486 avatar Mar 31 '25 06:03 duyuxuan1486