pykan icon indicating copy to clipboard operation
pykan copied to clipboard

How to use pykan to fit a piecewise function

Open zhongjingjogy opened this issue 9 months ago • 1 comments

I've experimented with PyKan and discovered its impressive regression capabilities, particularly for modeling nonlinear functional relationships.

I tried to step a bit forward to do regression on a piecewise function. However, the fitting result can be improve significantly. Therefore, I present the implementation here and kindly ask for some suggestions to improve the goodness of fitting. Appreciate very much!

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt

model = KAN(width=[1,10,1], grid=5, k=3, seed=0)

def f(x):
    # piecewise function, x < 0.5, f = 1, x >= 0.5, f = 0
    return torch.where(x[:,[0]]<0.5, torch.ones_like(x[:,[0]]), torch.zeros_like(x[:,[0]]))

dataset = create_dataset(f, n_var=1, ranges=[0,1])

# train the model
model.train(dataset, opt="LBFGS", steps=20)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

The outcome is output

zhongjingjogy avatar May 15 '24 23:05 zhongjingjogy

In this case, it might be more reasonable to try model = KAN(width=[1,10,1], grid=5, k=1, seed=0) (possibly increase grid as well), but this is just a workaround that may not be ideal and require more careful development.

KindXiaoming avatar May 16 '24 01:05 KindXiaoming

By trying an continuous but not differentiable, it works nicely.

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
import torchpwl

# model = KAN(width=[1,10,1], grid=5, k=3, seed=0)
model = KAN(width=[1,10,1], grid=100, k=1, seed=0)

def f(x):
    return torch.where(x > 0.5, x, 0.5)

dataset = create_dataset(f, n_var=1, ranges=[0,1])

# train the model
model.train(dataset, opt="LBFGS", steps=20)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

output-1

zhongjingjogy avatar May 16 '24 02:05 zhongjingjogy

Another try. It looks nice.

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
import numpy as np

# model = KAN(width=[1,10,1], grid=5, k=3, seed=0)
model = KAN(width=[1,10,1], grid=100, k=1, seed=0)

def f(x):
    return torch.where(x > 0.5, torch.sin(20.0*x), np.sin(20.0*0.5))

dataset = create_dataset(f, n_var=1, ranges=[0,1], train_num=2000, test_num=2000)

# train the model
model.train(dataset, opt="LBFGS", steps=20)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

output-2

zhongjingjogy avatar May 16 '24 03:05 zhongjingjogy

Great, it works. Thank you so much.

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt

model = KAN(width=[1,4,1], grid=100, k=1, seed=0)

def f(x):
    # piecewise function, x < 0.5, f = 1, x >= 0.5, f = 0
    return torch.where(x[:,[0]]<0.5, torch.ones_like(x[:,[0]]), torch.zeros_like(x[:,[0]]))

dataset = create_dataset(f, n_var=1, ranges=[0,1])

# train the model
model.train(dataset, opt="LBFGS", steps=100)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

output-3

zhongjingjogy avatar May 16 '24 06:05 zhongjingjogy