keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

[Flux] Port Flux Core Model

Open DavidLandup0 opened this issue 1 year ago • 3 comments

This PR ports the core model into a Keras model and includes a weight conversion script. VAE and rest of the pipeline would make sense in a separate PR.

Each layer is numerically compared against the original PyTorch implementation here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k

Modules included:

  • Maths module
    • Timestep embedding
    • RoPE
    • Attention
    • Scaled dot product attention re-implementation in Keras (to match the PyTorch one)
  • Layers module
    • MLPEmbedder
    • RMSNorm
    • QKNorm
    • SelfAttention
    • Modulation
    • DoubleStreamBlock
    • SingleStreamBlock
    • LastLayer

Output Comparison

The core model's outputs are latents. We plot the PCA of the output from the original implementation and the Keras re-implementation on the same input:

image

Numerically, equivalent to 1e-3 precision:

>>> np.allclose(output_keras.numpy(), output_pt.detach().numpy(), atol=1e-3)
True

DavidLandup0 avatar Sep 23 '24 15:09 DavidLandup0