Unsupported conversion from f64 to f16 in pallas despite not using fp16
Description
Hi, when I try to matrix multiply in float64 in pallas I get the following error related to converting to float16:
loc("/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]"(callsite("loop_body"("/mnt/xfs/home/fp64/fp64_broken.py":18:0) at callsite("kernel_func"("/mnt/xfs/home/fp64/fp64_broken.py":23:0) at callsite("matmul"("/mnt/xfs/home/fp64/fp64_broken.py":50:0) at "<module>"("/mnt/xfs/home/fp64/fp64_broken.py":68:0)))))): error: Rounding mode is required for FP downcast
loc("/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]"(callsite("loop_body"("/mnt/xfs/home/fp64/fp64_broken.py":18:0) at callsite("kernel_func"("/mnt/xfs/home/fp64/fp64_broken.py":23:0) at callsite("matmul"("/mnt/xfs/home/fp64/fp64_broken.py":50:0) at "<module>"("/mnt/xfs/home/fp64/fp64_broken.py":68:0)))))): error: Rounding mode is required for FP downcast
Unsupported conversion from f64 to f16
LLVM ERROR: Unsupported rounding mode for conversion.
Aborted (core dumped)
Why does this error occur given that we never convert to f16 anywhere?
Reproducible example:
import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax import lax
import functools
jax.config.update("jax_enable_x64", True)
def kernel_func(A_ref, B_ref, C_ref, N, block_rows, block_cols):
C_accum = jnp.zeros((block_rows, block_cols), dtype=jnp.float32)
def loop_body(i, C_accum):
A_tile = pl.load(A_ref, (pl.dslice(None), pl.dslice(i * block_cols, block_cols)))
B_tile = pl.load(B_ref, (pl.dslice(i * block_cols, block_cols), pl.dslice(None)))
C_accum += pl.dot(A_tile, B_tile)
return C_accum
loop_limit = pl.cdiv(N, block_cols)
C_accum = lax.fori_loop(0, loop_limit, loop_body, C_accum)
pl.store(C_ref, (pl.dslice(None), pl.dslice(None)), C_accum.astype(C_ref.dtype))
def matmul(A, B):
block_rows = 16
block_cols = 32
N = A.shape[0]
grid = (pl.cdiv(N, block_rows), pl.cdiv(N, block_cols))
in_specs = [
pl.BlockSpec(lambda r, c: (r, 0), (block_rows, N)),
pl.BlockSpec(lambda r, c: (0, c), (N, block_cols))
]
C = jax.ShapeDtypeStruct(shape=(N, N), dtype=A.dtype)
kernel = functools.partial(
kernel_func,
N = N,
block_rows = block_rows,
block_cols = block_cols
)
out, = pl.pallas_call(
kernel,
grid=grid,
in_specs=in_specs,
out_specs=[
pl.BlockSpec(lambda r, c: (r, c), (block_rows, block_cols))
],
out_shape=[ C ],
name="matmul"
)(A, B)
return out
dtype = jnp.float64
N = 512
A = jax.random.uniform(jax.random.PRNGKey(0), (N, N), dtype=jnp.float32).astype(dtype)
B = jax.random.uniform(jax.random.PRNGKey(1), (N, N), dtype=jnp.float32).astype(dtype)
C_pallas = matmul(A, B)
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax; jax.print_environment_info()
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='deep-chungus-7.csail.mit.edu', release='5.15.0-113-generic', version='#123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed Aug 21 16:47:30 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100 80GB PCIe On | 00000000:01:00.0 Off | 0 |
| N/A 66C P0 85W / 300W | 66180MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
It looks like when Pallas lowers the Matmul, it assumes the inputs are FP32:
%77 = tt.dot %52, %75, %76, inputPrecision = tf32 : tensor<16x32xf64> * tensor<32x32xf64> -> tensor<16x32xf32> "/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float64]"(#loc54)
My guess is that the culprint is this line (https://github.com/google/jax/blob/main/jax/_src/pallas/triton/lowering.py#L2045) which sets the input type (f64) equal to the output type (f32).
@superbobry do you know what's the rationale behind the typing? I tried commenting out the line below which asserts input type == output type and the user's code worked. But I'm not sure why that check exists in the first place.
It looks like pl.dot uses preferred_element_type=jnp.float32. However, that doesn't explain the error, and AFAICT there are no casts to fp16 anywhere in the TTIR generated by Pallas.
Maybe running with MLIR_ENABLE_DUMP=1 (ie dump all intermediate mlir lowerings) would reveal more information
I have encountered the same error when using floating type 64. Hoping there are solutions.