equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Structured State Space Sequence (S4) Implementation

Open stergiosba opened this issue 1 year ago • 13 comments

Hey Patrick,

I have been using Equinox for one of my projects and up until now it has helped immensely in using JAX effectively and seamlessly.

For this project, I was going to use the model of Structured State Space Sequence (S4) from Gu et Al 2022. In short in the original paper that introduces S4, the authors compared it to Attention-based architectures in time-series analysis problems with long-range time dependences in the data (e.g. EEG signals, Audio recognition, etc). They showed that S4 outperformed every other model. I think it would be a very nice addition to the package for all those who might want to use S4+Equinox in the future.

Without going into further details, a new nn Layer must be created. I am unsure if you are familiar with the S4 concept or if you even have time to implement it. The original implementation of the Layer is written in PyTorch but they also provided a Flax implementation. I could use the Flax implementation as it is but I am already deeply invested in Equinox to turn back :). Hence, I started rewriting it for Equinox. To this end, I have two questions for you.

1) Would this be a worthy addition to the main package? 2) If need be could you provide some guidance for the sharp bits?

Best, Stergios

stergiosba avatar Jun 20 '23 04:06 stergiosba