pykan icon indicating copy to clipboard operation
pykan copied to clipboard

Token based learning

Open AutomaticHourglass opened this issue 9 months ago • 2 comments

Hello,

I'm an avid transformer architecture user and have sufficient knowledge about neural networks.

Upon reading your github repo, I have seen great potential on a novel idea and would like to contribute.

Basically, I have adapted the KAN network to a token-based next token prediction system.

How it does is as it get a batch_size*seq_len batch, I calculate the token and positional embedding of those as learnable parameters. After this, I add up each other and get my first x value, my y value is one shifted token values.

However, in order to do this I had to work in the embedding space but it is 3 dimensional instead of 2. Since all the tokens are biased towards their respective position via positional embedding, I have calculated the mean over 1st(0 based) dimension to "hear" all tokens at once but answer one token at a time.

My spider sense says something is wrong and I'm not sufficiently confident in my understanding of your algorithm so if you are interested, please lets work on this.

Best

AutomaticHourglass avatar May 02 '24 23:05 AutomaticHourglass

Hi, sorry I don't think I quite get your question. My understanding is that one can place MLPs with KANs in a transformer (it should be as simple as that), but maybe I'm naive. ''My spider sense says something is wrong'' could you please elaborate? Do you mean the framework you just described could be wrong, or KANs might not be able to fit into the framework? Anyway, great initiative! I just want to understand more. Also your goal sounds ambitious, given that, I think maybe you can create your own repo without merging into this one.

KindXiaoming avatar May 02 '24 23:05 KindXiaoming

Hello,

Let me elaborate on it a bit. If I understand KAN's correctly, they are way more parameter efficient on capturing the non linear relationships between input and output but they are also not as efficient in terms of parameters so I don't think replacing MLPs in a billion parameter transformer would not work. Thats why I went for a pure KAN solution to understand it better.

During the training, I see the loss value jumps from 6-800 and back and this is usually unseen behavior, that didn't felt right so thats why my senses are saying something is wrong, its not your code, its my implementation that could be wrong.

I didnt' wanted to create my own repo for this reasons since I would need somebody to spar with to make this work.

AutomaticHourglass avatar May 03 '24 09:05 AutomaticHourglass