pykan icon indicating copy to clipboard operation
pykan copied to clipboard

A Tensorflow2 implementation of KANs

Open ZPZhou-lab opened this issue 9 months ago • 0 comments

Hi all, here is a simple Tensorflow2 implementation tfkan of KANs for those tensorflow users who want to have a try.

I read the paper and the several implementations of KANs (including origin pykan, efficient-kan) carefully and hope my implementation can help you understand it.

I have split the module and made it preserve the tensorflow style as much as possible. Now you can combine it as an independent layer with any other tensorflow module and it should be compatible with other APIs (like model.compile(), model.fit(), model.predict()). This is more friendly for users who want to quickly use KAN in actual machine learning tasks.

Quick Start

You can build your model use Sequential():

import tensorflow as tf
from tfkan.layers import DenseKAN
# create model using KAN
model = tf.keras.models.Sequential([
    DenseKAN(4),
    DenseKAN(1)
])
model.build(input_shape=(None, 10))

and then use model.summary() to see the model structure and its trainable parameters

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_kan (DenseKAN)        (None, 4)                 360       
                                                                 
 dense_kan_1 (DenseKAN)      (None, 1)                 36        
                                                                 
=================================================================
Total params: 396 (1.55 KB)
Trainable params: 396 (1.55 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

Layers

So far, I have implemented a dense layer DenseKAN and a 2D convolutional layer Conv2DKAN, and provided a demo of regression task and fashion-mnist classification tasks. I believe this can help with most of the testing in the early stages of exploring KAN, rather than directly applying KAN to large-scale models like GPT

Abount Grid Update

The grid adaptive update is an important feature mentioned in KANs paper. In this tensorflow implementation of KANs, each KAN layer has a method self.update_grid_from_samples(...) used to implement this feature. You can call it in custom training logic or use Tensorflow Callbacks, see demo grid_update_demo

Difference:

Unlike the implementation by the pykan and efficient-kan, tfkan only provides independent layers rather than a complete model class and corresponding training interface API. Users need to manually build the model and define training logic, so tfkan may be more suitable for TensorFlow players who are already familiar with machine learning. Moreover, the current implementation does not include symbol features, sparse pruning, and interpretable model visualization functions.

Hope you will like it (BTW, perhaps there are still people in the world using tensorflow?😇)

ZPZhou-lab avatar May 13 '24 07:05 ZPZhou-lab