krahnikblis

Results 2 comments of krahnikblis

right, for clarity, the issue happens in the **torch** version, when using float16. my reference to bfloat16 was also torch version - it's super slow (slower than torch.float32), but works...

is there a recommendation for Flax version for TPU on Colab? i just discovered last night that this [issue](https://github.com/google/jax/issues/14544) i reported to Jax a few weeks ago includes Flax version...