xla
xla copied to clipboard
How to minimize memory expansion due to padding during sharding
Hello
For a model that can be sharded in model parallelization in TPUv4 (4x32) device, I am getting the error below at the beginning of the training on TPUv3 (8x16) device. There is 4x expansion
with respect to console message. Even if both both TPUv4 and TPUv3 devices have same total memory I cannot run the training on TPUv3 device.
Program hbm requirement 15.45G:
global 2.36M
scoped 3.88M
HLO temp 15.45G (60.9% utilization: Unpadded (9.40G) Padded (15.44G), 0.0% fragmentation (5.52M))
Largest program allocations in hbm:
1. Size: 4.00G
Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
Unpadded size: 1.00G
Extra memory due to padding: 3.00G (4.0x expansion)
XLA label: broadcast.6042.remat3 = broadcast(bitcast.26), dimensions={2,3}
Allocation type: HLO temp
==========================
2. Size: 4.00G
Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
Unpadded size: 1.00G
Extra memory due to padding: 3.00G (4.0x expansion)
XLA label: broadcast.6043.remat3 = broadcast(bitcast.27), dimensions={0,3}
Allocation type: HLO temp
==========================
The lines that causes 4x expansion
is below:
def forward(self, x): # Activation map volume = 1,128,2048,1
...
...
x = torch.transpose(x, 1, 3) # Activation map volume = 1,1,2048,128
x_batch_0 = x.expand(2048, -1, -1, -1) # Activation map volume = 2048,1,2048,128
x_batch_1 = x.repeat_interleave(2048, dim=2).reshape(2048, 1, 2048, 128) # Activation map volume = 2048,1,2048,128
x_batch = torch.cat((x_batch_0, x_batch_1), dim=1) # Activation map volume = 2048,2,2048,128
...
...
Here are the sharding properties that I set.
mesh_shape = (num_devices, 1, 1, 1)
mesh = xs.Mesh(device_ids, mesh_shape, ('w', 'x', 'y', 'z'))
partition_spec = (0, 1, 2, 3) # Apply sharding along all axes
for name, layer in model.named_modules():
if ( 'conv2d' in name ):
xs.mark_sharding(layer.weight, mesh, partition_spec)
How can I prevent 4x expansion
?
One guess is that while the total amount of HBM are equal between v3 and v4, the HBM bandwidth of v4 is higher than that of v3 (https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4) so it can be one possible reason.
As you mentioned, one possible solution here will be to use v4, but if we want to further optimize the HBM usage on v3, next steps would be to get an HLO dump of the forward
function (it doesn't necessarily have to be using SPMD, it can just be a simple unit test with those ops) and try to optimize it.
cc @JackCaoG for visibility, in case you have any other recommendations.
@mfatih7 "TPUv4 (4x32) device, I am getting the error below at the beginning of the training on TPUv3 (8x16) device" how many hosts (or pod size) are you runing for each TPU generation? 4x32 means 32 hosts, and you are using v4-256 (appear as 128 devices)? And TPUv3 (8x16) device, you mean v3-128 (appear as 128 devices)?
Let's first confirm your pod sizes and that you are actually using the same number of chips in both cases.
Also,
mesh_shape = (num_devices, 1, 1, 1)
mesh = xs.Mesh(device_ids, mesh_shape, ('w', 'x', 'y', 'z'))
partition_spec = (0, 1, 2, 3) # Apply sharding along all axes
for name, layer in model.named_modules():
if ( 'conv2d' in name ):
xs.mark_sharding(layer.weight, mesh, partition_spec)
Naively sharding the parameters can incur extra collectives, which can cost more memory on v3 vs. v4. That will explain why, if you are indeed using the same number of chips.
Hello @yeounoh
Thank you for the answer.
Sorry for not sticking to your correct notation. I do not use devices as a Pod instead I train my models on single devices with multiple cores. I just want to express the memory size in my notation. I mean TPUv4 (4x32) means 4 cores each core has 32GB memory TPUv8 (8x16) means 8 cores each with 16 GB memory.
So the devices I use both have 128GB memory in total. When I use model-level parallelization on TPUv4, the biggest activation map on my model can fit into 32GB memory. But when I use model-level parallelization on TPUv3, the biggest activation map on my model does not fit into 16GB memory and I get the error above. I tried to learn if there was something that I could adjust to cancel out the 4x expansion shown on the error.
Hi @mfatih7 , thanks for the clarification. Yes, the model needs to fit in the HBM.
As to the culprit, padding, try
- keeping dim=1 multiple of 8 instead of 1
- also, try to do data formatting for your batch outside forward.