Pytorch-UNet icon indicating copy to clipboard operation
Pytorch-UNet copied to clipboard

Predicted output is black.

Open akshay9396 opened this issue 1 year ago • 10 comments

I trained the model on Carvana dataset. Below is the screen shot of training. There i you can clearly see that validation Dice score is constant for 5 epochs and loss is "nan". I used one checkpoint.pth for predication and i got black output. could you please help me to resolved this issue. image Capture

akshay9396 avatar Aug 02 '24 14:08 akshay9396

Have you solved it? I'm having the exact same problem

Feng2100 avatar Sep 26 '24 13:09 Feng2100

I'm having the exact same problem,too

chenshans avatar Sep 29 '24 12:09 chenshans

I'm having the exact same problem,too

berda-ak avatar Sep 30 '24 14:09 berda-ak

I'm having the exact same problem,too

caigg188 avatar Oct 21 '24 01:10 caigg188

I trained the model on Carvana dataset. Below is the screen shot of training. There i you can clearly see that validation Dice score is constant for 5 epochs and loss is "nan". I used one checkpoint.pth for predication and i got black output. could you please help me to resolved this issue. image Capture

Probably the reason is that torch.float16 is used in autocast() when using AMP. So, change torch.float16 to torch.bfloat16. But, if you wanna use torch.float16, you add grad_scaler in init_scale=4096. I solved this issue. I referred to the following sites. This is my first time commenting on github, so sorry if there is something wrong and my poor English. https://qiita.com/takeuchiseijin/items/909c48b57127a37fbd12 https://qiita.com/bowdbeg/items/71c62cf8ef891d164ecd

taipain avatar Nov 14 '24 02:11 taipain

I trained the model on Carvana dataset. Below is the screen shot of training. There i you can clearly see that validation Dice score is constant for 5 epochs and loss is "nan". I used one checkpoint.pth for predication and i got black output. could you please help me to resolved this issue. image Capture

Probably the reason is that torch.float16 is used in autocast() when using AMP. So, change torch.float16 to torch.bfloat16. But, if you wanna use torch.float16, you add grad_scaler in init_scale=4096. I solved this issue. I referred to the following sites. This is my first time commenting on github, so sorry if there is something wrong and my poor English. https://qiita.com/takeuchiseijin/items/909c48b57127a37fbd12 https://qiita.com/bowdbeg/items/71c62cf8ef891d164ecd

so to solve this problem, we simply should add dtype in autocast, which appears as 'with torch.autocast(dtype = torch.bfloat16)' in the train.py. (at least that is what i have thought from your links, qwq)

MorleyOlsen avatar Dec 29 '24 04:12 MorleyOlsen

@taipain thanks for the solution. It helps to set ' dtype=torch.bfloat16' in meth:torch.autocast, which address the problem of 'NAN' loss and black predition.

Ralph-cong avatar May 11 '25 06:05 Ralph-cong

感谢您的解决方案。在 meth:torch.autocast 中设置 'dtype=torch.bfloat16' 有助于解决 'NAN' 丢失和黑色 predition 的问题。 Hi, could you please tell me which folder it is in?

algwang avatar May 14 '25 06:05 algwang

感谢您的解决方案。在 meth:torch.autocast 中设置 'dtype=torch.bfloat16' 有助于解决 'NAN' 丢失和黑色 predition 的问题。 Hi, could you please tell me which folder it is in?

train.py before with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):

after with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp,dtype=torch.bfloat16):

taipain avatar May 14 '25 07:05 taipain

感谢您的解决方案。在 meth:torch.autocast 中设置 'dtype=torch.bfloat16' 有助于解决 'NAN' 丢失和黑色 predition 的问题。您好,您能告诉我它在哪个文件夹中吗?

之前使用 torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp)train.py:

之后使用 torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp,dtype=torch.bfloat16):

Hello, after I made the modification, I still get the same error. When I added the validation set during training, the images of the validation machine were masked, but when I used predict.py for testing, I still got the same error. Can you explain why?

algwang avatar May 14 '25 11:05 algwang