flux icon indicating copy to clipboard operation
flux copied to clipboard

[BUG] fail to launch with ray

Open gameofdimension opened this issue 6 months ago • 2 comments

Describe the bug

The op fails to launch when using Ray, but runs successfully with torchrun.

To Reproduce

import sys
from typing import Optional

import ray
import torch

import flux
from flux.testing import initialize_distributed


def torch_base(
    TP_GROUP,
    M,
    K,
    input,
    transpose_weight: bool,
    weight,
    bias,
):
    # allgather and then matmul
    input_dtype = input.dtype
    full_input = torch.zeros(
        (M, K),
        dtype=input_dtype,
        device=torch.cuda.current_device(),
    )
    torch.distributed.all_gather_into_tensor(
        full_input, input, group=TP_GROUP,
    )
    if transpose_weight:
        gold = full_input @ weight
    else:
        gold = full_input @ weight.t().contiguous()
    if bias is not None:
        gold += bias
    return gold


@torch.no_grad()
def all_gather_gemm(
    TP_GROUP,
    input: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    transpose_weight: bool,
    ring_mode: Optional[flux.AGRingMode] = None,
    fast_acc: bool = False,
    use_pdl: bool = False,
):
    RANK, WORLD_SIZE, NNODES = TP_GROUP.rank(), TP_GROUP.size(), flux.testing.NNODES()

    local_M = input.size(0)
    M = local_M * TP_GROUP.size()
    K = input.size(1)

    if transpose_weight:
        N = weight.size(1)
    else:
        N = weight.size(0)

    input_dtype = input.dtype
    output_dtype = input.dtype
    ag_gemm_output = torch.empty([M, N], dtype=output_dtype, device=input.device)
    ag_option = flux.AllGatherOption()
    ag_option.mode = ring_mode
    print(f"RANK {RANK}/{WORLD_SIZE} - M: {M}, N: {N}, K: {K}, {input.shape}, {weight.shape}")
    gold = torch_base(
        TP_GROUP,
        M,
        K,
        input,
        transpose_weight,
        weight,
        bias,
    )

    all_gather_gemm_kernel = flux.AGKernel(
        TP_GROUP,
        NNODES,
        M,
        N,
        K,
        input_dtype,
        output_dtype=output_dtype,
        use_pdl=use_pdl,
    )

    if bias is not None:
        bias = bias.repeat_interleave(M, dim=0)

    all_gather_gemm_kernel.forward(
        input,
        weight,
        bias=bias,
        output=ag_gemm_output,
        input_scale=None,
        weight_scale=None,
        output_scale=None,
        fast_accum=fast_acc,
        gathered_input=None,
        transpose_weight=transpose_weight,
        all_gather_option=ag_option,
    )
    torch.cuda.synchronize()

    delta = (ag_gemm_output - gold).abs().max().item()
    print(
        f"all_gather_gemm testing Output shape {ag_gemm_output.shape}, "
        f"RANK {RANK}/{WORLD_SIZE} - Max delta: {delta:.6f}")


def make_data(M, K, N, dtype, TP_GROUP):
    generator = torch.Generator(device='cuda').manual_seed(666 + TP_GROUP.rank())
    device = torch.device(f'cuda:{torch.cuda.current_device()}')

    scale = 1  # (TP_GROUP.rank() + 1) * 0.01
    bias = 0.0
    input_tensor = (torch.rand(M, K, dtype=dtype, device=device,
                    generator=generator) * 2 - 1) * scale + bias
    weight_tensor = (torch.rand(K, N, dtype=dtype, device=device,
                     generator=generator) * 2 - 1) * scale + bias
    bias_tensor = (torch.randn(1, N, dtype=dtype, device=device,
                   generator=generator) * 2 - 1) * scale + bias

    return input_tensor, weight_tensor, bias_tensor


def run_all_gather_gemm(TP_GROUP):
    transpose_weight = True
    """
    * input: [M_per_rank, K] for all types
    * weight: [K, N] if transpose_weight, [N, K] if not transpose_weight.
    * bias: [1, N] for FP8 or INT8, [M_per_rank, N] for FP16 or BF16.
    """
    local_M = 4096
    K = 12288
    N = 49152
    dtype = torch.bfloat16  # torch.float16  # torch.bfloat16
    input_tensor, weight_tensor, bias_tensor = make_data(
        local_M, K, N, dtype, TP_GROUP
    )

    if not transpose_weight:
        weight_tensor = weight_tensor.t().contiguous()
    all_gather_gemm(
        TP_GROUP,
        input=input_tensor,
        weight=weight_tensor,
        bias=bias_tensor,
        transpose_weight=transpose_weight,
        ring_mode=None,
        fast_acc=False,
        use_pdl=False
    )


def main():
    TP_GROUP = initialize_distributed()
    run_all_gather_gemm(TP_GROUP)
    # run_gemm_rs()


@ray.remote(num_gpus=1)
def ray_launch():
    TP_GROUP = initialize_distributed()
    run_all_gather_gemm(TP_GROUP)


def ray_main():
    num_gpus = int(sys.argv[1])
    ray.init(num_gpus=num_gpus)
    tasks = []
    for i in range(num_gpus):
        env_vars = {}
        env_vars['RANK'] = str(i)
        env_vars['WORLD_SIZE'] = str(num_gpus)
        env_vars['MASTER_ADDR'] = 'localhost'
        env_vars['MASTER_PORT'] = '12345'
        task_func = ray_launch.options(
            runtime_env={
                'env_vars': env_vars
            }
        ).remote()
        tasks.append(task_func)
    ray.get(tasks)


if __name__ == "__main__":
    # it is ok to launch with torchrun
    # NVSHMEM_DISABLE_CUDA_VMM=1 torchrun --nproc_per_node=2 -m lab.example
    # main()

    # fail to launch with ray
    # NVSHMEM_DISABLE_CUDA_VMM=1 python3 -m lab.example 2
    ray_main()

Expected behavior

Expect it to work well when using Ray.

Stack trace/logs

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/flux/lab/example.py", line 195, in <module>
    ray_main()
  File "/root/flux/lab/example.py", line 185, in ray_main
    ray.get(tasks)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2771, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 919, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::ray_launch() (pid=20286, ip=10.189.110.50)
  File "/root/flux/lab/example.py", line 166, in ray_launch
    run_all_gather_gemm(TP_GROUP)
  File "/root/flux/lab/example.py", line 145, in run_all_gather_gemm
    all_gather_gemm(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/flux/lab/example.py", line 77, in all_gather_gemm
    all_gather_gemm_kernel = flux.AGKernel(
RuntimeError: /root/flux-tune/src/ths_op/topo_utils.cc:245 Check failed: 2((gpu_device_ids.size())) == 1((gpu_device_set.size()))

Environment L20+python3.10+torch2.6+cuda12.4+flux@13e48df3

Proposed fix If you have a proposal for how to fix the issue state it here or link to a PR.

Additional context Add any other context about the problem here.

gameofdimension avatar Jun 17 '25 06:06 gameofdimension

we do no test for ray. seems that ray will limit each process to only 1 GPU?

houqi avatar Jun 17 '25 07:06 houqi

what is the recommended way to use flux with ray?

gameofdimension avatar Jun 17 '25 13:06 gameofdimension