jax icon indicating copy to clipboard operation
jax copied to clipboard

When calculating the loss, the input data does not contain NaN, but the output contains NaN

Open CZXIANGOvO opened this issue 1 year ago • 1 comments

Description

Please specify cuda:0 at the very beginning.

import torch
import numpy as np
import os
import jax
import jax
import jax.numpy as jnp
from jax import ops as jops
from jax.nn import one_hot, sigmoid
from jax import lax
import jax.scipy.special as sc
import optax

if "CONTEXT_DEVICE_TARGET" in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU':
    devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
    device = devices[-2]
    final_device = "cuda:" + device
else:
    final_device = 'cpu'


from network.cv.yolov4.yolov4_pytorch import YOLOV4CspDarkNet53_torch as yolov4_torch

def loss_yolo_jax():
    from network.cv.yolov4.yolov4_pytorch import yolov4loss_jax
    yolo_obj = yolov4loss_jax()
    return yolo_obj


y_true_0 = np.load('./yolo_out[0][0].npy')
yolo_out1 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[0][1].npy')
yolo_out2 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[0][2].npy')
yolo_out3 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[1][0].npy')
yolo_out4 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[1][1].npy')
yolo_out5 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[1][2].npy')
yolo_out6 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[2][0].npy')
yolo_out7 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[2][1].npy')
yolo_out8 = torch.from_numpy(y_true_0).to(final_device)
y_true_0 = np.load('./yolo_out[2][2].npy')
yolo_out9 = torch.from_numpy(y_true_0).to(final_device)

yolo_out = ((yolo_out1,yolo_out2,yolo_out3),(yolo_out4,yolo_out5,yolo_out6),(yolo_out7,yolo_out8,yolo_out9))


model_pt = yolov4_torch()
model_pt.train()
model_torch = model_pt.to(final_device)



y_true_0 = np.load('./y_true_0.npy')
y_true_0 = torch.from_numpy(y_true_0).to(final_device)

y_true_1 = np.load('./y_true_1.npy')
y_true_1 = torch.from_numpy(y_true_1).to(final_device)

y_true_2 = np.load('./y_true_2.npy')

y_true_2 = torch.from_numpy(y_true_2).to(final_device)


gt_0 = np.load('./gt_0.npy')
gt_0 = torch.from_numpy(gt_0).to(final_device)

gt_1 = np.load('./gt_1.npy')
gt_1 = torch.from_numpy(gt_1).to(final_device)


gt_2 = np.load('./gt_2.npy')
gt_2 = torch.from_numpy(gt_2).to(final_device)

input_shape_t = np.load('./input_shape_t.npy')
input_shape_t = torch.from_numpy(input_shape_t).to(final_device)

params_torch = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()}


loss_jax_fun = loss_yolo_jax()
params_jax = {name: jnp.array(value, dtype=jnp.float32) for name, value in params_torch.items()}
loss_jax, jax_grads = jax.value_and_grad(loss_jax_fun.calc_loss)(params_jax, yolo_out, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape_t)


yolo_out1 = torch.isnan(yolo_out1).any()
print('yolo_out1;',yolo_out1) 
yolo_out2 = torch.isnan(yolo_out2).any()
print('yolo_out2;',yolo_out2) 
yolo_out3 = torch.isnan(yolo_out3).any()
print('yolo_out3;',yolo_out3) 
yolo_out4 = torch.isnan(yolo_out4).any()
print('yolo_out4;',yolo_out4)
yolo_out5 = torch.isnan(yolo_out5).any()
print('yolo_out5;',yolo_out5) 
yolo_out6 = torch.isnan(yolo_out6).any()
print('yolo_out6;',yolo_out6) 
yolo_out7 = torch.isnan(yolo_out7).any()
print('yolo_out7;',yolo_out7)
yolo_out8 = torch.isnan(yolo_out8).any()
print('yolo_out8;',yolo_out8)
yolo_out9 = torch.isnan(yolo_out9).any()
print('yolo_out9;',yolo_out9) 
y_true_0 = torch.isnan(y_true_0).any()
print('y_true_0;',y_true_0) 
y_true_1 = torch.isnan(y_true_1).any()
print('y_true_1;',y_true_1)
y_true_2 = torch.isnan(y_true_2).any()
print('y_true_2;',y_true_2) 
gt_0 = torch.isnan(gt_0).any()
print('gt_0;',gt_0) 
gt_1 = torch.isnan(gt_1).any()
print('gt_1;',gt_1) 
gt_2 = torch.isnan(gt_2).any()
print('gt_2;',gt_2) 
input_shape_t = torch.isnan(input_shape_t).any()
print('input_shape_t;',input_shape_t) 

print('loss_torch_result;',np.array(loss_jax)) 

屏幕截图 2024-09-13 204340

System info (python version, jaxlib version, accelerator, etc.)

Code and data links:https://drive.google.com/file/d/1-edrk7_sxSgdu7cmXQXf6JsT57xiG1Hb/view?usp=sharing

CZXIANGOvO avatar Sep 13 '24 12:09 CZXIANGOvO

Is there a MVC? This code doesn't run for me

lockwo avatar Sep 15 '24 00:09 lockwo

Is there a MVC? This code doesn't run for me

Is there a MVC? This code doesn't run for me

Where can't run it, in the beginning final_device to set it yourself, you can delete

 if “CONTEXT_DEVICE_TARGET” in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU': devices = os.environ['CUDA_VISIBLE_DEVICES'].split(“,”).
    devices = os.environ['CUDA_VISIBLE_DEVICES'].split(“,”)
    device = devices[-2]
    final_device = “cuda:” + device
else: final_device = 'cuda:” + device
    final_device = 'cpu' 

Translated with DeepL.com (free version)

CZXIANGOvO avatar Sep 18 '24 10:09 CZXIANGOvO

Hi @CZXIANGOvO – it's going to be hard to help with specifics here absent an MVC (also known as a minimal reproducible example). If you're able to re-work your example so that others can run it and see the same errors you are seeing, then we could offer specific guidance.

Absent that, though, in general it's not surprising to see NaN outputs for inputs without NaNs: it just means that you're calling some function in your model in a way that is undefined to floating point precision. Here's a simple example of this:

>>> import jax.numpy as jnp

>>> def f(x, y):
...   return x * jnp.exp(y)

>>> f(1.0, 1.0)
Array(2.7182817, dtype=float32, weak_type=True)

>>> f(0.0, 100.0)
Array(nan, dtype=float32, weak_type=True)

More than likely, somewhere in your model you have an expression that is evaluating to NaN for reasons like this.

The best way to debug this is to start digging-in to your model to figure out exactly where this is coming from. One way to do this is to enable the jax_debug_nans flag, as described here: https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debug-nans-configuration-option-and-context-manager

I hope that helps get you on the right path!

jakevdp avatar Sep 18 '24 17:09 jakevdp