xla
xla copied to clipboard
gradient checkpoint cause bigger memory usage on GPU
❓ Questions and Help
Recently I started testing GC performance on the GPU on the master version of pytorch and torch xla.
Unfortunately in consistent with my previous conclusions(https://github.com/pytorch/xla/issues/3455#issuecomment-1101056839), the current torch xla GC still seems have difficulty achieving the desired results, even with memory gains on some very simple cases.
Simple test program:
import argparse
import torch
import torch_xla.utils.checkpoint
import torch.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
def run(grad_checkpoint, use_cuda=False):
if use_cuda:
device = 'cuda'
else:
device = xm.xla_device()
model = torch.nn.ModuleList(
[
torch.nn.Sequential(
torch.nn.Conv2d(1024, 1024, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(1024, 1024, 1),
torch.nn.ReLU(),
)
for _ in range(64)
]
).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0)
for step in range(200):
dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
optimizer.zero_grad()
x = dummy_data
for n_l, layer in enumerate(model):
if n_l > 0 and grad_checkpoint:
if use_cuda:
x = torch.utils.checkpoint.checkpoint(layer, x)
else:
x = torch_xla.utils.checkpoint.checkpoint(layer, x)
else:
x = layer(x)
dummy_loss = x.sum()
dummy_loss.backward()
optimizer.step()
if not use_cuda:
xm.mark_step()
mem_info = xm.get_memory_info(device)
mem_info['kb_used'] = mem_info['kb_total'] - mem_info['kb_free']
print(f"step {step}, memory = {mem_info['kb_used']}")
else:
print(f"step {step}, memory = {torch.cuda.memory_summary()}")
if not use_cuda:
print(met.metrics_report())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--grad_checkpoint", type=int, required=True)
parser.add_argument("--use_cuda", type=int, required=True)
args = parser.parse_args()
run(args.grad_checkpoint, args.use_cuda)
Comparing GPU memory usage in four configs:
WITHOUT GC | WITH GC | |
---|---|---|
torch xla | 11561 MiB | 19753 MiB |
torch native | 7680 MiB | 4634 MiB |
running command:
TF_FORCE_GPU_ALLOW_GROWTH=true CUDA_VISIBLE_DEVICES=0 GPU_NUM_DEVICES=1 python3 ./test_gc.py --grad_checkpoint=0/1 --use_cuda=0/1
torch version: 1.12.0a0+git8abf37d
torch xla version: 5dae54bc53eb6c9a11eb4706fe01d1dfa557c14f
This zip contains the hlo dump results with GC enabled/disabled.
The number of model substructure cycles in the run that generated this dump was set to 4 instead of 64. But the conclusion is consistent, turning on GC on the GPU causes the memory to increase on this case.
Hmm, is there a way for you to verify the peak memory usage and check if that reduced with gc?
I dump the HLO for with checkpoint and without checkpoint case, will try to find someone to take a look.
I talked with Parker who is the author of the optimization_barrier
HLO, I think my implementation of the optimzation_barrier
has some flaw. I was doing
x1 = layer0.fwd(x0)
(x1,x0) = opt_barrier(x1, x0)
x2 = layer1.fwd(x1)
(x2,x1) = opt_barrier(x2, x1)
x3 = layer2.fwd(x2)
(x3,x2) = opt_barrier(x3, x2)
...
grad2 = layer2.bwd(x2, grad3)
grad1 = layer1.bwd(x1, grad2)
grad0 = layer0.bwd(x0, grad1)
which does gurantee that repeated computation in the backward wait until the corresponding fwd function to finish. However there is nothing prevented repeated computation got moved right after the correponding fwd.
for example it can be
x1 = layer0.fwd(x0)
(x1,x0) = opt_barrier(x1, x0)
layer0.repeated_computation_for_backward
....
grad1 = layer1.bwd(x1, grad2)
grad0 = layer0.remaining_bwd(x0, grad1)
....
which can be unideal. I should really do it like
x1 = layer0.fwd(x0)
x2 = layer1.fwd(x1)
x3 = layer2.fwd(x2)
...
x2, grad3 = opt_barrier(x2, grad3)
grad2 = layer2.bwd(x2, grad3)
x1, grad2 = opt_barrier(x1, grad2)
grad1 = layer1.bwd(x1, grad2)
x0, grad1 = opt_barrier(x0, grad1)
grad0 = layer0.bwd(x0, grad1)
which will fully guarantee the execution order. Not sure how much it will impact the real memory usage but I will try to implement this soon. FYI @ronghanghu
In another word, instead of binding the input and output of the fwd, I should bind the grad_input and input before performing the backward.
Another point that parker raised was that because the example is so simple, the XLA compiler is super adversarial. It is trying really hard to use those unused cores for something!
Let me first try to implement the change I propse above and see if that fixed the issue here.
@JackCaoG I see, thanks for the update on this!
@ronghanghu @cicirori Can you guys give https://github.com/pytorch/xla/pull/3721 a try?
Thanks @JackCaoG -- I'll try this out!