Problem occured when using model.refine
First of all, deeply appreciated to this contribution.
Informed by the tutorial, I tried to use more girds in activation function approximation by model.refine() after training with less grids. However, NaN occurred in my training outputs.
Unfortunately, I found that the NaN appeared just follow the application of model.refine().
How to settle it?
I experience the same issue. I have produced a small example to show the issue. I use pykan 0.2.8 installed in editable mode with python 3.11.13 and torch 2.7.1.
The example code is:
import torch
from kan import KAN
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
def train(model,steps=100):
# optimizer = LBFGS(kan_ae.parameters(), lr=1e-3, line_search_fn='strong_wolfe')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = torch.nn.MSELoss()
pbar = tqdm(range(steps), desc="Training KAN Autoencoder")
for step in pbar:
def closure():
optimizer.zero_grad()
# print("Reconstructing batch...")
x_recon = model(x_batch)
loss = loss_fn(x_recon,x_batch)
# print("Calculating gradients...")
loss.backward()
return loss
loss = optimizer.step(closure)
if step % 5 == 0:
current_loss = closure().item()
pbar.set_description("Step: %d | Loss: %.3f" %
(step, current_loss))
return loss
## Example
batch_size = 500
num_elem = 100
x_batch = torch.rand(batch_size, num_elem, device=device)
kan_width=[num_elem,num_elem,num_elem]
k=3
G=3
kan_ae = KAN(width=kan_width, k=k, grid=G,symbolic_enabled=False).to(device=device)
kan_ae.update_grid_from_samples(x_batch)
# Train
loss_out = train(kan_ae,steps=100)
# Refine grid
kan_ae = kan_ae.refine(5)
# Train again.
loss_out = train(kan_ae,steps=100)
Which produced the output
Using device: cuda
checkpoint directory created: ./model
saving model version 0.0
Step: 95 | Loss: 0.074: 100%|██████████| 100/100 [00:01<00:00, 73.60it/s]
saving model version 0.1
Step: 95 | Loss: nan: 100%|██████████| 100/100 [00:00<00:00, 108.92it/s]
I have tried varying the widths, the grid sizes, the data dimensions. But in all cases, there are some coefficients that turn nan in the model, leading to nan outputs.