pykan
pykan copied to clipboard
How to apply KAN on Computer Vision
Hi Author,
Thank you for your great work. I am wondering if we can apply this network on Vision based task such as classification/detection/segmentation, etc.
Thank you for your help.
Update on this topic:
I write a short notebook to test traiing and evaluation on MNIST dataset. And If we want to apply KAN on 2D or 3D task, One possible way is to change KANlayer inherits nn.Conv2d
?
Here is the screen shot:
And here is the Traceback:
description: 0%| | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 results = model.train(dataset, opt="LBFGS", steps=20, loss_fn=torch.nn.CrossEntropyLoss());
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:913, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
910 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
912 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 913 self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
916 if opt == "LBFGS":
917 optimizer.step(closure)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:242, in KAN.update_grid_from_samples(self, x)
219 '''
220 update grid from samples
221
(...)
239 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
240 '''
241 for l in range(self.depth):
--> 242 self.forward(x)
243 self.act_fun[l].update_grid_from_samples(self.acts[l])
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:313, in KAN.forward(self, x)
308 self.acts.append(x) # acts shape: (batch, width[l])
311 for l in range(self.depth):
--> 313 x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
315 if self.symbolic_enabled == True:
316 x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KANLayer.py:172, in KANLayer.forward(self, x)
170 batch = x.shape[0]
171 # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
--> 172 x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim,).to(self.device)).reshape(batch, self.size).permute(1,0)
173 preacts = x.permute(1,0).clone().reshape(batch, self.out_dim, self.in_dim)
174 base = self.base_fun(x).permute(1,0) # shape (batch, size)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/functional.py:385, in einsum(*args)
380 return einsum(equation, *_operands)
382 if len(operands) <= 2 or not opt_einsum.enabled:
383 # the path for contracting 0 or 1 time(s) is already optimized
384 # or the user has disabled using opt_einsum
--> 385 return _VF.einsum(equation, operands) # type: ignore[attr-defined]
387 path = None
388 if opt_einsum.is_available():
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given
yeah I think KANs, as they are right now, cannot handle convolution. It seems reasonable to defineConvKAN
layers. Given the current implementation, the only thing you can do with vision tasks is flattening a whole image into a vector, totally abandoning spatial information (which is not good, that's why I think extra development is needed).
As a quick cute example, you may try play with KAN as if playing with an MLP for MNIST.
Please make sure input data have shape [data size, indim]
, indim=784. Also, the input dimension of KAN should be 784, and output should be 10. So e.g., these KANs are valid for MNIST: KAN(width=[784,5,10])
or KAN(width=[784,5,5,10])
. Also you may want to include say batch=128
in model.train()
to train on batches rather than the whole dataset (which is fine, but I worry it might run too slowly on cpu haha).
As a quick cute example, you may try play with KAN as if playing with an MLP for MNIST.
Please make sure input data have shape
[data size, indim]
, indim=784. Also, the input dimension of KAN should be 784, and output should be 10. So e.g., these KANs are valid for MNIST:KAN(width=[784,5,10])
orKAN(width=[784,5,5,10])
. Also you may want to include saybatch=128
inmodel.train()
to train on batches rather than the whole dataset (which is fine, but I worry it might run too slowly on cpu haha).
Thank for the quick reply.
It did work with
model = KAN(width=[784,5,5,10], grid=3, k=3).to(device)
and
dataset['train_input'] = torch.flatten(train_dataset.data, start_dim=1).to(device)
dataset['test_input'] = torch.flatten(test_dataset.data, start_dim=1).to(device)
Now training can work on device cpu (slow as expected). But it will raise error when using Apple Chip with device mps
results = model.train(dataset, opt="LBFGS", steps=20, loss_fn=torch.nn.CrossEntropyLoss(), batch=128, device='cpu');
But anyway, faltten the image into 1D is not a good idea in general. A VisionKAN
or KAN_Conv2d
need to be implemented. LOL.
Nice! yes, there's still some issue with GPU training. Looking forward to your new development :-)
Nice! yes, there's still some issue with GPU training. Looking forward to your new development :-)
Yeah. about GPU Training. I might need to use CUDA first. For MPS, it will raised this error for the Classificaion Example:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!
I already make datasets and model on mps
device:
dataset['train_input'] = torch.from_numpy(train_input).to(torch.float32).to(device)
dataset['test_input'] = torch.from_numpy(test_input).to(torch.float32).to(device)
dataset['train_label'] = torch.from_numpy(train_label[:,None]).to(torch.float32).to(device)
dataset['test_label'] = torch.from_numpy(test_label[:,None]).to(torch.float32).to(device)
model = KAN(width=[2,1], grid=3, k=3).to(torch.float32).to(device)
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), device=device);
The full traceback is:
description: 0%| | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[9], line 9
6 def test_acc():
7 return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
----> 9 results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), device=device);
10 results['train_acc'][-1], results['test_acc'][-1]
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:913, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
910 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
912 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 913 self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
916 if opt == "LBFGS":
917 optimizer.step(closure)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:242, in KAN.update_grid_from_samples(self, x)
219 '''
220 update grid from samples
221
(...)
239 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
240 '''
241 for l in range(self.depth):
--> 242 self.forward(x)
243 self.act_fun[l].update_grid_from_samples(self.acts[l])
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:313, in KAN.forward(self, x)
308 self.acts.append(x) # acts shape: (batch, width[l])
311 for l in range(self.depth):
--> 313 x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
315 if self.symbolic_enabled == True:
316 x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KANLayer.py:172, in KANLayer.forward(self, x)
170 batch = x.shape[0]
171 # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
--> 172 x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim,).to(self.device)).reshape(batch, self.size).permute(1,0)
173 preacts = x.permute(1,0).clone().reshape(batch, self.out_dim, self.in_dim)
174 base = self.base_fun(x).permute(1,0) # shape (batch, size)
File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/functional.py:385, in einsum(*args)
380 return einsum(equation, *_operands)
382 if len(operands) <= 2 or not opt_einsum.enabled:
383 # the path for contracting 0 or 1 time(s) is already optimized
384 # or the user has disabled using opt_einsum
--> 385 return _VF.einsum(equation, operands) # type: ignore[attr-defined]
387 path = None
388 if opt_einsum.is_available():
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!
It turns out that replacing model = KAN(width=[784,5,5,10], grid=3, k=3).to(device)
by model = KAN(width=[784,5,5,10], grid=3, k=3, device=device)
does the trick for me! Here is a full example training on mps for reference:
from kan import *
from tensorflow import keras
device = "mps"
model = KAN(width=[7*7, 5, 5, 128], grid=3, k=3, device=device)
(X_train,y_train),(X_test,y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
# downsample to 7x7
X_train = np.array([cv2.resize(x, (7,7)) for x in X_train])
X_test = np.array([cv2.resize(x, (7,7)) for x in X_test])
dataset = {}
dataset['train_input'] = torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
dataset['train_label'] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset['test_input'] = torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
dataset['test_label'] = torch.from_numpy(y_test).to(torch.float32).to(device)
model.train(dataset, opt="LBFGS", steps=20, batch=128)
Were you able to actually train on MNIST using a flat dataset?
No unfortunately, @WuZhuoran's comments make sense, I was solely making a point about getting training on mps to work
Hi everybody, please let me know if anybody of you successfully applied KANs to any Computer vision tasks? or anybody integrated it with CNNs ?
Also let me know, how can I integrate and train KAN layers with CNNs after flattening the tensors?? Anybody please share the code.
Very interesting. I'd really like to see a direct comparison between KAN and MLP in CNN architecture.
I tried something like this, but it didn’t work loss decreases slowly
import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt
train_data = torchvision.datasets.MNIST(
root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
root="./mnist_data", train=False, download=True, transform=None
)
valid_labels = [0, 1, 2]
X_train = []
y_train = []
for pil_img, label in train_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_train.append(x.astype(float))
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")
X_test = []
y_test = []
for pil_img, label in test_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_test.append(x.astype(float))
y_test.append(label)
X_test = np.array(X_test)
y_test = np.array(y_test)
X_test = (X_test - mean) / std
X_train = (X_train - mean) / std
device = "cpu"
model = KAN(width=[x.shape[0]**2, 20, 20, len(valid_labels)], grid=3, k=3, device=device)
dataset = {}
dataset["train_input"] = (
torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset["test_input"] = (
torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).to(torch.float32).to(device)
result = model.train(dataset, opt="Adam", steps=100, lr=0.1, batch=len(valid_labels), device=device)
plt.plot(result['train_loss'], label="train_loss")
plt.plot(result['test_loss'], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.show()
My experience with MNIST is that a 2-Layer KAN with an extremely small (say 5 or 10) hidden neurons is enough to train MNIST (but maybe my impression was from accuracy), i.e., KAN(width=[49, 10, 3])
in your case. It's likely that accuracies are high but losses are high.
So please try computing acc as well. You can refer to this tutorial to see how to do this. Basically, it's something like
def train_acc():
return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())
def test_acc():
return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
results['train_acc'][-1], results['test_acc'][-1]
Very interesting. I'd really like to see a direct comparison between KAN and MLP in CNN architecture.
I am also willing to do so, but don't know how to integrate and train simultaneously
Hi, here is a working MNIST example using CUDA. Reusing some code from above. It may be verbose and far from optimal.
I get around 73% test accuracy in about 1 minute. Playing with the network size may improve the performance.
import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt
def train_acc():
# model for some reason is on cpu only here, something about KAN's implementation
try:
arg = (
torch.argmax(model(dataset["train_input"]), dim=1) == dataset["train_label"]
)
except:
arg = torch.argmax(model(dataset["train_input"].to("cpu")), dim=1) == dataset[
"train_label"
].to("cpu")
return torch.mean(arg.float())
def test_acc():
try:
arg = torch.argmax(model(dataset["test_input"]), dim=1) == dataset["test_label"]
except:
arg = torch.argmax(model(dataset["test_input"].to("cpu")), dim=1) == dataset[
"test_label"
].to("cpu")
return torch.mean(arg.float())
train_data = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=None
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")
valid_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
X_train = []
y_train = []
for pil_img, label in train_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_train.append(x.astype(float))
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")
X_test = []
y_test = []
for pil_img, label in test_data:
if label in valid_labels:
x = np.array(pil_img)
x = cv2.resize(x, (7, 7))
X_test.append(x.astype(float))
y_test.append(label)
X_test = np.array(X_test)
y_test = np.array(y_test)
X_test = (X_test - mean) / std
X_train = (X_train - mean) / std
model = KAN(width=[x.shape[0] ** 2, 20, len(valid_labels)], grid=5, k=3, device=device)
dataset = {}
dataset["train_input"] = (
torch.flatten(torch.from_numpy(X_train), start_dim=1).long().to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).long().to(device)
dataset["test_input"] = (
torch.flatten(torch.from_numpy(X_test), start_dim=1).long().to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).long().to(device)
loss_fn = torch.nn.CrossEntropyLoss()
result = model.train(
dataset,
opt="Adam",
steps=50,
lr=0.1,
batch=512,
# metrics=(
# train_acc,
# test_acc,
# ), # this is the slower step, so its better to evaluate it after training
loss_fn=loss_fn,
# device=device,
)
acc = test_acc()
print(f"Test accuracy: {acc.item()}")
plt.plot(result["train_loss"], label="train_loss")
plt.plot(result["test_loss"], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.savefig("loss.png")
Hi, here's my attempted code, its going to take about 30s to run on CUDA and get about 83% accuracy.
import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt
def preprocess_data(data):
images = []
labels = []
for img, label in data:
img = cv2.resize(np.array(img), (7, 7))
img = img.flatten() / 255.0
images.append(img)
labels.append(label)
return np.array(images), np.array(labels)
train_data = torchvision.datasets.MNIST(
root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
root="./mnist_data", train=False, download=True, transform=None
)
train_images, train_labels = preprocess_data(train_data)
test_images, test_labels = preprocess_data(test_data)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
dataset = {
"train_input": torch.from_numpy(train_images).float().to(device),
"train_label": torch.from_numpy(train_labels).to(device),
"test_input": torch.from_numpy(test_images).float().to("cpu"),
"test_label": torch.from_numpy(test_labels).to("cpu"),
}
model = KAN(width=[49, 10, 10], device=device)
results = model.train(
dataset,
opt="Adam",
lr=0.05,
steps=100,
batch=512,
loss_fn=torch.nn.CrossEntropyLoss(),
)
torch.save(model.state_dict(), "kan.pth")
del model
model = KAN(width=[49, 10, 10], device="cpu")
model.load_state_dict(torch.load("kan.pth"))
def test_acc():
with torch.no_grad():
predictions = torch.argmax(model(dataset["test_input"]), dim=1)
correct = (predictions == dataset["test_label"]).float()
accuracy = correct.mean()
return accuracy
acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")
plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()
plt.savefig("kan.png")
I also think a pure KAN implementation for computer vision does not look very promising due to not making any use of spatial locality. I think an interesting idea could be to define a KAN based 2d convolution layer that replaces the 2d kernel with a spline (or full KAN layer?) working on flattened 2d patches of similar sizes to the regular kernels. At small enough kernel sizes (say 3x3) the loss in fine-grained spatial locality might not as detrimental to model performance.
Compared KAN with 5x smaller MLP. In 10 epochs, KAN reached 91% acc whereas MLP reached 97%. KAN loss goes down more slowly than that of MLP.
# implementation from https://github.com/Blealtan/efficient-kan
class EKAN(nn.Module):
pass
# a simple MLP model
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x):
return self.layers(x)
# Data preprocessing
transform = transforms.Compose([
transforms.ToTensor(), # Convert images to PyTorch tensors and scale to [0,1]
transforms.Normalize((0.5,), (0.5,)) # Normalize to mean=0.5, std=0.5
])
# Load the datasets
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize model, loss function, and optimizer
model = EKAN([28*28, 64, 10]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training the model
def train_model(num_epochs):
model.train()
for epoch in range(num_epochs):
total_loss = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
images = images.view(images.shape[0], -1)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')
train_model(10)
# Testing the model
def test_model():
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
images = images.view(images.shape[0], -1)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')
test_model()
Were you able to actually train on MNIST using a flat dataset?
Hi,
I did train on MNIST dataset but it is just flatten the image into 1D vector. I think we still need more development on Computer Vision Task.
I also think a pure KAN implementation for computer vision does not look very promising due to not making any use of spatial locality. I think an interesting idea could be to define a KAN based 2d convolution layer that replaces the 2d kernel with a spline (or full KAN layer?) working on flattened 2d patches of similar sizes to the regular kernels. At small enough kernel sizes (say 3x3) the loss in fine-grained spatial locality might not as detrimental to model performance.
Good point on 2D conv layer. One possible is to define a kan_conv2d
layer. then we can build KAN3D
or KAN
directly with different conv2d layer.
Currently all the test on images (such as MNIST) that are processing the data into 1D vector which is not very useful.
before use conv2d, what about use VAE latent space ,train KAN MNIST use VAE encode output as KAN input ?
According to my experiments, the modified version of Kan outperformed MLP with the same shape on the MNIST dataset , both 768 64 10, using the efficient kan code above with some tweaks.
this is kan+
this is mlp
@xiaol use Handwritten Sequence Trajectories?
Also let me know, how can I integrate and train KAN layers with CNNs after flattening the tensors?? Anybody please share the code.
I tried to replace MLP with KAN in CNN models, and the performances are close to each other.
https://github.com/juntaoJianggavin/kan-cifar10/tree/main
how can i build a Conv-KAN ? how do i integrate convolotion into KAN ?
I used a 'linearized version' of nn.Conv2d
using nn.Unfold
and a reshape to build a KANConv2d
I'm not completely sure whether it makes sense and I don't think it's efficient at all but you may check it out
I also tried a simple implementation of LeNet but with KAN as classifier: https://github.com/SimoSbara/kan-lenet
KAN receives flatten data from convolution.
I think a combination of them(nn.Linear,KAN) works fine for the MNIST task:
import torchvision
import torch
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from kan import KAN
import tqdm
transform = transforms.Compose(
[transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
]
)
trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=500,
shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./MNIST', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=500,
shuffle=False, num_workers=2)
print(len(trainset),len(testset))
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(28*28,64).cuda()
self.kan = KAN(width=[64,16,10], grid=5, k=3, seed=0,device='cuda:0')
def forward(self,x):
x = self.linear(x)
out = self.kan(x)
return out
net = Net().cuda()
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.002,)
for epoch in range(4): # loop over the dataset multiple times
running_loss = 0.0
for i, data in tqdm.tqdm(enumerate(trainloader, 0)):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# print('predict.size=',pred.size())
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
x = inputs.view(inputs.size(0),-1).cuda()
outputs = net(x)
loss = criterion(outputs, labels.cuda())
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i %100 == 99: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
correct = 0
total = 0
# net.eval()
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
inputs, labels = data
# calculate outputs by running images through the network
x = inputs.view(inputs.size(0),-1).cuda()
outputs = net(x)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels.cuda()).sum().item()
print(f'epoch {epoch} Accuracy of the network on the 10000 test images: {100 * correct // total} %')
# net.train()
print('Finished Training')
After 4 epochs training, acc comes to 96%, the logs looks like:
60000 10000
99it [00:52, 1.90it/s][1, 100] loss: 0.028
120it [01:03, 1.89it/s]
[1, 120] loss: 0.002
epoch 0 Accuracy of the network on the 10000 test images: 93 %
99it [00:51, 1.95it/s][2, 100] loss: 0.010
120it [01:02, 1.91it/s]
[2, 120] loss: 0.002
epoch 1 Accuracy of the network on the 10000 test images: 95 %
99it [00:51, 1.91it/s][3, 100] loss: 0.006
120it [01:02, 1.93it/s]
[3, 120] loss: 0.001
epoch 2 Accuracy of the network on the 10000 test images: 95 %
99it [00:51, 1.91it/s][4, 100] loss: 0.005
120it [01:02, 1.93it/s]
[4, 120] loss: 0.001
epoch 3 Accuracy of the network on the 10000 test images: 96 %
Finished Training
I think a combination of them(nn.Linear,KAN) works fine for the MNIST task:
import torchvision import torch from torchvision import transforms import torch.nn as nn import torch.nn.functional as F from kan import KAN import tqdm transform = transforms.Compose( [transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ] ) trainset = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=500, shuffle=True, num_workers=2) testset = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=500, shuffle=False, num_workers=2) print(len(trainset),len(testset)) class Net(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(28*28,64).cuda() self.kan = KAN(width=[64,16,10], grid=5, k=3, seed=0,device='cuda:0') def forward(self,x): x = self.linear(x) out = self.kan(x) return out net = Net().cuda() import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.002,) for epoch in range(4): # loop over the dataset multiple times running_loss = 0.0 for i, data in tqdm.tqdm(enumerate(trainloader, 0)): # get the inputs; data is a list of [inputs, labels] inputs, labels = data # print('predict.size=',pred.size()) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize x = inputs.view(inputs.size(0),-1).cuda() outputs = net(x) loss = criterion(outputs, labels.cuda()) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i %100 == 99: # print every 2000 mini-batches print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}') running_loss = 0.0 print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}') correct = 0 total = 0 # net.eval() # since we're not training, we don't need to calculate the gradients for our outputs with torch.no_grad(): for data in testloader: inputs, labels = data # calculate outputs by running images through the network x = inputs.view(inputs.size(0),-1).cuda() outputs = net(x) # the class with the highest energy is what we choose as prediction _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels.cuda()).sum().item() print(f'epoch {epoch} Accuracy of the network on the 10000 test images: {100 * correct // total} %') # net.train() print('Finished Training')
After 4 epochs training, acc comes to 96%, the logs looks like:
60000 10000 99it [00:52, 1.90it/s][1, 100] loss: 0.028 120it [01:03, 1.89it/s] [1, 120] loss: 0.002 epoch 0 Accuracy of the network on the 10000 test images: 93 % 99it [00:51, 1.95it/s][2, 100] loss: 0.010 120it [01:02, 1.91it/s] [2, 120] loss: 0.002 epoch 1 Accuracy of the network on the 10000 test images: 95 % 99it [00:51, 1.91it/s][3, 100] loss: 0.006 120it [01:02, 1.93it/s] [3, 120] loss: 0.001 epoch 2 Accuracy of the network on the 10000 test images: 95 % 99it [00:51, 1.91it/s][4, 100] loss: 0.005 120it [01:02, 1.93it/s] [4, 120] loss: 0.001 epoch 3 Accuracy of the network on the 10000 test images: 96 % Finished Training
In comparison with MLP its a good improvement. Although in real cases the convolution gives real robustness in OCR applications.
It would be nice to have a peformance benchmark for bigger nets where kan replaces mlp.