pykan
pykan copied to clipboard
The out put of the KAN model is a matrix fulled of (NaN) in second time iterations
I replace the MLP of NeRF,But get the NaN loss because of the KAN model's output
initialize KAN:
net = KAN(width=[84,1,4], grid=5, k=3, seed=0,device=device)
#----KAN------
device = pts_flat.device
pts_flat = pts_flat.to(device)
dir_flat = dir_flat.to(device)
data=torch.cat((pts_flat, dir_flat), dim=-1)
model_out = net(data)
print (model_out) # here is the print
rgb = model_out[:,0:3]
sigma = model_out[:,3]
#------------------
print result:
Epoch 1: 0%| | 1/976 [00:00<09:24, 1.73it/s, loss=inf] tensor([[ 0.0486, 0.0511, 0.0563, 0.0445], [ 0.0033, -0.0027, 0.0009, -0.0026], [ 0.2292, 0.2569, 0.2867, 0.2423], ..., [-0.0786, -0.0964, -0.1001, -0.0851], [-0.0152, -0.0245, -0.0216, -0.0215], [ 0.1987, 0.2231, 0.2465, 0.2082]], device='cuda:0', grad_fn=<AddBackward0>) Epoch 1: 0%| | 2/976 [00:03<31:20, 1.93s/it, loss=nan] tensor([[nan, nan, nan, nan], [nan, nan, nan, nan], [nan, nan, nan, nan], ..., [nan, nan, nan, nan], [nan, nan, nan, nan], [nan, nan, nan, nan]], device='cuda:0', grad_fn=<AddBackward0>) Epoch 1: 0%|▏ | 3/976 [00:03<18:57, 1.17s/it, loss=nan] tensor([[nan, nan, nan, nan], [nan, nan, nan, nan], [nan, nan, nan, nan], ..., [nan, nan, nan, nan], [nan, nan, nan, nan], [nan, nan, nan, nan]], device='cuda:0', grad_fn=<AddBackward0>) Epoch 1: 0%|▏ | 4/976 [00:06<28:42, 1.77s/it, loss=nan] tensor([[nan, nan, nan, nan], [nan, nan, nan, nan], [nan, nan, nan, nan], ..., [nan, nan, nan, nan], [nan, nan, nan, nan], [nan, nan, nan, nan]], device='cuda:0', grad_fn=<AddBackward0>)
And after that I check the backward code:
`` for i in range(iterations): train_rays = next(train_iter) assert train_rays.shape == (Batch_size, 9)
rays_o, rays_d, target_rgb = torch.chunk(train_rays, 3, dim=-1)
rays_od = (rays_o, rays_d)
rgb, _, __ = render_rays(net, rays_od, bound=bound, N_samples=N_samples, device=device, use_view=use_view)
acts_scale=[84,0.25]
loss = mse(rgb, target_rgb)
optimizer.zero_grad()
loss.backward()
optimizer.step()
p_bar.set_postfix({'loss': '{0:1.5f}'.format(loss.item())})
p_bar.update(1)
``
Maybe the way to calculation loss is wrong?
WHO can help me ?
THANK YOU!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!