DiT icon indicating copy to clipboard operation
DiT copied to clipboard

jax version

Open NathanYanJing opened this issue 2 years ago • 5 comments

Hi authors,

The codebase is very helpful! And thanks for your efforts in putting everything together.

While I was trying the model, I saw that the model weights of this pytorch version were ported from a JAX-based model. I am curious whether there is a pointer to its jax's implementation?

Sincerely!

NathanYanJing avatar Dec 29 '22 20:12 NathanYanJing

I'm interested in the jax version as well! Ideally, the same training code used, so it'd be possible to reproduce the training process.

Lime-Cakes avatar Dec 30 '22 18:12 Lime-Cakes

Hello, I am also interested in the Jax version

JTT94 avatar Aug 18 '23 11:08 JTT94

Hey, I've implemented a Jax version that I can release soon when I'm finished with my project :)!

philippe-eecs avatar Sep 28 '23 01:09 philippe-eecs

Hey, I've implemented a Jax version that I can release soon when I'm finished with my project :)!

Looking forward to it! Hope it's released soon

Lime-Cakes avatar Oct 09 '23 16:10 Lime-Cakes

Well "soon" is a relative term, but I just released it here! Includes an implementation of MAE and my own method.

https://github.com/philippe-eecs/small-vision/tree/main

philippe-eecs avatar Jun 27 '24 00:06 philippe-eecs