flux
flux copied to clipboard
[BUG] fail to launch with ray
Describe the bug
The op fails to launch when using Ray, but runs successfully with torchrun.
To Reproduce
import sys
from typing import Optional
import ray
import torch
import flux
from flux.testing import initialize_distributed
def torch_base(
TP_GROUP,
M,
K,
input,
transpose_weight: bool,
weight,
bias,
):
# allgather and then matmul
input_dtype = input.dtype
full_input = torch.zeros(
(M, K),
dtype=input_dtype,
device=torch.cuda.current_device(),
)
torch.distributed.all_gather_into_tensor(
full_input, input, group=TP_GROUP,
)
if transpose_weight:
gold = full_input @ weight
else:
gold = full_input @ weight.t().contiguous()
if bias is not None:
gold += bias
return gold
@torch.no_grad()
def all_gather_gemm(
TP_GROUP,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
transpose_weight: bool,
ring_mode: Optional[flux.AGRingMode] = None,
fast_acc: bool = False,
use_pdl: bool = False,
):
RANK, WORLD_SIZE, NNODES = TP_GROUP.rank(), TP_GROUP.size(), flux.testing.NNODES()
local_M = input.size(0)
M = local_M * TP_GROUP.size()
K = input.size(1)
if transpose_weight:
N = weight.size(1)
else:
N = weight.size(0)
input_dtype = input.dtype
output_dtype = input.dtype
ag_gemm_output = torch.empty([M, N], dtype=output_dtype, device=input.device)
ag_option = flux.AllGatherOption()
ag_option.mode = ring_mode
print(f"RANK {RANK}/{WORLD_SIZE} - M: {M}, N: {N}, K: {K}, {input.shape}, {weight.shape}")
gold = torch_base(
TP_GROUP,
M,
K,
input,
transpose_weight,
weight,
bias,
)
all_gather_gemm_kernel = flux.AGKernel(
TP_GROUP,
NNODES,
M,
N,
K,
input_dtype,
output_dtype=output_dtype,
use_pdl=use_pdl,
)
if bias is not None:
bias = bias.repeat_interleave(M, dim=0)
all_gather_gemm_kernel.forward(
input,
weight,
bias=bias,
output=ag_gemm_output,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=fast_acc,
gathered_input=None,
transpose_weight=transpose_weight,
all_gather_option=ag_option,
)
torch.cuda.synchronize()
delta = (ag_gemm_output - gold).abs().max().item()
print(
f"all_gather_gemm testing Output shape {ag_gemm_output.shape}, "
f"RANK {RANK}/{WORLD_SIZE} - Max delta: {delta:.6f}")
def make_data(M, K, N, dtype, TP_GROUP):
generator = torch.Generator(device='cuda').manual_seed(666 + TP_GROUP.rank())
device = torch.device(f'cuda:{torch.cuda.current_device()}')
scale = 1 # (TP_GROUP.rank() + 1) * 0.01
bias = 0.0
input_tensor = (torch.rand(M, K, dtype=dtype, device=device,
generator=generator) * 2 - 1) * scale + bias
weight_tensor = (torch.rand(K, N, dtype=dtype, device=device,
generator=generator) * 2 - 1) * scale + bias
bias_tensor = (torch.randn(1, N, dtype=dtype, device=device,
generator=generator) * 2 - 1) * scale + bias
return input_tensor, weight_tensor, bias_tensor
def run_all_gather_gemm(TP_GROUP):
transpose_weight = True
"""
* input: [M_per_rank, K] for all types
* weight: [K, N] if transpose_weight, [N, K] if not transpose_weight.
* bias: [1, N] for FP8 or INT8, [M_per_rank, N] for FP16 or BF16.
"""
local_M = 4096
K = 12288
N = 49152
dtype = torch.bfloat16 # torch.float16 # torch.bfloat16
input_tensor, weight_tensor, bias_tensor = make_data(
local_M, K, N, dtype, TP_GROUP
)
if not transpose_weight:
weight_tensor = weight_tensor.t().contiguous()
all_gather_gemm(
TP_GROUP,
input=input_tensor,
weight=weight_tensor,
bias=bias_tensor,
transpose_weight=transpose_weight,
ring_mode=None,
fast_acc=False,
use_pdl=False
)
def main():
TP_GROUP = initialize_distributed()
run_all_gather_gemm(TP_GROUP)
# run_gemm_rs()
@ray.remote(num_gpus=1)
def ray_launch():
TP_GROUP = initialize_distributed()
run_all_gather_gemm(TP_GROUP)
def ray_main():
num_gpus = int(sys.argv[1])
ray.init(num_gpus=num_gpus)
tasks = []
for i in range(num_gpus):
env_vars = {}
env_vars['RANK'] = str(i)
env_vars['WORLD_SIZE'] = str(num_gpus)
env_vars['MASTER_ADDR'] = 'localhost'
env_vars['MASTER_PORT'] = '12345'
task_func = ray_launch.options(
runtime_env={
'env_vars': env_vars
}
).remote()
tasks.append(task_func)
ray.get(tasks)
if __name__ == "__main__":
# it is ok to launch with torchrun
# NVSHMEM_DISABLE_CUDA_VMM=1 torchrun --nproc_per_node=2 -m lab.example
# main()
# fail to launch with ray
# NVSHMEM_DISABLE_CUDA_VMM=1 python3 -m lab.example 2
ray_main()
Expected behavior
Expect it to work well when using Ray.
Stack trace/logs
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/flux/lab/example.py", line 195, in <module>
ray_main()
File "/root/flux/lab/example.py", line 185, in ray_main
ray.get(tasks)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2771, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 919, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::ray_launch() (pid=20286, ip=10.189.110.50)
File "/root/flux/lab/example.py", line 166, in ray_launch
run_all_gather_gemm(TP_GROUP)
File "/root/flux/lab/example.py", line 145, in run_all_gather_gemm
all_gather_gemm(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/flux/lab/example.py", line 77, in all_gather_gemm
all_gather_gemm_kernel = flux.AGKernel(
RuntimeError: /root/flux-tune/src/ths_op/topo_utils.cc:245 Check failed: 2((gpu_device_ids.size())) == 1((gpu_device_set.size()))
Environment L20+python3.10+torch2.6+cuda12.4+flux@13e48df3
Proposed fix If you have a proposal for how to fix the issue state it here or link to a PR.
Additional context Add any other context about the problem here.
we do no test for ray. seems that ray will limit each process to only 1 GPU?
what is the recommended way to use flux with ray?