jax icon indicating copy to clipboard operation
jax copied to clipboard

Unsupported conversion from f64 to f16 in pallas despite not using fp16

Open lengstrom opened this issue 1 year ago • 1 comments

Description

Hi, when I try to matrix multiply in float64 in pallas I get the following error related to converting to float16:

loc("/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]"(callsite("loop_body"("/mnt/xfs/home/fp64/fp64_broken.py":18:0) at callsite("kernel_func"("/mnt/xfs/home/fp64/fp64_broken.py":23:0) at callsite("matmul"("/mnt/xfs/home/fp64/fp64_broken.py":50:0) at "<module>"("/mnt/xfs/home/fp64/fp64_broken.py":68:0)))))): error: Rounding mode is required for FP downcast
loc("/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]"(callsite("loop_body"("/mnt/xfs/home/fp64/fp64_broken.py":18:0) at callsite("kernel_func"("/mnt/xfs/home/fp64/fp64_broken.py":23:0) at callsite("matmul"("/mnt/xfs/home/fp64/fp64_broken.py":50:0) at "<module>"("/mnt/xfs/home/fp64/fp64_broken.py":68:0)))))): error: Rounding mode is required for FP downcast
Unsupported conversion from f64 to f16
LLVM ERROR: Unsupported rounding mode for conversion.
Aborted (core dumped)

Why does this error occur given that we never convert to f16 anywhere?

Reproducible example:

import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax import lax
import functools

jax.config.update("jax_enable_x64", True)

def kernel_func(A_ref, B_ref, C_ref, N, block_rows, block_cols):

    C_accum = jnp.zeros((block_rows, block_cols), dtype=jnp.float32)

    def loop_body(i, C_accum):

        A_tile = pl.load(A_ref, (pl.dslice(None), pl.dslice(i * block_cols, block_cols)))
        B_tile = pl.load(B_ref, (pl.dslice(i * block_cols, block_cols), pl.dslice(None)))

        C_accum += pl.dot(A_tile, B_tile)

        return C_accum

    loop_limit = pl.cdiv(N, block_cols)
    C_accum = lax.fori_loop(0, loop_limit, loop_body, C_accum)

    pl.store(C_ref, (pl.dslice(None), pl.dslice(None)), C_accum.astype(C_ref.dtype))

def matmul(A, B):

    block_rows = 16
    block_cols = 32

    N = A.shape[0]

    grid = (pl.cdiv(N, block_rows), pl.cdiv(N, block_cols))

    in_specs = [
        pl.BlockSpec(lambda r, c: (r, 0), (block_rows, N)),
        pl.BlockSpec(lambda r, c: (0, c), (N, block_cols))
    ]

    C = jax.ShapeDtypeStruct(shape=(N, N), dtype=A.dtype)

    kernel = functools.partial(
        kernel_func,
        N = N,
        block_rows = block_rows,
        block_cols = block_cols
    )

    out, = pl.pallas_call(
        kernel,
        grid=grid,
        in_specs=in_specs,
        out_specs=[
            pl.BlockSpec(lambda r, c: (r, c), (block_rows, block_cols))
        ],
        out_shape=[ C ],
        name="matmul"
    )(A, B)

    return out

dtype = jnp.float64

N = 512
A = jax.random.uniform(jax.random.PRNGKey(0), (N, N), dtype=jnp.float32).astype(dtype)
B = jax.random.uniform(jax.random.PRNGKey(1), (N, N), dtype=jnp.float32).astype(dtype)
C_pallas = matmul(A, B)

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='deep-chungus-7.csail.mit.edu', release='5.15.0-113-generic', version='#123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024', machine='x86_64')


$ nvidia-smi
Wed Aug 21 16:47:30 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:01:00.0 Off |                    0 |
| N/A   66C    P0             85W /  300W |   66180MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

lengstrom avatar Aug 21 '24 20:08 lengstrom

It looks like when Pallas lowers the Matmul, it assumes the inputs are FP32:

      %77 = tt.dot %52, %75, %76, inputPrecision = tf32 : tensor<16x32xf64> * tensor<32x32xf64> -> tensor<16x32xf32> "/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float64]"(#loc54)

My guess is that the culprint is this line (https://github.com/google/jax/blob/main/jax/_src/pallas/triton/lowering.py#L2045) which sets the input type (f64) equal to the output type (f32).

@superbobry do you know what's the rationale behind the typing? I tried commenting out the line below which asserts input type == output type and the user's code worked. But I'm not sure why that check exists in the first place.

justinjfu avatar Aug 22 '24 21:08 justinjfu

It looks like pl.dot uses preferred_element_type=jnp.float32. However, that doesn't explain the error, and AFAICT there are no casts to fp16 anywhere in the TTIR generated by Pallas.

superbobry avatar Aug 30 '24 14:08 superbobry

Maybe running with MLIR_ENABLE_DUMP=1 (ie dump all intermediate mlir lowerings) would reveal more information

cperivol avatar Aug 30 '24 15:08 cperivol

I have encountered the same error when using floating type 64. Hoping there are solutions.

chaoming0625 avatar Oct 05 '24 12:10 chaoming0625