maxtext
maxtext copied to clipboard
FlashAttention Support - TPUv3
Is FlashAttention supported on TPUv3? The same config that works on TPUv4 fails on TPUv3 with the following error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.
However, after setting attention from autoselected
to dot_product
, the error disappears.