pykan icon indicating copy to clipboard operation
pykan copied to clipboard

Problem occured when using model.refine

Open MuellerJY opened this issue 8 months ago • 1 comments

First of all, deeply appreciated to this contribution.

Informed by the tutorial, I tried to use more girds in activation function approximation by model.refine() after training with less grids. However, NaN occurred in my training outputs.

Unfortunately, I found that the NaN appeared just follow the application of model.refine().

How to settle it?

MuellerJY avatar Apr 17 '25 01:04 MuellerJY

I experience the same issue. I have produced a small example to show the issue. I use pykan 0.2.8 installed in editable mode with python 3.11.13 and torch 2.7.1.

The example code is:

import torch
from kan import KAN
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def train(model,steps=100):
    # optimizer = LBFGS(kan_ae.parameters(), lr=1e-3, line_search_fn='strong_wolfe')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    loss_fn = torch.nn.MSELoss()
    
    pbar = tqdm(range(steps), desc="Training KAN Autoencoder")

    for step in pbar:
        def closure():
            optimizer.zero_grad()
            # print("Reconstructing batch...")
            x_recon = model(x_batch)

            loss = loss_fn(x_recon,x_batch)
            # print("Calculating gradients...")
            loss.backward()
            return loss

        loss = optimizer.step(closure)
        if step % 5 == 0:
            current_loss = closure().item()
            pbar.set_description("Step: %d | Loss: %.3f" %
                                 (step, current_loss))
    return loss


## Example
batch_size = 500
num_elem = 100
x_batch = torch.rand(batch_size, num_elem, device=device)
kan_width=[num_elem,num_elem,num_elem]
k=3
G=3
kan_ae = KAN(width=kan_width, k=k, grid=G,symbolic_enabled=False).to(device=device)
kan_ae.update_grid_from_samples(x_batch)

# Train
loss_out = train(kan_ae,steps=100)

# Refine grid
kan_ae = kan_ae.refine(5)

# Train again.
loss_out = train(kan_ae,steps=100)

Which produced the output

Using device: cuda
checkpoint directory created: ./model
saving model version 0.0
Step: 95 | Loss: 0.074: 100%|██████████| 100/100 [00:01<00:00, 73.60it/s]
saving model version 0.1
Step: 95 | Loss: nan: 100%|██████████| 100/100 [00:00<00:00, 108.92it/s]

I have tried varying the widths, the grid sizes, the data dimensions. But in all cases, there are some coefficients that turn nan in the model, leading to nan outputs.

FrejaTerpPetersen avatar Jul 14 '25 13:07 FrejaTerpPetersen