jax icon indicating copy to clipboard operation
jax copied to clipboard

JAX segment_sum is two times slower for FP16 inputs than FP32 inputs

Open CloudyDory opened this issue 1 year ago • 10 comments

Description

I find that JAX segment_sum is two times slower for FP16 inputs than FP32 inputs. Here is an example:

import time
import numpy as np
import jax
import jax.numpy as jnp

num_segments = 1700
segment_ids = np.repeat(np.arange(num_segments), np.random.randint(40,977,size=num_segments))

key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float32)
start_time = time.time()
data_sum = jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)
data_sum.block_until_ready()
print('Run time for FP32: {:.5f} seconds.'.format(time.time()-start_time))

key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float16)
start_time = time.time()
data_sum = jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)
data_sum.block_until_ready()
print('Run time for FP16: {:.5f} seconds.'.format(time.time()-start_time))

Outputs:

Run time for FP32: 0.03310 seconds.
Run time for FP16: 0.08621 seconds.

This happens with or without jit(). Why does this happen? And is there a way to optimize the computation for FP16 input?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ZJ', release='6.5.0-45-generic', version='#45~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Mon Jul 15 16:40:02 UTC 2', machine='x86_64')


$ nvidia-smi
Tue Aug 20 10:40:04 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:01:00.0 Off |                  Off |
|  0%   49C    P8              39W / 480W |     33MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        On  | 00000000:03:00.0 Off |                  Off |
|  0%   45C    P8              36W / 480W |   5032MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1284      G   /usr/lib/xorg/Xorg                            9MiB |
|    0   N/A  N/A      1317      G   /usr/bin/gnome-shell                         10MiB |
|    1   N/A  N/A      1284      G   /usr/lib/xorg/Xorg                            4MiB |
|    1   N/A  N/A     88489      C   .../miniconda3/envs/jax/bin/python         5012MiB |
+---------------------------------------------------------------------------------------+

CloudyDory avatar Aug 20 '24 02:08 CloudyDory

I don't know the answer to this, but maybe @jakevdp does?

Some notes in the meantime. It's worth checking out the JAX microbenchmark FAQ entry because benchmarking like you're doing here can lead to incorrect conclusions since it includes the tracing and compilation overhead. Updating this doesn't seem to change the specific conclusions though! Here's how I would write the benchmark:

import numpy as np
import jax
import jax.numpy as jnp

num_segments = 1700
segment_ids = np.repeat(np.arange(num_segments), np.random.randint(40,977,size=num_segments))

@jax.jit
def do_sum(data):
  return jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)

key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float32)
do_sum(data).block_until_ready()  # compile
%timeit do_sum(data).block_until_ready()

key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float16)
do_sum(data).block_until_ready()  # compile
%timeit do_sum(data).block_until_ready()

Regardless, I do find that the float16 version is consistently slower. Perhaps @jakevdp can lead us in the right direction!

dfm avatar Aug 20 '24 10:08 dfm

Interesting question! I suspect the reason for the performance difference here is that the GPU hardware is designed and tuned for float32 computation, and not for float16 computation. It would be interesting to compare this across different generations of GPU hardware.

jakevdp avatar Aug 20 '24 14:08 jakevdp

Interesting question! I suspect the reason for the performance difference here is that the GPU hardware is designed and tuned for float32 computation, and not for float16 computation. It would be interesting to compare this across different generations of GPU hardware.

But I think GPU's FP16 performance shouldn't be slower than FP32 performance. For example, the A100's FP16 FLOPS is twice the FP32 FLOPS; and for NVIDIA 4090, some data shows that it has equal FP16 and FP32 performance.

Is it possible that JAX somehow internally converts FP16 to FP32, performs the computation, and converts the result back to FP16?

CloudyDory avatar Aug 20 '24 14:08 CloudyDory

No, I don't think such conversions are happening – you can see exactly what operations the compiler is emitting using ahead of time lowering to output the compiled HLO. This is the output on a T4 GPU:

key = jax.random.key(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype='float16')
print(jax.jit(lambda data: jax.ops.segment_sum(
        data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True
      )).lower(data).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f16[852234]{0})->f16[1700]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="31c88d085a583f79a5c8e16aa07882b5"}

