k-diffusion
k-diffusion copied to clipboard
FID scores from paper
Hello, in the original K-Diffusion paper the authors report FID scores for CIFAR in the low-single-digits range (eg 1.8). However, the FID scores from this repo all give in the high teens: like 27, 36, 48.
Is the difference due to the FID calculation strategy? It's hard to imagine that the FID is over an order of magnitude off...
Got the same issue on cifar10, based on settings (bs=1024):
step,fid,kid
10000,54.75194549560547,0.024161577224731445
20000,38.80967330932617,0.01055598258972168
30000,35.451629638671875,0.008167028427124023
40000,32.956600189208984,0.005323648452758789
50000,32.06228256225586,0.004504680633544922
I was curious about this too... quoting https://arxiv.org/pdf/1801.01401.pdf:
FID scores can only be compared to one another with the same value of n.
If you compare the CIFAR-10 test and train sets (same distribution!), at n=2000, you'd get an FID around 30. 2000 is the default evaluate-n
for k-diffusion
training, so I think it's expected that the logged FID values will be 30 or higher.
The EDM paper reports results with n_fake=50000
and n_real="all available"
; this FID value will be much lower than any n=2000 FIDs.
To check the 50k FID value for k-diffusion
, I tried training the default CIFAR-10 config in this repo for 320k steps (checkpoint), and sampled 50k images:
mkdir -p cifar10_fake_samples
python3 sample.py --checkpoint model_cifar10_last.pth --prefix cifar10_fake_samples/sample --config configs/config_cifar10.json -n 50000
I then followed the evaluation steps from the EDM repo
cd ..; git clone https://github.com/NVlabs/edm; cd edm
torchrun --standalone --nproc_per_node=1 fid.py calc --images=../k-diffusion/cifar10_fake_samples/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
This gave me an FID value of ~4.6, which seems to be on the right order of magnitude:
Calculating statistics for 50000 images...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [01:05<00:00, 12.01batch/s]
Calculating FID...
4.5546
I expect more training (and stronger dropout / augmentation) is needed to match the ~2.0 FID values reported in the EDM paper without overfitting.
https://github.com/crowsonkb/k-diffusion/pull/78
I am having similar issues. The FID score is pretty high.