functional-transformer
functional-transformer copied to clipboard
A pure-functional implementation of a machine learning transformer model in Python/JAX
From https://github.com/awf/functional-transformer/discussions/4#discussioncomment-2834638_ See how we might include the `sin(t)` terms, rather than just 'learned' encodings.
WandB loss curves (e.g. [here](https://wandb.ai/awfidius/pure-transformer/runs/ehf0othc)) show a sawtooth form, correlated with batch ID. Batches are [randomized](https://github.com/awf/awf-jaxutils/blob/2590cc78a4ab017e0f6bcd1ccded1f63bbd9fc6a/dataset.py#L67) and this occurs even with [1-bit gradients](https://github.com/awf/functional-transformer/blob/780073081d65df06a5c0c31dc4f9d2c8285625a0/main.py#L178), so it's not Adam... 
As noted in https://github.com/awf/functional-transformer/discussions/6 the model does not match the original code, or indeed the original transformer paper. I therefore consider this a "transformer variant", but of course it would...