%region_0.7 (Arg_0.8.0: f16[], Arg_1.9.0: f16[]) -> f16[] {
  %Arg_1.9.0 = f16[] parameter(1)
  %Arg_0.8.0 = f16[] parameter(0)
  ROOT %add.1.0 = f16[] add(f16[] %Arg_0.8.0, f16[] %Arg_1.9.0), metadata={op_name="/add" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}

%fused_scatter (param_0: f16[1700], param_1: s32[852234,1], param_2.1: f16[852234]) -> f16[1700] {
  %param_0 = f16[1700]{0} parameter(0)
  %param_1 = s32[852234,1]{1,0} parameter(1)
  %param_2.1 = f16[852234]{0} parameter(2)
  %bitcast.26.1 = f16[852234,1]{1,0} bitcast(f16[852234]{0} %param_2.1)
  ROOT %scatter.11.1 = f16[1700]{0} scatter(f16[1700]{0} %param_0, s32[852234,1]{1,0} %param_1, f16[852234,1]{1,0} %bitcast.26.1), update_window_dims={1}, inserted_window_dims={}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_0.7, metadata={op_name="jit(<lambda>)/jit(main)/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=False mode=GatherScatterMode.FILL_OR_DROP]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}

%fused_broadcast () -> f16[1700] {
  %constant_3_1 = f16[] constant(0)
  ROOT %broadcast.1.1 = f16[1700]{0} broadcast(f16[] %constant_3_1), dimensions={}, metadata={op_name="jit(<lambda>)/jit(main)/broadcast_in_dim[shape=(1700,) broadcast_dimensions=()]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}

ENTRY %main.12 (Arg_0.1.0: f16[852234]) -> f16[1700] {
  %constant_1_0 = s32[852234,1]{1,0} constant({...})
  %Arg_0.1.0 = f16[852234]{0} parameter(0)
  %loop_broadcast_fusion = f16[1700]{0} fusion(), kind=kLoop, calls=%fused_broadcast, metadata={op_name="jit(<lambda>)/jit(main)/broadcast_in_dim[shape=(1700,) broadcast_dimensions=()]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
  ROOT %input_scatter_fusion = f16[1700]{0} fusion(f16[1700]{0} %loop_broadcast_fusion, s32[852234,1]{1,0} %constant_1_0, f16[852234]{0} %Arg_0.1.0), kind=kInput, calls=%fused_scatter, metadata={op_name="jit(<lambda>)/jit(main)/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=False mode=GatherScatterMode.FILL_OR_DROP]" source_file="<ipython-input-2-a1e8edb9aab8>" source_line=14}
}

jakevdp avatar Aug 20 '24 15:08 jakevdp

Thanks for the clarification!

What might be the problem then? I am curious about how we can debug into this issue.

CloudyDory avatar Aug 21 '24 01:08 CloudyDory

My best guess still is that the hardware you're using is not optimized for the kinds of operations you're performing (i.e. scatters) in float16, and is more optimized for float32.

Appendix A of https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf suggests that for GeForce RTX 4090, non-tensor ops are no faster in F16 than in F32, though it doesn't indicate that they should be slower. It may be that performance is worse for F16 scatters – I'm not sure.

jakevdp avatar Aug 21 '24 01:08 jakevdp

I write a similar benchmark on PyTorch 2.3.1 and the torch_scatter library, and now I agree that non-tensor ops are no faster in F16 than in F32 on GeForce RTX 4090.

However, it seems that PyTorch's FP16 performance is 380 times faster than Jax's FP16 performance on RTX 4090. If the following benchmark code is correct, then there is still much room for improvement in Jax?

PyTorch code:

import torch
import torch_scatter  # !conda install pytorch-scatter -c pyg

def do_sum(data):
    y = torch_scatter.scatter(data, segment_ids, reduce='sum')
    torch.cuda.synchronize()
    return y

device = torch.device('cuda')
num_segments = 1700
segment_ids = torch.repeat_interleave(torch.arange(num_segments, device=device), torch.randint(40,977,size=(num_segments,),device=device))

# FP32
data_fp32 = torch.rand(len(segment_ids), dtype=torch.float32, device=device)
torch.cuda.synchronize()
%timeit do_sum(data_fp32)

# FP16
data_fp16 = torch.rand(len(segment_ids), dtype=torch.float16, device=device)
torch.cuda.synchronize()
%timeit do_sum(data_fp16)

PyTorch result:

102 μs ± 75.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
120 μs ± 121 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Jax code:

import numpy as np
import jax
import jax.numpy as jnp

num_segments = 1700
segment_ids = np.repeat(np.arange(num_segments), np.random.randint(40,977,size=num_segments))

@jax.jit
def do_sum(data):
  return jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)

key = jax.random.PRNGKey(0)
data_fp32 = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float32)
do_sum(data_fp32).block_until_ready()  # compile
%timeit do_sum(data_fp32).block_until_ready()

key = jax.random.PRNGKey(0)
data_fp16 = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float16)
do_sum(data_fp16).block_until_ready()  # compile
%timeit do_sum(data_fp16).block_until_ready()

Jax result:

161 μs ± 2.84 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
45.8 ms ± 162 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

CloudyDory avatar Aug 21 '24 10:08 CloudyDory

I think these are not equivalent operations – wouldn't torch.scatter be equivalent to JAX scatter, not JAX segment sum?

jakevdp avatar Aug 21 '24 15:08 jakevdp

Hi, this is not torch.scatter, but torch_scatter.scatter in the torch_scatter library (https://github.com/rusty1s/pytorch_scatter).

According to the documentation (https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter), torch_scatter.scatter is actually doing segment_sum. image

We can also verify this in the following code: Pytorch:

import torch
import torch_scatter

data = torch.tensor([3.12, 4.98, 5.0, -1.3, -0.45, 2.08, 1.2], dtype=torch.float32)
segment_ids = torch.tensor([1,0,0,1,2,2,3], dtype=torch.int64)
print(torch_scatter.scatter(data, segment_ids, reduce='sum'))

Output:

tensor([9.9800, 1.8200, 1.6300, 1.2000])

JAX:

import jax
import jax.numpy as jnp

data = jnp.array([3.12, 4.98, 5.0, -1.3, -0.45, 2.08, 1.2], dtype=jnp.float32)
segment_ids = jnp.array([1,0,0,1,2,2,3], dtype=jnp.int32)
print(jax.ops.segment_sum(data, segment_ids=segment_ids))

Output:

[9.98      1.8199999 1.6299999 1.2      ]

CloudyDory avatar Aug 21 '24 22:08 CloudyDory

Ah, thanks for the clarification. Looks like it is doing the same thing – I'm not sure why JAX's version is slower.

jakevdp avatar Aug 22 '24 00:08 jakevdp

This should be fixed after https://github.com/openxla/xla/commit/763244822bf01b1e33f5352b0eea5cb13259af58.

We were using the non-sorted algorithm which is very slow for fp16 (See https://github.com/openxla/xla/issues/22233#issuecomment-2884486719 if you want more information).

karupayun avatar Jun 04 '25 08:06 karupayun