DiT
DiT copied to clipboard
jax version
Hi authors,
The codebase is very helpful! And thanks for your efforts in putting everything together.
While I was trying the model, I saw that the model weights of this pytorch version were ported from a JAX-based model. I am curious whether there is a pointer to its jax's implementation?
Sincerely!
I'm interested in the jax version as well! Ideally, the same training code used, so it'd be possible to reproduce the training process.
Hello, I am also interested in the Jax version
Hey, I've implemented a Jax version that I can release soon when I'm finished with my project :)!
Hey, I've implemented a Jax version that I can release soon when I'm finished with my project :)!
Looking forward to it! Hope it's released soon
Well "soon" is a relative term, but I just released it here! Includes an implementation of MAE and my own method.
https://github.com/philippe-eecs/small-vision/tree/main