CLIP Weights are stored as float16
Dear the Scenic team,
I wanted to point out that the original torch CLIP Weights are stored in float16. But when used in Scenic, the Flax Module is defined with float32 weights. This doesn't cause problems during forward pass, but when I tried to fine-tune CLIP this caused unexpected errors on a TPU.
This could be overflow/underflow issues or the TPU isn't meant for float16 - this is the message I received.
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [post-optimization]: Bitcast cannot have different shape sizes of output (8192) and operand (6144).
To fix my error, I had to cast CLIP weights back to float32.
Hi @moabarar
Thank for figuring this out. Would you like to propose a PR for your fix?
~ Alexey