jax
jax copied to clipboard
[Pallas] Add FP32-TF32 rounding before matmul