xla icon indicating copy to clipboard operation
xla copied to clipboard

How to minimize memory expansion due to padding during sharding

Open mfatih7 opened this issue 11 months ago • 4 comments

Hello

For a model that can be sharded in model parallelization in TPUv4 (4x32) device, I am getting the error below at the beginning of the training on TPUv3 (8x16) device. There is 4x expansion with respect to console message. Even if both both TPUv4 and TPUv3 devices have same total memory I cannot run the training on TPUv3 device.

Program hbm requirement 15.45G:
    global            2.36M
    scoped            3.88M
    HLO temp         15.45G (60.9% utilization: Unpadded (9.40G) Padded (15.44G), 0.0% fragmentation (5.52M))

  Largest program allocations in hbm:

  1. Size: 4.00G
     Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
     Unpadded size: 1.00G
     Extra memory due to padding: 3.00G (4.0x expansion)
     XLA label: broadcast.6042.remat3 = broadcast(bitcast.26), dimensions={2,3}
     Allocation type: HLO temp
     ==========================

  2. Size: 4.00G
     Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
     Unpadded size: 1.00G
     Extra memory due to padding: 3.00G (4.0x expansion)
     XLA label: broadcast.6043.remat3 = broadcast(bitcast.27), dimensions={0,3}
     Allocation type: HLO temp
     ==========================

The lines that causes 4x expansion is below:

def forward(self, x):   # Activation map volume = 1,128,2048,1
   ...
   ...
   x = torch.transpose(x, 1, 3)  # Activation map volume = 1,1,2048,128

   x_batch_0 = x.expand(2048, -1, -1, -1)  # Activation map volume = 2048,1,2048,128

   x_batch_1 = x.repeat_interleave(2048, dim=2).reshape(2048, 1, 2048, 128) # Activation map volume = 2048,1,2048,128

   x_batch = torch.cat((x_batch_0, x_batch_1), dim=1) # Activation map volume = 2048,2,2048,128

   ...
   ...

Here are the sharding properties that I set.

mesh_shape = (num_devices, 1, 1, 1)

mesh = xs.Mesh(device_ids, mesh_shape, ('w', 'x', 'y', 'z'))
partition_spec = (0, 1, 2, 3)  # Apply sharding along all axes

for name, layer in model.named_modules():
    if (  'conv2d' in name ):
       xs.mark_sharding(layer.weight, mesh, partition_spec)

How can I prevent 4x expansion?

mfatih7 avatar Mar 06 '24 15:03 mfatih7

One guess is that while the total amount of HBM are equal between v3 and v4, the HBM bandwidth of v4 is higher than that of v3 (https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4) so it can be one possible reason.

As you mentioned, one possible solution here will be to use v4, but if we want to further optimize the HBM usage on v3, next steps would be to get an HLO dump of the forward function (it doesn't necessarily have to be using SPMD, it can just be a simple unit test with those ops) and try to optimize it.

cc @JackCaoG for visibility, in case you have any other recommendations.

wonjoolee95 avatar Mar 06 '24 23:03 wonjoolee95

@mfatih7 "TPUv4 (4x32) device, I am getting the error below at the beginning of the training on TPUv3 (8x16) device" how many hosts (or pod size) are you runing for each TPU generation? 4x32 means 32 hosts, and you are using v4-256 (appear as 128 devices)? And TPUv3 (8x16) device, you mean v3-128 (appear as 128 devices)?

Let's first confirm your pod sizes and that you are actually using the same number of chips in both cases.

yeounoh avatar Mar 07 '24 18:03 yeounoh

Also,

mesh_shape = (num_devices, 1, 1, 1)

mesh = xs.Mesh(device_ids, mesh_shape, ('w', 'x', 'y', 'z'))
partition_spec = (0, 1, 2, 3)  # Apply sharding along all axes

for name, layer in model.named_modules():
    if (  'conv2d' in name ):
       xs.mark_sharding(layer.weight, mesh, partition_spec)

Naively sharding the parameters can incur extra collectives, which can cost more memory on v3 vs. v4. That will explain why, if you are indeed using the same number of chips.

yeounoh avatar Mar 07 '24 18:03 yeounoh

Hello @yeounoh

Thank you for the answer.

Sorry for not sticking to your correct notation. I do not use devices as a Pod instead I train my models on single devices with multiple cores. I just want to express the memory size in my notation. I mean TPUv4 (4x32) means 4 cores each core has 32GB memory TPUv8 (8x16) means 8 cores each with 16 GB memory.

So the devices I use both have 128GB memory in total. When I use model-level parallelization on TPUv4, the biggest activation map on my model can fit into 32GB memory. But when I use model-level parallelization on TPUv3, the biggest activation map on my model does not fit into 16GB memory and I get the error above. I tried to learn if there was something that I could adjust to cancel out the 4x expansion shown on the error.

mfatih7 avatar Mar 07 '24 18:03 mfatih7

Hi @mfatih7 , thanks for the clarification. Yes, the model needs to fit in the HBM.

As to the culprit, padding, try

  • keeping dim=1 multiple of 8 instead of 1
  • also, try to do data formatting for your batch outside forward.

yeounoh avatar Mar 15 '24 23:03 yeounoh