annotated-s4 icon indicating copy to clipboard operation
annotated-s4 copied to clipboard

PyTorch Implementation

Open TariqAHassan opened this issue 2 years ago • 10 comments

Hello @srush and @siddk

Thank you for putting together this wonderful library and blog post. I have found them both to be incredibly helpful!

A few months ago, when I first read your post, I attempted to reimplement S4 in PyTorch, using your blog as a guide. I managed to get reasonable results, although the code is not quite as fast as JAX 😄.

Now, I'd like to open-source the repository. However, I'm unsure of which license to use. Would it be OK with you if I used an Apache-2 license? This is the same license that is used in the state-spaces repository.

I would, of course, give clear credit to both of you and link to the original blog post.

TariqAHassan avatar Jul 07 '22 18:07 TariqAHassan

That's awesome Tariq! Really nice. I think Apache 2 sounds good to me.

Can we help in anyway? It would be nice if we were able to keep them in sync somehow. Would also love to help promote your repo as well.

Also not sure if you saw, but there were some recent updates that improved some of bugs in our implementation. Want to make sure they made it into yours as well.

srush avatar Jul 08 '22 01:07 srush

Hi Sasha,

Thanks for getting back to me.

On the license, excellent. I've released the code under it here.

Can we help in anyway? It would be nice if we were able to keep them in sync somehow.

Yes, I've thought about this a little bit. My initial plan was to simply create a pull request for this repo, but I couldn't find an elegant way to structure things while maintaining the ability to pip install the PyTorch code.

One approach would be for me to make you an administrator on the repository I created, which would make it easy for you to implement changes. In general though, I'm open to any ideas you may have here.

Also not sure if you saw, but there were some recent updates that improved some of bugs in our implementation

Two things here:

  • I think I have implemented some of them, but not all. For instance, my Ct parameter is complex-valued, but I am still applying ifft order when constructing K, which I think was found to no longer be necessary.
  • The performance I achieve on some benchmarks is lower than expected, such as CIFAR10. Part of this likely stems from outstanding bugs in my implementation, but some of it is no doubt due to my use of pooling after each S4 block. At the time I wrote this code, I only had access to a 16GB V100, and so this was the only way I could get the code to run with a reasonable batch size. Now, I have access to a 80GB A100, which should mean I will be able to remove this memory-saving step.

TariqAHassan avatar Jul 08 '22 07:07 TariqAHassan

Sounds great. I definitely think it makes sense to try to get the CIFAR number up to the high 80s are so. I wonder if it is worth trying to use Torch-KeOps to make the memory less bad?

Repo looks great. I think it is fine to have it separate. I was just thinking it would be nice to have tests to show that they are computing the same thing function (i.e. like how BERT models work in TensorFlow and Pytorch. )

srush avatar Jul 11 '22 02:07 srush

Great. I have added you to the repo.

I definitely think it makes sense to try to get the CIFAR number up to the high 80s are so. I wonder if it is worth trying to use Torch-KeOps to make the memory less bad?

I have looked into this a bit more closely. The memory usage of my implementation when training CIFAR (without additional pooling) peaked at about 15.5 GB when using a batch size of 32, which is definately on the outer edges of what a 16 GB V100 can handle, but it's not as bad as I feared.* Still, yes, there are likely performance gains to be had by using Torch-KeOps. I will look into this. (I may be able to draw on this code in the original state-spaces repo.)

* Of course, there's always the option of simply lowering the batch size.

I was just thinking it would be nice to have tests to show that they are computing the same thing function (i.e. like how BERT models work in TensorFlow and Pytorch.)

Yes, agreed. I have created a new (WIP) PR where I will be attempting to backport the bug fixes which have been made in this repository into my code. It could be a week (or more) before I have time to focus on that but, at any rate, I will post back here when I've made meaningful progress.

TariqAHassan avatar Jul 12 '22 20:07 TariqAHassan

I'm not sure it's worth using keops or the custom CUDA kernel for a pedagogical implementation. FYI, you should be able to do well on CIFAR with much smaller models. This pedagogical version of S4D gets 88% on CIFAR with 200k parameters (I think 4 layers x 128 dim)

albertfgu avatar Jul 12 '22 20:07 albertfgu

Absolutely. The complexity of the code is definitely something I'm mindful of here given its purpose.

FYI, you should be able to do well on CIFAR with much smaller models.

Yes, and I think that will be the case when I implement some of the changes that have been made to this repository over the past few months.

This pedagogical version of S4D gets 88% on CIFAR with 200k parameters (I think 4 layers x 128 dim)

Very neat. I haven't seen that version before, but it's quite helpful.

TariqAHassan avatar Jul 12 '22 21:07 TariqAHassan

Hey @TariqAHassan - sorry I'm late to the party! This is awesome work, and minimally, I'd love to help with figuring out how to reproduce the CIFAR numbers in your codebase.

What's your current plan for addressing some of the various changes we've made? Would love to see if I can help out in any way (or even just share things we tried when we were working on CIFAR numbers)!

siddk avatar Jul 13 '22 13:07 siddk

Hi @siddk — sorry about the delay getting back to you.

What's your current plan for addressing some of the various changes we've made? Would love to see if I can help out in any way (or even just share things we tried when we were working on CIFAR numbers)!

My plan is just to carefully review the changes that have been made to this repo over the past few months. (I've opened a PR where I've started implementing changes, but haven't made it very far yet and, unfortunately, it looks like it will be a week or so before I have the time to focus on it.)

As for help, absolutely! Any help and/or guidance you have to offer would be more than welcome. 😄

TariqAHassan avatar Jul 15 '22 16:07 TariqAHassan

Hey @TariqAHassan - sounds good and take your time! Maybe ping me directly in the PR (comments) if there are things that are particularly strange, otherwise, I'll go through it in a couple of days?

siddk avatar Jul 17 '22 20:07 siddk

Hi @siddk Absolutely, sounds good. I've been very busy recently, but I should finally have the time to start on this next week.

TariqAHassan avatar Jul 21 '22 08:07 TariqAHassan