pykan
pykan copied to clipboard
How to use pykan to fit a piecewise function
I've experimented with PyKan and discovered its impressive regression capabilities, particularly for modeling nonlinear functional relationships.
I tried to step a bit forward to do regression on a piecewise function. However, the fitting result can be improve significantly. Therefore, I present the implementation here and kindly ask for some suggestions to improve the goodness of fitting. Appreciate very much!
from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
model = KAN(width=[1,10,1], grid=5, k=3, seed=0)
def f(x):
# piecewise function, x < 0.5, f = 1, x >= 0.5, f = 0
return torch.where(x[:,[0]]<0.5, torch.ones_like(x[:,[0]]), torch.zeros_like(x[:,[0]]))
dataset = create_dataset(f, n_var=1, ranges=[0,1])
# train the model
model.train(dataset, opt="LBFGS", steps=20)
inputs = dataset['train_input']
predictions = model(inputs)
plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)
plt.legend()
plt.show()
The outcome is
In this case, it might be more reasonable to try model = KAN(width=[1,10,1], grid=5, k=1, seed=0)
(possibly increase grid as well), but this is just a workaround that may not be ideal and require more careful development.
By trying an continuous but not differentiable, it works nicely.
from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
import torchpwl
# model = KAN(width=[1,10,1], grid=5, k=3, seed=0)
model = KAN(width=[1,10,1], grid=100, k=1, seed=0)
def f(x):
return torch.where(x > 0.5, x, 0.5)
dataset = create_dataset(f, n_var=1, ranges=[0,1])
# train the model
model.train(dataset, opt="LBFGS", steps=20)
inputs = dataset['train_input']
predictions = model(inputs)
plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)
plt.legend()
plt.show()
Another try. It looks nice.
from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
import numpy as np
# model = KAN(width=[1,10,1], grid=5, k=3, seed=0)
model = KAN(width=[1,10,1], grid=100, k=1, seed=0)
def f(x):
return torch.where(x > 0.5, torch.sin(20.0*x), np.sin(20.0*0.5))
dataset = create_dataset(f, n_var=1, ranges=[0,1], train_num=2000, test_num=2000)
# train the model
model.train(dataset, opt="LBFGS", steps=20)
inputs = dataset['train_input']
predictions = model(inputs)
plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)
plt.legend()
plt.show()
Great, it works. Thank you so much.
from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
model = KAN(width=[1,4,1], grid=100, k=1, seed=0)
def f(x):
# piecewise function, x < 0.5, f = 1, x >= 0.5, f = 0
return torch.where(x[:,[0]]<0.5, torch.ones_like(x[:,[0]]), torch.zeros_like(x[:,[0]]))
dataset = create_dataset(f, n_var=1, ranges=[0,1])
# train the model
model.train(dataset, opt="LBFGS", steps=100)
inputs = dataset['train_input']
predictions = model(inputs)
plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)
plt.legend()
plt.show()