OpenDiT icon indicating copy to clipboard operation
OpenDiT copied to clipboard

When running CIFAR10 demo, NAN appears in loss.

Open fupiao1998 opened this issue 11 months ago • 14 comments

Hello, I am running the CIFAR10 demo you provided. The script used is as follows. I found that the loss cannot be optimized and NAN appears. Is there any solution to this problem?

torchrun --standalone --nproc_per_node=8 train.py \
    --model DiT-XL/2 \
    --batch_size 36 \
    --mixed_precision fp16 \
    --ckpt_every 10000\
    --num_workers 12 \
    --enable_modulate_kernel \
    --enable_layernorm_kernel \
    --enable_flashattn

image

fupiao1998 avatar Mar 01 '24 03:03 fupiao1998

you can use bf16 for mix precision toavoid nan

oahzxl avatar Mar 01 '24 03:03 oahzxl

Thank you so much for such a quick reply, I will try your solution, thank you!

fupiao1998 avatar Mar 01 '24 03:03 fupiao1998

Hello, after I replaced it with bf16, the NAN problem still occurred.

torchrun --standalone --nproc_per_node=8 train.py \
    --model DiT-XL/2 \
    --batch_size 36 \
    --ckpt_every 10000\
    --num_workers 12 \
    --enable_modulate_kernel \
    --enable_layernorm_kernel \
    --enable_flashattn \
    --mixed_precision bf16 

image

fupiao1998 avatar Mar 01 '24 08:03 fupiao1998

thats strange. we have not met such problem. here are some potential solutions but not sure: 1) increase batch size 2) decrease grad_clip 3) disable modulate kernel 4) use grad accumulation

oahzxl avatar Mar 01 '24 08:03 oahzxl

Hello, I am running the CIFAR10 demo. According to the script in the readme, I expanded the batch size and the number of graphics cards, and left the rest unchanged.

fupiao1998 avatar Mar 01 '24 09:03 fupiao1998

ok im trying to reproduce your result

oahzxl avatar Mar 01 '24 09:03 oahzxl

Thank you so much!

fupiao1998 avatar Mar 01 '24 09:03 fupiao1998

i have not yet reproduced your result but i suspect the nan is brought by class embedding. class num is set to 1000 as imagenet by default. but in cifar10, it only has 10 classes. so as the train continues, the prob of other 990 classes will be really small and cause this nan

oahzxl avatar Mar 01 '24 09:03 oahzxl

i have not yet reproduced your result but i suspect the nan is brought by class embedding. class num is set to 1000 as imagenet by default. but in cifar10, it only has 10 classes. so as the train continues, the prob of other 990 classes will be really small and cause this nan

maybe you can try to set --num_classes 10

oahzxl avatar Mar 01 '24 09:03 oahzxl

OK, I will try again after modifying it, thank you. In addition, I checked the log and found that NAN will appear at epoch=5 during my training process.

fupiao1998 avatar Mar 01 '24 09:03 fupiao1998

After setting --num_classes 10 and disabling the modulate kernel, training is now normal. In the step where NAN occurred before, NAN did not occur. Thank you for your positive reply. ​

fupiao1998 avatar Mar 01 '24 10:03 fupiao1998

thanks for your using of opendit. we will update our readme about this issue

oahzxl avatar Mar 01 '24 10:03 oahzxl

I got the same problem when enable modulate kernel.

MikeChenfu avatar Mar 01 '24 18:03 MikeChenfu

I got the same problem when enable modulate kernel.

me too

SmileTAT avatar Mar 04 '24 07:03 SmileTAT