keras-nlp
keras-nlp copied to clipboard
[Flux] Port Flux Core Model
This PR ports the core model into a Keras model and includes a weight conversion script. VAE and rest of the pipeline would make sense in a separate PR.
Each layer is numerically compared against the original PyTorch implementation here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k
Modules included:
- Maths module
- Timestep embedding
- RoPE
- Attention
- Scaled dot product attention re-implementation in Keras (to match the PyTorch one)
- Layers module
- MLPEmbedder
- RMSNorm
- QKNorm
- SelfAttention
- Modulation
- DoubleStreamBlock
- SingleStreamBlock
- LastLayer
Output Comparison
The core model's outputs are latents. We plot the PCA of the output from the original implementation and the Keras re-implementation on the same input:
Numerically, equivalent to 1e-3 precision:
>>> np.allclose(output_keras.numpy(), output_pt.detach().numpy(), atol=1e-3)
True