stable-diffusion-jax
stable-diffusion-jax copied to clipboard
[Don't merge] Platform tests
Current status
☑️ Easy bugs
- Set offset=1 (#4)
- Iterate for the total number of steps (#6).
☑️ Scheduler state bug
I created a stateless implementation that I test using torch-micro-scheduler.py
and jax-micro-scheduler.py
. They use random input (the same in both scripts) and print the output after every call to scheduler.step
. The last line from torch-micro-scheduler.py
, when running in CPU, is:
50 [1]: [[ -0.03086062 -14.875211 -21.42081 4.41971 ]]
The last one from jax-micro-scheduler.py
using CPU
is:
50 [1]: [[[ -0.0308551 -14.875193 -21.420778 4.419718 ]]]
And when using a CUDA
device it is:
50 [1]: [[[ -0.03085638 -14.875197 -21.420776 4.4197164 ]]]
I consider this to be close enough. These tests were conducted in the same Linux box.
☑️ Sinusoidal embeddings bug
- See #7
After this fix, the difference after the 50-times inference loop is at ~8e-4
for the slice I'm printing.
❌ Remaining bugs
A simple image generation loop, without classifier-free guidance, still shows differences between the PyTorch and the Flax versions.
The test scripts show a slice of the current latents
after each iteration step. These are the results:
PyTorch CPU
Step 0 [981]: [-0.20029372 -0.71322274 -0.6588407 1.3012452 ]
Step 1 [961]: [-0.20023333 -0.7130672 -0.6587741 1.3011351 ]
Step 2 [961]: [-0.19388276 -0.7101537 -0.6593911 1.2980776 ]
Step 3 [941]: [-0.18676892 -0.7066815 -0.6598941 1.2944702 ]
Step 4 [921]: [-0.17876498 -0.70243776 -0.660212 1.290017 ]
Step 5 [901]: [-0.16972902 -0.69751877 -0.66045487 1.284985 ]
Step 6 [881]: [-0.15949863 -0.69193625 -0.66029674 1.2790979 ]
Step 7 [861]: [-0.14811812 -0.68585664 -0.6598889 1.2726934 ]
Step 8 [841]: [-0.13533752 -0.6792915 -0.6590149 1.2656195 ]
Step 9 [821]: [-0.12131497 -0.6722729 -0.657692 1.2578422 ]
Step 10 [801]: [-0.10598446 -0.6647858 -0.65582865 1.2492626 ]
Step 11 [781]: [-0.08965665 -0.65693915 -0.6534673 1.2399257 ]
Step 12 [761]: [-0.07231925 -0.6487282 -0.6504389 1.2296318 ]
Step 13 [741]: [-0.05423042 -0.64025223 -0.64689887 1.2185035 ]
Step 14 [721]: [-0.03560711 -0.63162565 -0.64254385 1.2063537 ]
Step 15 [701]: [-0.0166229 -0.62293816 -0.63756216 1.1934131 ]
Step 16 [681]: [ 0.00244747 -0.61430305 -0.631572 1.1795033 ]
Step 17 [661]: [ 0.02151561 -0.6057371 -0.62473357 1.1647282 ]
Step 18 [641]: [ 0.04053753 -0.59714484 -0.6170455 1.1489599 ]
Step 19 [621]: [ 0.05943215 -0.5886498 -0.6084518 1.1321318 ]
Step 20 [601]: [ 0.07833029 -0.58015937 -0.5992237 1.114365 ]
Step 21 [581]: [ 0.09711783 -0.5717824 -0.5891966 1.0954585 ]
Step 22 [561]: [ 0.11586673 -0.56339365 -0.57859325 1.0755281 ]
Step 23 [541]: [ 0.13457038 -0.5550037 -0.56728995 1.054374 ]
Step 24 [521]: [ 0.15327923 -0.54639286 -0.5554153 1.0321356 ]
Step 25 [501]: [ 0.17189951 -0.53757924 -0.5428081 1.008607 ]
Step 26 [481]: [ 0.19037156 -0.52846026 -0.5295105 0.9838172 ]
Step 27 [461]: [ 0.20872533 -0.5189969 -0.5155331 0.9578293 ]
Step 28 [441]: [ 0.22693683 -0.50913084 -0.5008364 0.9305079 ]
Step 29 [421]: [ 0.24495293 -0.4988736 -0.48545808 0.9019891 ]
Step 30 [401]: [ 0.2628477 -0.48817235 -0.46945277 0.87225354]
Step 31 [381]: [ 0.28055626 -0.47705117 -0.45268854 0.8412365 ]
Step 32 [361]: [ 0.2980854 -0.46552035 -0.4353878 0.80914426]
Step 33 [341]: [ 0.3154298 -0.45355433 -0.41739047 0.77581835]
Step 34 [321]: [ 0.33256987 -0.44116807 -0.39883557 0.7414012 ]
Step 35 [301]: [ 0.34944463 -0.4283454 -0.3796378 0.705828 ]
Step 36 [281]: [ 0.3660169 -0.41508853 -0.35983148 0.6691652 ]
Step 37 [261]: [ 0.38228554 -0.40139943 -0.3394577 0.6314346 ]
Step 38 [241]: [ 0.39813244 -0.38732246 -0.31853062 0.59268993]
Step 39 [221]: [ 0.41363195 -0.37279892 -0.29705465 0.5529112 ]
Step 40 [201]: [ 0.42867348 -0.35788408 -0.27501422 0.5121036 ]
Step 41 [181]: [ 0.44330248 -0.3425106 -0.25247994 0.47017682]
Step 42 [161]: [ 0.45744026 -0.32664642 -0.22940254 0.42711228]
Step 43 [141]: [ 0.47099906 -0.3102275 -0.20562789 0.38262194]
Step 44 [121]: [ 0.48409614 -0.29324844 -0.1810959 0.3365584 ]
Step 45 [101]: [ 0.4967933 -0.27560505 -0.15544392 0.28805938]
Step 46 [81]: [ 0.5093855 -0.25712317 -0.12850419 0.23624928]
Step 47 [61]: [ 0.5223362 -0.23739506 -0.09932347 0.17866679]
Step 48 [41]: [ 0.53584635 -0.21572267 -0.06636252 0.11064841]
Step 49 [21]: [ 0.5504587 -0.18439673 -0.01844398 0.00506777]
Step 50 [1]: [ 0.5498064 -0.18463875 -0.01927387 0.00364937]
Flax CPU
Step: 0: [-0.19992101 -0.7129871 -0.6586888 1.3008868 ]
Step: 1: [-0.19981347 -0.7128339 -0.6585821 1.3007385 ]
Step: 2: [-0.19290741 -0.70965886 -0.6589236 1.2971743 ]
Step: 3: [-0.18514648 -0.7057757 -0.6591348 1.2928768 ]
Step: 4: [-0.17614478 -0.70119715 -0.6590942 1.2878238 ]
Step: 5: [-0.16603893 -0.69599694 -0.6587924 1.2821018 ]
Step: 6: [-0.1546922 -0.69035506 -0.65814054 1.2758265 ]
Step: 7: [-0.14186884 -0.6842141 -0.6570212 1.2688968 ]
Step: 8: [-0.12765141 -0.6776505 -0.65541327 1.2612278 ]
Step: 9: [-0.11221321 -0.6706691 -0.6532921 1.2528518 ]
Step: 10: [-0.09563611 -0.6633487 -0.6506718 1.2437 ]
Step: 11: [-0.07814168 -0.65575045 -0.64742196 1.2336963 ]
Step: 12: [-0.05991069 -0.6480145 -0.6434747 1.2227519 ]
Step: 13: [-0.04123378 -0.6402351 -0.63899314 1.211138 ]
Step: 14: [-0.02239008 -0.6325704 -0.6335036 1.198588 ]
Step: 15: [-0.00347782 -0.6249878 -0.62726957 1.1853817 ]
Step: 16: [ 0.01554522 -0.61738086 -0.62022686 1.1712673 ]
Step: 17: [ 0.03435938 -0.61008286 -0.61222446 1.1561339 ]
Step: 18: [ 0.05318831 -0.6028376 -0.60362244 1.140135 ]
Step: 19: [ 0.07208736 -0.5956207 -0.5944642 1.1233044 ]
Step: 20: [ 0.09092688 -0.58852077 -0.58460766 1.1053636 ]
Step: 21: [ 0.10976899 -0.58154446 -0.574152 1.0863402 ]
Step: 22: [ 0.12861738 -0.5744467 -0.5630966 1.0662225 ]
Step: 23: [ 0.14747892 -0.56707346 -0.55156934 1.0450902 ]
Step: 24: [ 0.1661616 -0.55952215 -0.53917605 1.022551 ]
Step: 25: [ 0.18472286 -0.5516906 -0.52600926 0.9986482 ]
Step: 26: [ 0.2032078 -0.54341716 -0.51236373 0.9736709 ]
Step: 27: [ 0.22153237 -0.53473914 -0.49796087 0.94738066]
Step: 28: [ 0.23979937 -0.5256013 -0.48277816 0.9196414 ]
Step: 29: [ 0.2578938 -0.51604843 -0.4671864 0.89093804]
Step: 30: [ 0.27593407 -0.5060353 -0.4508992 0.8608283 ]
Step: 31: [ 0.29384482 -0.4955864 -0.43395382 0.8294597 ]
Step: 32: [ 0.31158978 -0.48479837 -0.41666374 0.79701966]
Step: 33: [ 0.32928428 -0.4736257 -0.3989653 0.7633914 ]
Step: 34: [ 0.3467968 -0.4620075 -0.38061678 0.72844476]
Step: 35: [ 0.36414167 -0.45000747 -0.36187106 0.69242454]
Step: 36: [ 0.3813184 -0.43756184 -0.34273604 0.65527713]
Step: 37: [ 0.39826918 -0.42457163 -0.32300457 0.616891 ]
Step: 38: [ 0.41499922 -0.4111316 -0.3029643 0.5774899 ]
Step: 39: [ 0.43152043 -0.39716995 -0.28250125 0.53699094]
Step: 40: [ 0.44782242 -0.3827076 -0.26167208 0.4953915 ]
Step: 41: [ 0.46390548 -0.36765796 -0.24032444 0.45241216]
Step: 42: [ 0.47982553 -0.35209504 -0.21875334 0.40826562]
Step: 43: [ 0.49573308 -0.33579704 -0.19671243 0.36226097]
Step: 44: [ 0.5115899 -0.3185639 -0.17404059 0.31394032]
Step: 45: [ 0.527544 -0.30043286 -0.15111566 0.26273066]
Step: 46: [ 0.5436902 -0.2809677 -0.1277442 0.2067176]
Step: 47: [ 0.5601497 -0.25989434 -0.10393226 0.14277294]
Step: 48: [ 0.57713884 -0.2371628 -0.08050387 0.06425086]
Step: 49: [ 0.59747016 -0.20854189 -0.05620244 -0.06810611]
Step: 50: [ 0.5983118 -0.21079235 -0.05890573 -0.07158186]
As you can see, the numbers are slowly drifting apart. For example, the difference after the first step is in the order of 1e-4. After step 50, it is 0.048 for the first column (9% relative).
Could there be something in the UNet that is slightly different? How would you go about debugging this?
After the sinusoidal fix (#7), the differences after the inference loop are in the order of ~8e-4
for the slice I print in one of my test images. This is the visual result (left: before the fix; middle: PyTorch CPU reference; right: after the fix). I see no difference between the PyTorch and Flax versions:

Next steps:
- [ ] Should we enable
int64
for timesteps?
This means that some operations will be performed in float64
and then we'll cast to float32
. It has to be enabled explicitly using jax.config.update("jax_enable_x64", True)
. I haven't measure performance impact, my anecdotal results on one test image are:
int32
: jnp.sum(jnp.abs((latents - torch_latents))) -> 98.92086
int64
: jnp.sum(jnp.abs((latents - torch_latents))) -> 96.41958
- [ ] Improve the stateless scheduler code
I'm using both a dataclass and a dict depending on what part of the code, and I'm sharding manually. I believe I should move to a flax.struct.dataclass
to simplify.
- [ ] Test a complete generation loop. Currently I'm using text embeddings generated in PyTorch, and I'm not applying classifier-free guidance.
- [ ] Test on TPU!
- [ ] Merge the bug fixes: #4, #6, #7, #8.
- [ ] Adapt the Flax inference notebook and publish it!
- [ ] Make the demo work on TPU.
Thanks a lot for the summary @pcuenca and great job finding the bug!
Regarding the questions:
1.) I think we should not create flloat64 timesteps, but just keep float32 timesteps (the 2% relative diff is not worth it IMO)
More generally I think after having verified that TPU works well we can spin up a demo with this and then before publishing the notebook, I'd actually merge everything directly into diffusers
and make a bigger release (0.3.0) saying we now also support JAX
I think we should not create flloat64 timesteps 👍
I'd actually merge everything directly into
diffusers
and make a bigger release (0.3.0) saying we now also support JAX
Fine for me! But this is only for Stable Diffusion with this one scheduler, we need to be clear about that :)