Any step-by-step guide on KAN with code?
Are there any step-by-step guides on how to apply KAN on tabular binary classification datasets??
I used this guide https://www.kaggle.com/code/seyidcemkarakas/kan-tabular-data-binary-classification for my dataset and got this error:
checkpoint directory created: ./model
saving model version 0.0
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-16-bff360054870>](https://localhost:8080/#) in <cell line: 41>()
39
40 # KAN model training
---> 41 results = model.train({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
42 metrics=(train_acc, test_acc),
43 opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())
TypeError: Module.train() got an unexpected keyword argument 'metrics'`
`import torch
from kan import KAN
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# Converting data to Torch tensor
train_input = torch.tensor(X_train.to_numpy(), dtype=torch.float32)
train_label = torch.tensor(y_train.to_numpy(), dtype=torch.long)
val_input = torch.tensor(X_val.to_numpy(), dtype=torch.float32)
val_label = torch.tensor(y_val.to_numpy(), dtype=torch.long)
test_input = torch.tensor(X_test.to_numpy(), dtype=torch.float32)
test_label = torch.tensor(y_test.to_numpy(), dtype=torch.long)
dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
}
# Create model
model = KAN(width=[11, 2], grid=10, k=2)
# Fuctions for getting accuracy scores while training
def train_acc():
preds = torch.argmax(model(dataset['train_input']), dim=1)
return torch.mean((preds == dataset['train_label']).float())
def test_acc():
preds = torch.argmax(model(dataset['test_input']), dim=1)
return torch.mean((preds == dataset['test_label']).float())
# KAN model training
results = model.train({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
metrics=(train_acc, test_acc),
opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())
# Predictions of train val and test datasets
test_preds = torch.argmax(model.forward(test_input).detach(),dim=1)
test_labels = test_label
train_preds = torch.argmax(model.forward(train_input).detach(),dim=1)
train_labels = train_label
val_preds = torch.argmax(model.forward(val_input).detach(),dim=1)
val_labels = val_label
# Evaluate metrics
print("Train ACC:", accuracy_score(train_labels.numpy(), train_preds.numpy()))
print("Val ACC:", accuracy_score(val_labels.numpy(), val_preds.numpy()))
print("Test ACC:", accuracy_score(test_labels.numpy(), test_preds.numpy()))
# Plotting KAN network
model.plot(scale=10)
# Learning curve based on ACC and LOSS
plt.figure(figsize=(10, 5))
plt.plot(results["train_acc"], label='Training Accuracy')
plt.plot(results["test_acc"], label='Val Accuracy')
plt.plot(results["train_loss"], label='Training Loss')
plt.plot(results["test_loss"], label='Val Loss')
plt.title('Training and Val Accuracy over Iterations')
plt.xlabel('Iteration')
plt.ylabel('Accuracy & Loss')
plt.legend()
plt.grid(True)
plt.show()`
Hi @apavlo89
I saw your comments on Kaggle on my notebook. Firstly thanks for looking.
Seconly, I run my code again but I dont get any errors. Can you share verisons of your libraries ?
Hi @apavlo89,
Since 0.2.0, you will need to use the fit method instead of train.
Hi @apavlo89,
Since 0.2.0, you will need to use the fit method instead of train.
Yes, @apavlo89 Could you try it?
fixed with this:
import torch
from kan import KAN
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# Check the shape of your data
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
print(f"Shape of X_test: {X_test.shape}")
# Converting data to Torch tensor
train_input = torch.tensor(X_train.to_numpy(), dtype=torch.float32)
train_label = torch.tensor(y_train.to_numpy(), dtype=torch.long)
val_input = torch.tensor(X_val.to_numpy(), dtype=torch.float32)
val_label = torch.tensor(y_val.to_numpy(), dtype=torch.long)
test_input = torch.tensor(X_test.to_numpy(), dtype=torch.float32)
test_label = torch.tensor(y_test.to_numpy(), dtype=torch.long)
dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
}
# Ensure the input width matches the number of features
input_width = X_train.shape[1]
# Create model
model = KAN(width=[input_width, 2], grid=10, k=2)
# Functions for getting accuracy scores while training
def train_acc():
preds = torch.argmax(model(dataset['train_input']), dim=1)
return torch.mean((preds == dataset['train_label']).float())
def test_acc():
preds = torch.argmax(model(dataset['test_input']), dim=1)
return torch.mean((preds == dataset['test_label']).float())
# KAN model training using the fit method
results = model.fit(dataset,
metrics=(train_acc, test_acc),
opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())
# Predictions of train, val, and test datasets
test_preds = torch.argmax(model.forward(test_input).detach(), dim=1)
test_labels = test_label
train_preds = torch.argmax(model.forward(train_input).detach(), dim=1)
train_labels = train_label
val_preds = torch.argmax(model.forward(val_input).detach(), dim=1)
val_labels = val_label
# Evaluate metrics
print("Train ACC:", accuracy_score(train_labels.numpy(), train_preds.numpy()))
print("Val ACC:", accuracy_score(val_labels.numpy(), val_preds.numpy()))
print("Test ACC:", accuracy_score(test_labels.numpy(), test_preds.numpy()))
# Plotting KAN network
model.plot(scale=10)
# Learning curve based on ACC and LOSS
plt.figure(figsize=(10, 5))
plt.plot(results["train_acc"], label='Training Accuracy')
plt.plot(results["val_acc"], label='Validation Accuracy')
plt.plot(results["train_loss"], label='Training Loss')
plt.plot(results["val_loss"], label='Validation Loss')
plt.title('Training and Validation Accuracy over Iterations')
plt.xlabel('Iteration')
plt.ylabel('Accuracy & Loss')
plt.legend()
plt.grid(True)
plt.show()
Any idea how to tune model? I am getting absolutely abysmal performance and also get an error when all steps finish. Is it just a case of increasing steps? Anything else?:
Shape of X_train: (1393, 369)
Shape of X_val: (725, 369)
Shape of X_test: (249, 369)
checkpoint directory created: ./model
saving model version 0.0
description: 0%| | 0/100 [00:00<?, ?it/s]
| train_loss: 1.92e+00 | test_loss: 1.86e+00 | reg: 0.00e+00 | : 0%| | 0/100 [00:02<?, ?it/s]
| train_loss: 1.92e+00 | test_loss: 1.86e+00 | reg: 0.00e+00 | : 1%| | 1/100 [00:02<04:33, 2.76s/
| train_loss: 1.55e+00 | test_loss: 1.46e+00 | reg: 0.00e+00 | : 1%| | 1/100 [00:04<04:33, 2.76s/
| train_loss: 1.55e+00 | test_loss: 1.46e+00 | reg: 0.00e+00 | : 2%| | 2/100 [00:04<03:24, 2.09s/
| train_loss: 1.37e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | : 2%| | 2/100 [00:05<03:24, 2.09s/
| train_loss: 1.37e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | : 3%| | 3/100 [00:05<02:58, 1.84s/
| train_loss: 1.23e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | : 3%| | 3/100 [00:07<02:58, 1.84s/
| train_loss: 1.23e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | : 4%| | 4/100 [00:07<02:51, 1.79s/
| train_loss: 1.16e+00 | test_loss: 1.18e+00 | reg: 0.00e+00 | : 4%| | 4/100 [00:09<02:51, 1.79s/
| train_loss: 1.16e+00 | test_loss: 1.18e+00 | reg: 0.00e+00 | : 5%| | 5/100 [00:09<02:45, 1.74s/
| train_loss: 1.11e+00 | test_loss: 1.15e+00 | reg: 0.00e+00 | : 5%| | 5/100 [00:11<02:45, 1.74s/
| train_loss: 1.11e+00 | test_loss: 1.15e+00 | reg: 0.00e+00 | : 6%| | 6/100 [00:11<02:45, 1.76s/
| train_loss: 1.03e+00 | test_loss: 1.11e+00 | reg: 0.00e+00 | : 6%| | 6/100 [00:12<02:45, 1.76s/
| train_loss: 1.03e+00 | test_loss: 1.11e+00 | reg: 0.00e+00 | : 7%| | 7/100 [00:12<02:37, 1.69s/
| train_loss: 9.89e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 7%| | 7/100 [00:14<02:37, 1.69s/
| train_loss: 9.89e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 8%| | 8/100 [00:14<02:42, 1.77s/
| train_loss: 9.64e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 8%| | 8/100 [00:16<02:42, 1.77s/
| train_loss: 9.64e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 9%| | 9/100 [00:16<02:44, 1.80s/
| train_loss: 9.17e-01 | test_loss: 1.06e+00 | reg: 0.00e+00 | : 9%| | 9/100 [00:17<02:44, 1.80s/
| train_loss: 9.17e-01 | test_loss: 1.06e+00 | reg: 0.00e+00 | : 10%| | 10/100 [00:17<02:34, 1.71s
| train_loss: 1.04e+00 | test_loss: 1.17e+00 | reg: 0.00e+00 | : 10%| | 10/100 [00:19<02:34, 1.71s
| train_loss: 1.04e+00 | test_loss: 1.17e+00 | reg: 0.00e+00 | : 11%| | 11/100 [00:19<02:31, 1.70s
| train_loss: 9.94e-01 | test_loss: 1.14e+00 | reg: 0.00e+00 | : 11%| | 11/100 [00:21<02:31, 1.70s
| train_loss: 9.94e-01 | test_loss: 1.14e+00 | reg: 0.00e+00 | : 12%| | 12/100 [00:21<02:26, 1.67s
| train_loss: 9.54e-01 | test_loss: 1.10e+00 | reg: 0.00e+00 | : 12%| | 12/100 [00:22<02:26, 1.67s
| train_loss: 9.54e-01 | test_loss: 1.10e+00 | reg: 0.00e+00 | : 13%|▏| 13/100 [00:22<02:24, 1.66s
| train_loss: 9.22e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 13%|▏| 13/100 [00:24<02:24, 1.66s
| train_loss: 9.22e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 14%|▏| 14/100 [00:24<02:21, 1.65s
| train_loss: 8.94e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 14%|▏| 14/100 [00:26<02:21, 1.65s
| train_loss: 8.94e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | : 15%|▏| 15/100 [00:26<02:26, 1.73s
| train_loss: 1.15e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | : 15%|▏| 15/100 [00:28<02:26, 1.73s
| train_loss: 1.15e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | : 16%|▏| 16/100 [00:28<02:43, 1.95s
| train_loss: 1.06e+00 | test_loss: 1.25e+00 | reg: 0.00e+00 | : 16%|▏| 16/100 [00:30<02:43, 1.95s
| train_loss: 1.06e+00 | test_loss: 1.25e+00 | reg: 0.00e+00 | : 17%|▏| 17/100 [00:30<02:34, 1.86s
| train_loss: 9.92e-01 | test_loss: 1.24e+00 | reg: 0.00e+00 | : 17%|▏| 17/100 [00:32<02:34, 1.86s
| train_loss: 9.92e-01 | test_loss: 1.24e+00 | reg: 0.00e+00 | : 18%|▏| 18/100 [00:32<02:27, 1.79s
| train_loss: 9.44e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 18%|▏| 18/100 [00:33<02:27, 1.79s
| train_loss: 9.44e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 19%|▏| 19/100 [00:33<02:21, 1.75s
| train_loss: 9.32e-01 | test_loss: 1.20e+00 | reg: 0.00e+00 | : 19%|▏| 19/100 [00:35<02:21, 1.75s
| train_loss: 9.32e-01 | test_loss: 1.20e+00 | reg: 0.00e+00 | : 20%|▏| 20/100 [00:35<02:16, 1.70s
| train_loss: 9.57e-01 | test_loss: 1.28e+00 | reg: 0.00e+00 | : 20%|▏| 20/100 [00:37<02:16, 1.70s
| train_loss: 9.57e-01 | test_loss: 1.28e+00 | reg: 0.00e+00 | : 21%|▏| 21/100 [00:37<02:18, 1.76s
| train_loss: 9.35e-01 | test_loss: 1.25e+00 | reg: 0.00e+00 | : 21%|▏| 21/100 [00:38<02:18, 1.76s
| train_loss: 9.35e-01 | test_loss: 1.25e+00 | reg: 0.00e+00 | : 22%|▏| 22/100 [00:38<02:16, 1.75s
| train_loss: 9.16e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 22%|▏| 22/100 [00:41<02:16, 1.75s
| train_loss: 9.16e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 23%|▏| 23/100 [00:41<02:23, 1.86s
| train_loss: 9.05e-01 | test_loss: 1.21e+00 | reg: 0.00e+00 | : 23%|▏| 23/100 [00:42<02:23, 1.86s
| train_loss: 9.05e-01 | test_loss: 1.21e+00 | reg: 0.00e+00 | : 24%|▏| 24/100 [00:42<02:15, 1.78s
| train_loss: 8.93e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 24%|▏| 24/100 [00:44<02:15, 1.78s
| train_loss: 8.93e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 25%|▎| 25/100 [00:44<02:05, 1.68s
| train_loss: 1.60e+00 | test_loss: 1.54e+00 | reg: 0.00e+00 | : 25%|▎| 25/100 [00:45<02:05, 1.68s
| train_loss: 1.60e+00 | test_loss: 1.54e+00 | reg: 0.00e+00 | : 26%|▎| 26/100 [00:45<02:07, 1.73s
| train_loss: 1.36e+00 | test_loss: 1.26e+00 | reg: 0.00e+00 | : 26%|▎| 26/100 [00:47<02:07, 1.73s
| train_loss: 1.36e+00 | test_loss: 1.26e+00 | reg: 0.00e+00 | : 27%|▎| 27/100 [00:47<01:59, 1.64s
| train_loss: 1.35e+00 | test_loss: 1.24e+00 | reg: 0.00e+00 | : 27%|▎| 27/100 [00:48<01:59, 1.64s
| train_loss: 1.35e+00 | test_loss: 1.24e+00 | reg: 0.00e+00 | : 28%|▎| 28/100 [00:48<01:52, 1.56s
| train_loss: 1.32e+00 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 28%|▎| 28/100 [00:50<01:52, 1.56s
| train_loss: 1.32e+00 | test_loss: 1.22e+00 | reg: 0.00e+00 | : 29%|▎| 29/100 [00:50<01:52, 1.58s
| train_loss: 1.30e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | : 29%|▎| 29/100 [00:52<01:52, 1.58s
| train_loss: 1.30e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | : 30%|▎| 30/100 [00:52<01:55, 1.65s
| train_loss: 2.35e+00 | test_loss: 2.64e+00 | reg: 0.00e+00 | : 30%|▎| 30/100 [00:54<01:55, 1.65s
| train_loss: 2.35e+00 | test_loss: 2.64e+00 | reg: 0.00e+00 | : 31%|▎| 31/100 [00:54<02:07, 1.84s
| train_loss: 2.19e+00 | test_loss: 2.15e+00 | reg: 0.00e+00 | : 31%|▎| 31/100 [00:56<02:07, 1.84s
| train_loss: 2.19e+00 | test_loss: 2.15e+00 | reg: 0.00e+00 | : 32%|▎| 32/100 [00:56<02:00, 1.77s
| train_loss: 2.01e+00 | test_loss: 2.02e+00 | reg: 0.00e+00 | : 32%|▎| 32/100 [00:57<02:00, 1.77s
| train_loss: 2.01e+00 | test_loss: 2.02e+00 | reg: 0.00e+00 | : 33%|▎| 33/100 [00:57<01:54, 1.70s
| train_loss: 1.78e+00 | test_loss: 1.71e+00 | reg: 0.00e+00 | : 33%|▎| 33/100 [00:59<01:54, 1.70s
| train_loss: 1.78e+00 | test_loss: 1.71e+00 | reg: 0.00e+00 | : 34%|▎| 34/100 [00:59<01:51, 1.68s
| train_loss: 1.66e+00 | test_loss: 1.57e+00 | reg: 0.00e+00 | : 34%|▎| 34/100 [01:00<01:51, 1.68s
| train_loss: 1.66e+00 | test_loss: 1.57e+00 | reg: 0.00e+00 | : 35%|▎| 35/100 [01:00<01:48, 1.66s
| train_loss: 7.41e+00 | test_loss: 7.15e+00 | reg: 0.00e+00 | : 35%|▎| 35/100 [01:02<01:48, 1.66s
| train_loss: 7.41e+00 | test_loss: 7.15e+00 | reg: 0.00e+00 | : 36%|▎| 36/100 [01:02<01:51, 1.74s
| train_loss: 3.13e+00 | test_loss: 3.25e+00 | reg: 0.00e+00 | : 36%|▎| 36/100 [01:05<01:51, 1.74s
| train_loss: 3.13e+00 | test_loss: 3.25e+00 | reg: 0.00e+00 | : 37%|▎| 37/100 [01:05<01:57, 1.87s
| train_loss: 2.09e+00 | test_loss: 2.12e+00 | reg: 0.00e+00 | : 37%|▎| 37/100 [01:06<01:57, 1.87s
| train_loss: 2.09e+00 | test_loss: 2.12e+00 | reg: 0.00e+00 | : 38%|▍| 38/100 [01:06<01:57, 1.90s
| train_loss: 1.80e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | : 38%|▍| 38/100 [01:08<01:57, 1.90s
| train_loss: 1.80e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | : 39%|▍| 39/100 [01:08<01:48, 1.79s
| train_loss: 1.70e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | : 39%|▍| 39/100 [01:10<01:48, 1.79s
| train_loss: 1.70e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | : 40%|▍| 40/100 [01:10<01:44, 1.74s
| train_loss: 9.58e+00 | test_loss: 1.87e+00 | reg: 0.00e+00 | : 40%|▍| 40/100 [01:12<01:44, 1.74s
| train_loss: 9.58e+00 | test_loss: 1.87e+00 | reg: 0.00e+00 | : 41%|▍| 41/100 [01:12<01:46, 1.81s
| train_loss: 9.53e+00 | test_loss: 2.42e+00 | reg: 0.00e+00 | : 41%|▍| 41/100 [01:13<01:46, 1.81s
| train_loss: 9.53e+00 | test_loss: 2.42e+00 | reg: 0.00e+00 | : 42%|▍| 42/100 [01:13<01:42, 1.77s
| train_loss: 9.51e+00 | test_loss: 3.70e+00 | reg: 0.00e+00 | : 42%|▍| 42/100 [01:15<01:42, 1.77s
| train_loss: 9.51e+00 | test_loss: 3.70e+00 | reg: 0.00e+00 | : 43%|▍| 43/100 [01:15<01:35, 1.68s
| train_loss: 8.86e+00 | test_loss: 1.03e+01 | reg: 0.00e+00 | : 43%|▍| 43/100 [01:17<01:35, 1.68s
| train_loss: 8.86e+00 | test_loss: 1.03e+01 | reg: 0.00e+00 | : 44%|▍| 44/100 [01:17<01:37, 1.75s
| train_loss: 7.43e+00 | test_loss: 7.65e+00 | reg: 0.00e+00 | : 44%|▍| 44/100 [01:19<01:37, 1.75s
| train_loss: 7.43e+00 | test_loss: 7.65e+00 | reg: 0.00e+00 | : 45%|▍| 45/100 [01:19<01:42, 1.86s
| train_loss: 7.11e+00 | test_loss: 6.82e+00 | reg: 0.00e+00 | : 45%|▍| 45/100 [01:21<01:42, 1.86s
| train_loss: 7.11e+00 | test_loss: 6.82e+00 | reg: 0.00e+00 | : 46%|▍| 46/100 [01:21<01:41, 1.88s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 46%|▍| 46/100 [01:22<01:41, 1.88s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 47%|▍| 47/100 [01:22<01:34, 1.78s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 47%|▍| 47/100 [01:23<01:34, 1.78s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 48%|▍| 48/100 [01:23<01:19, 1.52s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 48%|▍| 48/100 [01:24<01:19, 1.52s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 49%|▍| 49/100 [01:24<01:07, 1.32s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 49%|▍| 49/100 [01:25<01:07, 1.32s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 50%|▌| 50/100 [01:25<00:57, 1.16s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 50%|▌| 50/100 [01:25<00:57, 1.16s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 51%|▌| 51/100 [01:25<00:47, 1.03i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 51%|▌| 51/100 [01:26<00:47, 1.03i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 52%|▌| 52/100 [01:26<00:41, 1.17i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 52%|▌| 52/100 [01:27<00:41, 1.17i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 53%|▌| 53/100 [01:27<00:37, 1.26i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 53%|▌| 53/100 [01:27<00:37, 1.26i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 54%|▌| 54/100 [01:27<00:33, 1.36i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 54%|▌| 54/100 [01:28<00:33, 1.36i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 55%|▌| 55/100 [01:28<00:31, 1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 55%|▌| 55/100 [01:28<00:31, 1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 56%|▌| 56/100 [01:28<00:30, 1.43i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 56%|▌| 56/100 [01:29<00:30, 1.43i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 57%|▌| 57/100 [01:29<00:31, 1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 57%|▌| 57/100 [01:30<00:31, 1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 58%|▌| 58/100 [01:30<00:31, 1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 58%|▌| 58/100 [01:31<00:31, 1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 59%|▌| 59/100 [01:31<00:29, 1.40i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 59%|▌| 59/100 [01:31<00:29, 1.40i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 60%|▌| 60/100 [01:31<00:26, 1.53i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 60%|▌| 60/100 [01:32<00:26, 1.53i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 61%|▌| 61/100 [01:32<00:24, 1.62i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 61%|▌| 61/100 [01:32<00:24, 1.62i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 62%|▌| 62/100 [01:32<00:23, 1.64i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 62%|▌| 62/100 [01:33<00:23, 1.64i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 63%|▋| 63/100 [01:33<00:21, 1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 63%|▋| 63/100 [01:33<00:21, 1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 64%|▋| 64/100 [01:33<00:20, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 64%|▋| 64/100 [01:34<00:20, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 65%|▋| 65/100 [01:34<00:19, 1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 65%|▋| 65/100 [01:35<00:19, 1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 66%|▋| 66/100 [01:35<00:19, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 66%|▋| 66/100 [01:35<00:19, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 67%|▋| 67/100 [01:35<00:18, 1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 67%|▋| 67/100 [01:36<00:18, 1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 68%|▋| 68/100 [01:36<00:18, 1.77i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 68%|▋| 68/100 [01:36<00:18, 1.77i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 69%|▋| 69/100 [01:36<00:17, 1.74i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 69%|▋| 69/100 [01:37<00:17, 1.74i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 70%|▋| 70/100 [01:37<00:17, 1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 70%|▋| 70/100 [01:37<00:17, 1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 71%|▋| 71/100 [01:37<00:17, 1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 71%|▋| 71/100 [01:38<00:17, 1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 72%|▋| 72/100 [01:38<00:16, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 72%|▋| 72/100 [01:39<00:16, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 73%|▋| 73/100 [01:39<00:15, 1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 73%|▋| 73/100 [01:39<00:15, 1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 74%|▋| 74/100 [01:39<00:15, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 74%|▋| 74/100 [01:40<00:15, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 75%|▊| 75/100 [01:40<00:14, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 75%|▊| 75/100 [01:40<00:14, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 76%|▊| 76/100 [01:40<00:14, 1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 76%|▊| 76/100 [01:41<00:14, 1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 77%|▊| 77/100 [01:41<00:14, 1.63i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 77%|▊| 77/100 [01:42<00:14, 1.63i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 78%|▊| 78/100 [01:42<00:14, 1.54i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 78%|▊| 78/100 [01:43<00:14, 1.54i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 79%|▊| 79/100 [01:43<00:14, 1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 79%|▊| 79/100 [01:43<00:14, 1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 80%|▊| 80/100 [01:43<00:13, 1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 80%|▊| 80/100 [01:44<00:13, 1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 81%|▊| 81/100 [01:44<00:12, 1.49i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 81%|▊| 81/100 [01:44<00:12, 1.49i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 82%|▊| 82/100 [01:44<00:11, 1.55i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 82%|▊| 82/100 [01:45<00:11, 1.55i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 83%|▊| 83/100 [01:45<00:10, 1.58i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 83%|▊| 83/100 [01:46<00:10, 1.58i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 84%|▊| 84/100 [01:46<00:10, 1.56i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 84%|▊| 84/100 [01:46<00:10, 1.56i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 85%|▊| 85/100 [01:46<00:09, 1.65i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 85%|▊| 85/100 [01:47<00:09, 1.65i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 86%|▊| 86/100 [01:47<00:08, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 86%|▊| 86/100 [01:47<00:08, 1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 87%|▊| 87/100 [01:47<00:07, 1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 87%|▊| 87/100 [01:48<00:07, 1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 88%|▉| 88/100 [01:48<00:06, 1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 88%|▉| 88/100 [01:48<00:06, 1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 89%|▉| 89/100 [01:48<00:06, 1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 89%|▉| 89/100 [01:49<00:06, 1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 90%|▉| 90/100 [01:49<00:05, 1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 90%|▉| 90/100 [01:50<00:05, 1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 91%|▉| 91/100 [01:50<00:05, 1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 91%|▉| 91/100 [01:50<00:05, 1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 92%|▉| 92/100 [01:50<00:04, 1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 92%|▉| 92/100 [01:51<00:04, 1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 93%|▉| 93/100 [01:51<00:03, 1.85i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 93%|▉| 93/100 [01:51<00:03, 1.85i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 94%|▉| 94/100 [01:51<00:03, 1.79i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 94%|▉| 94/100 [01:52<00:03, 1.79i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 95%|▉| 95/100 [01:52<00:02, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 95%|▉| 95/100 [01:52<00:02, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 96%|▉| 96/100 [01:52<00:02, 1.72i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 96%|▉| 96/100 [01:53<00:02, 1.72i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 97%|▉| 97/100 [01:53<00:01, 1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 97%|▉| 97/100 [01:53<00:01, 1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 98%|▉| 98/100 [01:53<00:01, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 98%|▉| 98/100 [01:54<00:01, 1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | : 99%|▉| 99/100 [01:54<00:01, 1.16s
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-15-7af2f0c35dea>](https://localhost:8080/#) in <cell line: 48>()
46
47 # KAN model training using the fit method
---> 48 results = model.fit(dataset,
49 metrics=(train_acc, test_acc),
50 opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())
[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in fit(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, start_grid_update_step, stop_grid_update_step, batch, metrics, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, singularity_avoiding, y_th, reg_metric, display_metrics)
938 if _ == steps-1 and old_save_act:
939 #self.save_act = True
--> 940 self.recover_save_act_in_fit()
941
942 train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
TypeError: MultKAN.recover_save_act_in_fit() missing 1 required positional argument: 'old_save_act'
Also how would one go about doing feature selection with KAN?
Also how would one go about doing feature selection with KAN?
I have wrote code for feature selection, I will share it
Exciting! Where will you post it?
an example of how to do feature selection using KANs: https://github.com/KindXiaoming/pykan/blob/master/tutorials/Interp_4_feature_attribution.ipynb
It just doesn't like my dataset for some reason I've tried all kinds of things but i get these error - any idea why?:
from kan import *
import numpy as np
# Assuming X_train, y_train, X_val, y_val, X_test, y_test are already defined
# Combine training and validation sets
X_train_combined = np.concatenate((X_train, X_val), axis=0)
y_train_combined = np.concatenate((y_train, y_val), axis=0)
# Ensure the data is of type float
X_train_combined = X_train_combined.astype(float)
y_train_combined = y_train_combined.astype(float)
X_test = X_test.astype(float)
y_test = y_test.astype(float)
# Print shapes and types to debug
print(f"train_input shape: {X_train_combined.shape}, dtype: {X_train_combined.dtype}")
print(f"train_output shape: {y_train_combined.shape}, dtype: {y_train_combined.dtype}")
print(f"test_input shape: {X_test.shape}, dtype: {X_test.dtype}")
print(f"test_output shape: {y_test.shape}, dtype: {y_test.dtype}")
# Create the dataset in the expected format by mimicking the example
dataset = {
'train_input': X_train_combined,
'train_output': y_train_combined,
'test_input': X_test,
'test_output': y_test
}
# Ensure the structure matches the expected format
print(f"dataset keys: {dataset.keys()}")
for key in dataset:
print(f"{key} shape: {dataset[key].shape}, dtype: {dataset[key].dtype}")
# Create and train the KAN model, same as the example
model = KAN(width=[X_train_combined.shape[1], 10, 10, 1], seed=2)
model.fit(dataset, steps=50, lamb=1e-3, reg_metric='edge_forward_n')
train_input shape: (2134, 3854), dtype: float64
train_output shape: (2134,), dtype: float64
test_input shape: (251, 3854), dtype: float64
test_output shape: (251,), dtype: float64
dataset keys: dict_keys(['train_input', 'train_output', 'test_input', 'test_output'])
train_input shape: (2134, 3854), dtype: float64
train_output shape: (2134,), dtype: float64
test_input shape: (251, 3854), dtype: float64
test_output shape: (251,), dtype: float64
checkpoint directory created: ./model
saving model version 0.0
description: 0%| | 0/50 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-39-8d89e215e034>](https://localhost:8080/#) in <cell line: 37>()
35 # Create and train the KAN model, same as the example
36 model = KAN(width=[X_train_combined.shape[1], 10, 10, 1], seed=2)
---> 37 model.fit(dataset, steps=50, lamb=1e-3, reg_metric='edge_forward_n')
3 frames
[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in fit(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, start_grid_update_step, stop_grid_update_step, batch, metrics, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, singularity_avoiding, y_th, reg_metric, display_metrics)
944
945 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
--> 946 self.update_grid(dataset['train_input'][train_id])
947
948 if opt == "LBFGS":
[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in update_grid(self, x)
356
357 def update_grid(self, x):
--> 358 self.update_grid_from_samples(x)
359
360 def initialize_grid_from_another_model(self, model, x):
[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in update_grid_from_samples(self, x)
352 def update_grid_from_samples(self, x):
353 for l in range(self.depth):
--> 354 self.get_act(x)
355 self.act_fun[l].update_grid_from_samples(self.acts[l])
356
[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in get_act(self, x)
1730 if isinstance(x, dict):
1731 x = x['train_input']
-> 1732 if x == None:
1733 if self.cache_data != None:
1734 x = self.cache_data
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
@apavlo89 Here is my code for feature importance:
# Modeli oluşturma
model = KAN(width=[4,7,1], grid=3, k=11)
# Modeli eğitme
results = model.train({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
opt="LBFGS", steps=150, loss_fn=torch.nn.MSELoss())
# Validaionları verdik her ne kadar 'test_input': val_input yazsa da çünkü KAN model train dataset dict'ında test_input adlı bir dict istiyo
layer_1 = model.acts_scale[0].detach().numpy()
layer_2 = model.acts_scale[1].detach().numpy()
columns = list(X.columns)
importance_values = np.dot(layer_1.T, layer_2.T).flatten()
sorted_indices = np.argsort(importance_values)[::1]
sorted_importance_values = importance_values[sorted_indices]
sorted_columns = np.array(columns)[sorted_indices]
# Bar plot oluşturma
plt.figure(figsize=(12, 6))
plt.barh(sorted_columns, sorted_importance_values, color='skyblue')
plt.xlabel('Feature Importance')
plt.ylabel('Features')
plt.title('Feature Importance of KAN')
plt.show()
You will see the barh plot
Hi @apavlo89 , looks like your data are numpy array, could you please try convert them to torch tensors?
I fixed the numpy array to torch sensors but now i get this
@KindXiaoming @seyidcemkarakas
import torch
from kan import KAN
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# Check the shape of your data
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
print(f"Shape of X_test: {X_test.shape}")
# Converting data to NumPy arrays
X_train_np = X_train.to_numpy()
y_train_np = y_train.to_numpy()
X_val_np = X_val.to_numpy()
y_val_np = y_val.to_numpy()
X_test_np = X_test.to_numpy()
y_test_np = y_test.to_numpy()
# Converting data to Torch tensor
train_input = torch.tensor(X_train_np, dtype=torch.float32)
train_label = torch.tensor(y_train_np, dtype=torch.long)
val_input = torch.tensor(X_val_np, dtype=torch.float32)
val_label = torch.tensor(y_val_np, dtype=torch.long)
test_input = torch.tensor(X_test_np, dtype=torch.float32)
test_label = torch.tensor(y_test_np, dtype=torch.long)
dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
}
# Ensure the input width matches the number of features
input_width = X_train_np.shape[1]
# Create model
model = KAN(width=[input_width, 2], grid=10, k=2)
# Workaround for the issue by setting save_act to False
model.save_act = False
# Functions for getting accuracy scores while training
def train_acc():
preds = torch.argmax(model(dataset['train_input']), dim=1)
return torch.mean((preds == dataset['train_label']).float())
def test_acc():
preds = torch.argmax(model(dataset['test_input']), dim=1)
return torch.mean((preds == dataset['test_label']).float())
# KAN model training using the fit method
results = model.fit({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
opt="LBFGS", steps=150, loss_fn=torch.nn.CrossEntropyLoss())
layer_1 = model.acts_scale[0].detach().numpy()
layer_2 = model.acts_scale[1].detach().numpy()
columns = list(X_train.columns)
importance_values = np.dot(layer_1.T, layer_2.T).flatten()
sorted_indices = np.argsort(importance_values)[::1]
sorted_importance_values = importance_values[sorted_indices]
sorted_columns = np.array(columns)[sorted_indices]
# Create the bar plot
plt.figure(figsize=(12, 6))
plt.barh(sorted_columns, sorted_importance_values, color='skyblue')
plt.xlabel('Feature Importance')
plt.ylabel('Features')
plt.title('Feature Importance of KAN')
plt.show()
Output and error :/
Shape of X_train: (1886, 171)
Shape of X_val: (251, 171)
Shape of X_test: (253, 171)
checkpoint directory created: ./model
saving model version 0.0
| train_loss: nan | test_loss: nan | reg: 0.00e+00 | : 100%|██████| 150/150 [01:42<00:00, 1.46it/s]saving model version 0.1
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
[<ipython-input-12-674e38bca30b>](https://localhost:8080/#) in <cell line: 63>()
61 opt="LBFGS", steps=150, loss_fn=torch.nn.CrossEntropyLoss())
62
---> 63 layer_1 = model.acts_scale[0].detach().numpy()
64 layer_2 = model.acts_scale[1].detach().numpy()
65
IndexError: list index out of range
try print(model.acts_scale), is it None? If yes, what is model.save_act (True or False)? If you try a dataset which does not return nan loss, does this problem persist?
If i remove the feature selection code then it runs just fine. the issue appears in the feature selection part:
import torch
from kan import KAN
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# Check the shape of your data
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
print(f"Shape of X_test: {X_test.shape}")
# Converting data to NumPy arrays
X_train_np = X_train.to_numpy()
y_train_np = y_train.to_numpy()
X_val_np = X_val.to_numpy()
y_val_np = y_val.to_numpy()
X_test_np = X_test.to_numpy()
y_test_np = y_test.to_numpy()
# Converting data to Torch tensor
train_input = torch.tensor(X_train_np, dtype=torch.float32)
train_label = torch.tensor(y_train_np, dtype=torch.long)
val_input = torch.tensor(X_val_np, dtype=torch.float32)
val_label = torch.tensor(y_val_np, dtype=torch.long)
test_input = torch.tensor(X_test_np, dtype=torch.float32)
test_label = torch.tensor(y_test_np, dtype=torch.long)
dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
}
# Ensure the input width matches the number of features
input_width = X_train_np.shape[1]
# Create model
model = KAN(width=[input_width, 2], grid=4, k=1)
# Workaround for the issue by setting save_act to False
model.save_act = False
# Functions for getting accuracy scores while training
def train_acc():
preds = torch.argmax(model(dataset['train_input']), dim=1)
return torch.mean((preds == dataset['train_label']).float())
def test_acc():
preds = torch.argmax(model(dataset['test_input']), dim=1)
return torch.mean((preds == dataset['test_label']).float())
# KAN model training using the fit method
results = model.fit({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())
#opt can be Adam or LBFGS
print(model.acts_scale)
#####################################feature selection###########################################
layer_1 = model.acts_scale[0].detach().numpy()
layer_2 = model.acts_scale[1].detach().numpy()
columns = list(X_train.columns)
importance_values = np.dot(layer_1.T, layer_2.T).flatten()
sorted_indices = np.argsort(importance_values)[::1]
sorted_importance_values = importance_values[sorted_indices]
sorted_columns = np.array(columns)[sorted_indices]
## Create the bar plot
plt.figure(figsize=(12, 6))
plt.barh(sorted_columns, sorted_importance_values, color='skyblue')
plt.xlabel('Feature Importance')
plt.ylabel('Features')
plt.title('Feature Importance of KAN')
plt.show()
##################################################################################################
# Predictions of train val and test datasets
test_preds = torch.argmax(model.forward(test_input).detach(),dim=1)
test_labels = test_label
train_preds = torch.argmax(model.forward(train_input).detach(),dim=1)
train_labels = train_label
val_preds = torch.argmax(model.forward(val_input).detach(),dim=1)
val_labels = val_label
# Evaluate metrics
print("Train ACC:", accuracy_score(train_labels.numpy(), train_preds.numpy()))
print("Val ACC:", accuracy_score(val_labels.numpy(), val_preds.numpy()))
print("Test ACC:", accuracy_score(test_labels.numpy(), test_preds.numpy()))
# Plotting KAN network
#model.plot(scale=input_width)
print(results)
# Learning curve based on ACC and LOSS
plt.figure(figsize=(10, 5))
plt.plot(results["train_acc"], label='Training Accuracy')
plt.plot(results["test_acc"], label='Val Accuracy')
plt.plot(results["train_loss"], label='Training Loss')
plt.plot(results["test_loss"], label='Val Loss')
plt.title('Training and Val Accuracy over Iterations')
plt.xlabel('Iteration')
plt.ylabel('Accuracy & Loss')
plt.legend()
plt.grid(True)
plt.show()
`Shape of X_train: (1886, 171) Shape of X_val: (251, 171) Shape of X_test: (253, 171) checkpoint directory created: ./model saving model version 0.0 | train_loss: 6.57e-01 | test_loss: 9.09e-01 | reg: 0.00e+00 | : 100%|█| 100/100 [01:40<00:00, 1.01saving model version 0.1 []
IndexError Traceback (most recent call last)
IndexError: list index out of range`