xla icon indicating copy to clipboard operation
xla copied to clipboard

Enable lazy tensor loading for sharded tensors

Open pgmoka opened this issue 6 months ago • 11 comments

Currently in PyTorchXLA when a tensor is initialized through sharding, it is loaded into its related devices immediately. From a logic point, mark_sharding is acting similarly to how calling .to('xla') in a tensor does. This eager loading is efficient, but it poses a problem when considering implementing SPMD+MPMD (What "Model initialization" in https://github.com/pytorch/xla/issues/9019 is referring to).

The goal of https://github.com/pytorch/xla/issues/9019 is to achieve pipeline parallelism where each thread runs a gSPMD "world" consisting of multiple devices. This changes the data loading requirement as we will only want doing that after the model is traced. Currently tracing happens when we compile the model.

We can look at TransferShardsToDevice in the call stack from XlaMarkSharding for how data is currently being loaded into the device. We need to see how is it that we might do the tracing of the model such as the prototype from https://github.com/GleasonK/torch_xla/blob/dtensor-exploration/docs/source/perf/distributed_pipelining.ipynb does before we start doing data loading.

Note that we still want to allow for eager data loading to exist as it is generally more efficient for most cases. This will be something specific to SPMD+MPMD.

Potential investigation on the specifics of how Ray works as an alternative to the MPMD IR strategy might impact this. We can follow-up here later as investigation continues.

pgmoka avatar Jun 11 '25 21:06 pgmoka

@benawilson will be taking a look at this issue.

pgmoka avatar Jun 12 '25 18:06 pgmoka

Initial design concept:

Currently, if an input tensor has data, XlaMarkSharding will extract this to a cpu_tensor, and then will use CreateTensorsData along with the provided sharding spec to load this data onto the device. This is eager loading, and we want to retain this behavior as the default.

To allow for lazy sharded loading, on an opt-in basis, we will need to add a new argument lazy to both the Python mark_sharding and C++ XlaMarkSharding APIs, defaulted to false for backwards-compatibility. When enabled, we want to delay the data transfer until after compilation, but immediately before execution.

In this case, rather than updating the tensor to have an XLATensor::Data defined by a BackendDataPtr, we can instead update the XLATensor::Data to have an IR Value wrapping an XlaNode. This node would need to be a new subclass, LazyDeviceData, which would contain all the necessary information to call CreateTensorsData and do the final host-to-device transfer (aten tensor, shape/dtype info, device, etc).

During compilation, the lowering for LazyDeviceData should emit identical HLO/StableHLO to the existing DeviceData node, including hashes. The existing argument warm_up_cache_only in XlaGraphExecutor::SyncTensorsGraph can be used to identify compile-only fake executions.

For an actual execution, we need to identify any and all LazyDeviceData nodes, and get their data onto the device. Currently, special handling for DeviceData nodes exists within XlaGraphExecutor::CollectSyncTensors; by forwarding warm_up_cache_only and identifying LazyDeviceData, we should be able to call CreateTensorsData on all LazyDeviceData nodes at once to create their sharded data. At that point we can mark the LazyDeviceData nodes as resolved, and can handle them identically to DeviceData for the remainder of the execution (and future executions).

benawilson avatar Jun 12 '25 21:06 benawilson

Possible issue with the above design:

As noted above, we only should transfer the data to device once. When execution is requested we need to initiate and complete the transfer, swap the placeholder data_ member field of the node with a reference to the newly-created buffer on device, and then also mark the node as resolved.

What is presently unclear is whether this mutation is safe to do once an IR graph has been built up with references to this initial node. Normally the IR graph is built additively and immutably; modifying the graph post-creation risks creating unresolvable cycles. While this modification (swapping a "pending" buffer with a "ready" buffer) is theoretically safe, the assumption of immutability may be present in other IR node class definitions in ways that this design could violate.

benawilson avatar Jun 16 '25 20:06 benawilson

On further investigation, we might not need (or want) to use a new IR value after all.

An XLATensor::Data can be constructed by:

  • A BackendDataPtr, representing either a placeholder buffer to receive a computation output, or a defined buffer input
  • An IR node (torch::lazy::Value) representing a pending computations.
  • An aten::Tensor named "tensor_data". In torch_xla, this is enforced to be on the CPU device (with an XLA_CHECK_EQ for kCPU). The setter methods imply these are mutually exclusive states by design.

Currently XLAMarkSharding will create a BackendDataPtr; the initial proposal above was to move to an IR node, but the at::Tensor option may also be feasible, and might be preferred. There already exists a well-structured call to CreateTensorsData in CollectSyncTensors.

This "CPU to device" behavior behind a conditional guard for config.force_ltc_data. Going through the main SyncLiveTensorsGraph will always set this to be true, but it can be disabled by going through the lower-level SyncTensorsGraph and specifying either sync_ltc_data=false or warm_up_cache_only=true.

This implies that a lazy initialization can be achieved simply by constructing "tensor_data"-style values in XlaMarkSharding, instead of the BackendDataPtr values, at which point the behavior of warm_up_cache_only will control the compile-only vs compile-and-execute streams as mentioned.

The primary question is whether force_ltc_data is doing more than just this. For example, this comment in CollectSyncTensors:

  // The force_ltc_data controls aliasing compilation, so effectively the same
  // graph with on/off force_ltc_data should not match, hash wise.
  coll.hash = torch::lazy::MHash(config.force_ltc_data);

This implies there's some interaction between this existing flag and the the aliasing mechanics; there's also a long comment block in XLAGraphExecutor::SetTensorData. Some additional guards or config options may be necessary to not break associated functionality when we just want to delay CPU to device loading.

But, this seems like a lower-risk solution than the IR-based mechanism due to the (im)mutability concern.

benawilson avatar Jun 16 '25 22:06 benawilson

The base class for torch::lazy::LazyGraphExecutor accepts an argument for sync_ltc_data, but the torch_xla implementation of XlaGraphExecutor makes use of two arguments: sync_ltc_data and warm_up_cache_only. The possible combination of arguments are:

  • sync_ltc_data && !warm_up_cache_only: this sets force_ltc_data=true, meaning this is a full sync, compile (or use cache), and execute. This is the primary setup when called from Python, as it is used from torch_xla.sync() through _XLAC._xla_step_marker, and many other places.
  • sync_ltc_data && warm_up_cache_only: this sets force_ltc_data=false, and the value of sync_ltc_data is otherwise not used. This means no data will be synced, and the program will be compiled but not executed. This is execution path is not exposed at the Python level; there is no binding in _XLAC which can be configured this way.
  • !sync_ltc_data && !warm_up_cache_only: This sets force_ltc_data=false. This will not sync aten tensor data to device, but will try to compile and run the code. This execution path is exposed through _XLAC through _xla_sync_multi, and this setting is used in both torch_xla/utils/serialization and torch_xla/experimental/gradient_accumulation. But, the main public API of torch_xla.xla_model.unlazy does not give this option, it always forces (through defaulting) sync_ltc_data=true.
  • !sync_ltc_data && warm_up_cache_only: this sets force_ltc_data=false and will not do any device transfers or executions. This is exposed in _XLAC through _xla_warm_up_cache, and this is used in torch_xla._dynamo.dynamo_bridge.extract_graph_helper, but there is no main public Python API.

The commenting about aliasing refers to behavior in XlaGraphExecutor::GetBufferDonors. There is a very long comment in that function, but the simple explanation is that buffer aliasing is only safe to do after a device sync, at the step_marker. This is why the compilation hash includes the value of force_ltc_data; compiling with a sync vs without a sync changes what buffers are available.

This suggest that if we allow mark_sharding/XlaMarkSharding to return an XLATensor containing at::Tensor tensor_data rather than BackendDataPtr handle, then torch_xla.sync() (or mark_step()) will automatically do the device transfer in the existing code in CollectSyncTensors. Using a "warm up" execution path will compile a graph, but the graph will have a different hash and (if XLA_ENABLE_PARAM_ALIASING=1 or is unset) may compile the HLO differently due to aliasing differences. And the behavior of serialization and gradient accumulation is unclear--what would happen if data is not synced to device and you try to execute the graph anyway?

benawilson avatar Jun 17 '25 17:06 benawilson

What would happen if data is not synced to device and you try to execute the graph anyway?

The lazy aten operations are defined using XlaNodes, which are always constructed from other IR Values. Operands to aten ops are first cast from at::Tensor to XlaTensorPtrs using either GetXlaTensor or sometimes GetOrCreateXlaTensor. When this results in a creation, this will use the tensor_data construction of the XlaTensor.

Once XLATensorPtrs for all operands are acquired, then GetIrValue is called on them.

  • If the XlaTensor has IR, it's returned.
  • If does not have IR but does have a BackendDataPtr, then it is made into a DeviceData node via torch::lazy::LazyTensor::CreateTensorNode.
  • If it only has tensor data, then this tensor data is transferred to device via TensorToXlaData to create a BackendDataPtr which is then used to create a DeviceData IR node.

This means that it is not possible to construct an IR graph with reference to data not on the XLA device. Attempting to do so will immediately trigger a CPU-to-XLA device transfer on the tensor data operands.

The behavior of serialization and gradient accumulation is unclear

The behavior in torch_xla.utils.serialization.save is similar to that in xla_model.save. In this case, the objective of the sync is to get the tensor data on the CPU, which requires completing all IR graphs. Since tensor_data already lives on the CPU, there is no reason to sync as this would just be a round-trip of CPU -> XLA -> CPU.

For gradient accumulation, _xla_sync_multi is used only on the inputs to the gradient accumulation body, before any IR graph is defined. Since the user-provided train_step function will presumably construct an IR graph, any CPU-hosted tensor_data will be forced to device at that time.

benawilson avatar Jun 18 '25 16:06 benawilson

What this implies for the design here:

  • We can add an API option to allow mark_sharding to be called with lazy=True from Python, and pass this through to XlaMarkSharding in C++.
  • When lazy=true, we change the behavior of XlaMarkSharding from using CreateTensorsData and SetXlaData, to instead use XlaTensor::Create on the cpu_tensor (using the virtual device as the device, and passing the sharding spec in the constructor as well).
  • This allows a tensor to be marked as sharded but only until it is first used in an aten op. As soon as it is used, even before the IR output node is created, there will be a host-to-device transfer that creates the sharded layout.

I'm not sure if this suits AWS's requirements. If the tracing produces new IR nodes, then even if loading is delayed in mark_sharding, it might still be loaded as soon as it it traced. Or, it might not; if FX tracing does not dispatch fully to torch_xla code and stops at the torch level, then it might be sufficient.

benawilson avatar Jun 18 '25 17:06 benawilson

Complication: XlaMarkSharding is an in-place operation. The input is a const at::Tensor& which is unwrapped into an XLATensorPtr that is modified inplace. There is a setter for XLATensor::SetTensor which is suitable to override the tensor_data value with the cpu_tensor.

However, XLATensor::Data subclasses torch::lazy::LazyTensor::Data, which has const BackendDevice device. This means that the input XLATensor cannot have its device changed to be the SPMD virtual device.

This is an issue because XLATensor::GetIRValue uses XLATensor::GetDevice to determine the device to transfer to, but for the purposes of XlaMarkSharding we want this to always be the virtual device "SPMD:0".

We will need some way to indicate that the TransferToXlaDevice within GetIRValue should sometimes transfer to this virtual device when transferring its CurrentTensorData.

benawilson avatar Jun 18 '25 17:06 benawilson

Prior complication is mostly a non-issue; the intended API is for torch_xla.runtime.use_spmd() to be called before any XLA tensors are initialized, with an existing warning about this. If users try to mark_sharding before this is called, then XlaMarkSharding will error on an XLA_CHECK. If a non-replicated tensor is created before torch_xla.runtime.use_spmd() is called, it will be forcibly replicated on all devices.

This means that at any time time that XlaMarkSharding is not an error, the BackendDevice for all XLATensors will should already be SPMD:0, and we don't need to store any extra information.

However, we do still need to ensure that XLATensor::GetXlaData, XLATensor::GetIRValue, and similar sites that currently use TransferToXlaDevice with GetDevice() have a conditional switch to preserve sharding when the device is virtual and there is a sharding on the XLATensor.

benawilson avatar Jun 23 '25 21:06 benawilson

Follow-up to my prior comment:

I'm not sure if this suits AWS's requirements. If the tracing produces new IR nodes, then even if loading is delayed in mark_sharding, it might still be loaded as soon as it it traced. Or, it might not; if FX tracing does not dispatch fully to torch_xla code and stops at the torch level, then it might be sufficient.

The original RFC request was:

We want to ensure that the model initialization on PyTorch/XLA does not capture the entire model on the device for each stage. In order to avoid this, we want to extend the existing meta device context to only capture the metadata, ensuring that we can use with the XLA device interchangeably. Once a model is traced, we can then send the sharded data to the device.

This implies that the intended use case is to create the model on the PyTorch Meta device first, before moving it to the XLA device. This would guarantee zero data transfer to the XLA device, as a PyTorch/aten meta-tensor has no backing data buffer to transfer. This would also mean that the

Currently, this is impossible due to the at::kCPU requirement in XLATensor::Create. torch_xla implicitly assumes that any at::Tensor tensor_data within an XLATensor contains a real data buffer, that can be converted to an AtenSource for use in TensorToXlaData or CreateTensorsData.

XLATensors with placeholder BackendDataPtrs is already mostly supported. The Python API create_placeholder_tensor accomplishes this. If torch_xla.runtime.use_spmd() has been enabled, this will create a replicated tensor on the "SPMD:0" virtual data.

However, trying to use mark_sharding on this tensor will fail; because the tensor has no backing data buffer (PjRtBuffer or xla::ifrt::Array), an XLA_CHECK in XlaMarkSharding will throw an error.

We can improve the user experience around these placeholder tensors with a few targeted changes:

  1. Relax the at::kCPU constraint in XLATensor::Create to also accept an at::kMeta input tensor.
  2. At all callsites which presently construct AtenSources and use TransferToDevice, add detection for meta at::Tensors and use CreateDataPlaceholder or their sharded equivalents to construct the requisite torch::lazy::BackendDataPtrs.
  3. Support meta tensors and placeholder data tensors in XLAMarkSharding. In this case, instead of CreateTensorsData, use CreateDataPlaceholder on the virtual device, with a sharding argument.
  4. At the Python level, add optional parameters for Mesh and PartitionSpec to create_placeholder_tensor, as a convenience method for constructing sharded placeholders directly. (This will be equivalent to calling create_placeholder_tensor and mark_sharding back-to-back.)

This also means the "lazy" argument proposed above is irrelevant; we can retain eager-loading semantics for CPU tensors, and leave the API for mark_sharding unchanged. This avoids the counterintuitive behavior of seemingly unrelated ops (like torch.add) implicitly performing a device transfer via GetIrValue.

benawilson avatar Jun 25 '25 15:06 benawilson

Supporting meta initialization of the IR graph introduces a new problem; un-executable IR graphs.

The current implementation enforces that all DeviceData nodes are backed by real data (as PjRtBuffers). The intended data flow path is XLATensor::Create from an at::kCPU value, then that data is transferred to the XLA device either from GetIrValue or XlaMarkSharding. The IR graph is traced during torch::lazy::LazyGraphExecutor::RunPostOrder, and the underlying buffers are retrieved in PjRtComputationClient::Execute and ExecuteReplicated using PjrtData->buffer->get().

Constructing placeholder inputs would mean that we would generate an IR graph with root DeviceData nodes for which no data exists, backed by nullptr PjRtBuffers. This would cause torch::lazy::LazyGraphExecutor::RunPostOrder to do unnecessary TensorCollectionBarrier syncs (essentially waiting for pending computations that do not exist). More importantly, it would cause PostOrderData::parameter_data to be filled with bufferless data pointers, which would cause PjRtComputationClient::Execute and ExecuteReplicated to fill buffers/argument_handles with nullptrs, likely causing undefined behavior.

To prevent this, we would need to add an XLA_CHECK to these execution paths to ensure that if DataPtr->HasValue() is false for any argument, then it would raise an exception. This is a good idea regardless; an execution path currently exists using create_placeholder_tensor from Python to create invalid IR graphs.

benawilson avatar Jun 26 '25 16:06 benawilson