xla
xla copied to clipboard
Using CC ops with mark_sharding API throws an error.
🐛 Describe the bug
The crash seen is the following:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1709131242.311197 36940 hlo_sharding.cc:1034] Check failed: IsTuple() *** Check failure stack trace: ***
@ 0x7f1e46d752d9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7f1e40ed9700 xla::HloSharding::GetSubSharding()
@ 0x7f1e41cadd35 xla::ShardingPropagation::InferShardingFromOperands() @ 0x7f1e41cb1cec xla::ShardingPropagation::Run()::{lambda()#3}::operator()()
@ 0x7f1e41cb5d43 xla::ShardingPropagation::Run()
@ 0x7f1e41c98355 xla::HloPassPipeline::RunHelper()
@ 0x7f1e41c9933a xla::HloPassPipeline::RunPassesInternal<>()
@ 0x7f1e41c99fa4 xla::HloPassPipeline::Run()
@ 0x7f1e41100d49 neuron::HloOptimization()
@ 0x7f1e410a3ab9 neuron::Optimize()
@ 0x7f1e4109f07e neuron::PJRT_Client_Compile()
@ 0x7f1e410a0638 neuron::Decorator<>::wrapper()
@ 0x7f1e51d966c5 xla::InitializeArgsAndCompile()
@ 0x7f1e51d969e0 xla::PjRtCApiClient::Compile()
@ 0x7f1e4d3411e6 torch_xla::runtime::PjRtComputationClient::Compile()
@ 0x7f1e4d14853e torch_xla::XLAGraphExecutor::Compile()
@ 0x7f1e4d149f49 torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
@ 0x7f1e4d14a58b torch_xla::XLAGraphExecutor::SyncTensorsGraph()
@ 0x7f1e4d14a9b8 torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph()
@ 0x7f1e4cf1928a torch_xla::(anonymous namespace)::StepMarker()
@ 0x7f1e4cf196c6 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7f1e4cef6ed0 pybind11::cpp_function::dispatcher()
@ 0x5d5499 PyCFunction_Call
Aborted (core dumped)
A simple example to reproduce the bug is attached below:
import os
import numpy as np
import torch
import torch_xla
import torch_xla.runtime as xr
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["TF_CPP_VMODULE"] ='hlo_optimization=5'
# Enable XLA SPMD execution mode.
os.environ["XLA_IR_DEBUG"] = "1"
os.environ["XLA_FLAGS"]="--xla_force_host_platform_device_count=32 --xla_dump_hlo_as_text --xla_dump_hlo_as_proto --xla_dump_to=./xla_dump --xla_dump_hlo_pass_re='.*spmd.*'"
xr.use_spmd()
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt_backend
import torch_xla.experimental.pjrt as pjrt
os.environ['NEURON_RT_NUM_CORES']='32'
os.environ['NEURON_PJRT_PROCESS_INDEX'] = '0'
os.environ['NEURON_PJRT_PROCESSES_NUM_DEVICES'] = '32'
os.environ['WORLD_SIZE'] = '1'
num_devices = xr.global_runtime_device_count()
print(f'num device: {num_devices}')
mesh_shape = (1, num_devices)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
lin = torch.nn.Linear(8192, 32768, bias=False).to(xm.xla_device())
lin2 = torch.nn.Linear(32768, 8192, bias=False).to(xm.xla_device())
xs.mark_sharding(lin.weight, mesh, ('y', 'x'))
xs.mark_sharding(lin2.weight, mesh, ('x', 'y'))
lin.train()
lin2.train()
print(mesh.get_logical_mesh(), mesh.shape())
t1 = torch.randn(64, 8192).to(xm.xla_device())
t2 = lin(t1)
t3 = lin2(t2)
xs.mark_sharding(t3, mesh, (None, None))
xm.mark_step()
print(t3)
t4 = xm.all_reduce('sum', t3)
xm.mark_step()
print(t4)
`
```
### Versions
Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.1.2
[pip3] torch-neuronx==2.1.1.2.0.1b0
[pip3] torch-xla==2.1.2
[pip3] torchvision==0.16.2
I think what you are trying to do is
- use SPMD to shard the HLO belong to current pipeline stage
- use cc ops to communicate across all of the hosts
I think this won't work out of the box because
- under your SPMD setup PyTorch/XLA will only start a single process per host which owns all of the XLA devices in current host.
- However under your cc ops setup, you would need to start x process per host and init PJRT runtime in a way so it recognize all of the devices across the host.
The problem here is that I don't think there is an easy to change the PJRT device config on the go. @will-cromar @yeounoh in cases you guys has some better suggestions.
@baoleai I remembered you guys mentioned something about SPMD + pp, wondering if you guys has some insight as well.
Currently, SPMD cannot support communication operators at the Python layer. When combining SPMD-TP and PP, we made numerous changes to xla and the openxla spmd pass to support send/recv @yitongh . Supporting the allreduce communication operator might be more complicated.
Based on previous experience, you will need to do the following things on GPU:
- Support communication operations such as all-reduce on the Python side within SPMD. For example, support all-reduce in sharding_propagation.cc.
- When invoking NCCL communication, correctly handle the communication ranks for all-reduce, because the CollectiveOpGroupMode in the SPMD environment is different from that in the replicate mode, and some hack conversions are needed.
Even with the above handling, the all-reduce operator is currently not well-suited to handle sharded inputs and can only function as a replicated operation.
Similar handling may be required in the TPU environment. Overall, supporting Python-side communication in the SPMD environment doesn't seem to have any particularly elegant solutions at the moment. Perhaps, as JackCaoG suggested, changing the configuration of the PJRT device might be a good approach.
@baoleai @yitongh is the send/recv using XLA Send/Recv ? We are using all-reduce instead of send/recv to simplify our stack and we can assume that only non-sharded tensors will be passed in.
Can we use any way to skip sharding_propagation pass ? This can be an isolated graph (cut off using mark_Steps) and we can use custom_call or any attribute to skip "peeking" into the all-reduce
@JackCaoG For the cc ops set-up, why do we need ti set up PjRT in a different way? All we need is the graph with the correct replica groups correct? (this can be borrowed from mesh during SPMD set-up). The PjRT runtime would just execute this on all "threads" (we dont need these to be different processes ) and the all-reduce would look any other all-reduce from a SPMD partitioner pass.