Problem with PyTorch Version 1.10
Hi, I am trying to reproduce the results. It works correctly with PyTorch 1.5, but with PyTorch 1.10 - Parsing Computation Graph with torch.jit failed and with manual parse_graph function it takes up twice as much GPU memory.
Output with PyTorch Version 1.10.0a0+0aef44c (nvcr.io/nvidia/pytorch:21.10-py3 docker container):
Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
Parsing Computation Graph with torch.jit failed, revert to manual parse_graph function
Building Division Tree
Getting Max Terms
Solving Optimal for Each Max Term
100%|████████████████████████████████████████████████████| 330/330 [00:02<00:00, 138.06it/s]
Solving optimal gradient checkpointing takes 2.7020 s
/opt/conda/lib/python3.8/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Parsed graph forward check passed
Run graph forward check passed
Parsed graph backward check passed
Run graph backward check passed
/opt/conda/lib/python3.8/site-packages/torch/cuda/memory.py:271: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
100%|█████████████████████████████████████████████████████| 100/100 [00:24<00:00, 4.12it/s]
100%|█████████████████████████████████████████████████████| 100/100 [00:30<00:00, 3.25it/s]
Average Iteration Time: Checkpointing 0.3082 s, Regular 0.2427 s, overhead 26.99%
Average Peak Memory: Checkpointing 5251.0508 MB, Regular 8157.9248 MB, Memory Cut off 35.63%
Average Intermediate Tensors: Checkpointing 1023.0098 MB, Regular 3929.8838 MB, Memory Cut off 73.97%
Output after commenting the "try" at https://github.com/lordfjw/OptimalGradCheckpointing/blob/main/benchmark.py#L167
Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
Traceback (most recent call last):
File "benchmark.py", line 212, in <module>
main(arch, device)
File "benchmark.py", line 168, in main
G, source, target = parse_computation_graph(net, inputs)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 34, in parse_computation_graph
computation_graph, input_node_ids, output_node_ids = parse_raw_computation_graph_from_jit(module, inputs)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 55, in parse_raw_computation_graph_from_jit
computation_graph, _, input_node_ids, output_node_ids = build_computation_graph_recursively(module, inputs, inputs_nodes_ids=None, outputs_nodes_ids=None, cur_node_idx=None)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 412, in build_computation_graph_recursively
internal_node_dicts = [parse_node_str(n) for n in graph_nodes]
File "/working_dir/OptimalGradCheckpointing/graph.py", line 412, in <listcomp>
internal_node_dicts = [parse_node_str(n) for n in graph_nodes]
File "/working_dir/OptimalGradCheckpointing/graph.py", line 162, in parse_node_str
shape = [int(s) for s in shape_str.split(', ')]
File "/working_dir/OptimalGradCheckpointing/graph.py", line 162, in <listcomp>
shape = [int(s) for s in shape_str.split(', ')]
ValueError: invalid literal for int() with base 10: 'strides=[2048'
Output with PyTorch Version 1.5.0a0+8f84ded (nvcr.io/nvidia/pytorch:20.03-py3 docker container)
Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
/opt/conda/lib/python3.6/site-packages/torch/tensor.py:746: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
Building Division Tree
Getting Max Terms
Solving Optimal for Each Max Term
100%|██████████████████████████████████████████████████| 350/350 [00:03<00:00, 95.77it/s]
Solving optimal gradient checkpointing takes 4.1945 s
/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Parsed graph forward check passed
Run graph forward check passed
Parsed graph backward check passed
Run graph backward check passed
100%|█████████████████████████████████████████████████████| 100/100 [00:19<00:00, 5.02it/s]
100%|█████████████████████████████████████████████████████| 100/100 [00:26<00:00, 3.80it/s]
Average Iteration Time: Checkpointing 0.2623 s, Regular 0.1983 s, overhead 32.28%
Average Peak Memory: Checkpointing 1524.6592 MB, Regular 4306.1680 MB, Memory Cut off 64.59%
Average Intermediate Tensors: Checkpointing 1145.3750 MB, Regular 3926.8838 MB, Memory Cut off 70.83%
Hi,
Thanks for reporting the issue. The peak memory in pytorch 1.10 seems quite different with 1.5. But the intermidiate tensors memory look similar.
I can see there is a constant increase in peak memory from 1.5 to 1.10 for both regular and checkpointing. Regular: 3852 MB = 8158 MB - 4306 MB Checkpointing: 3726 MB = 5251 MB - 1525 MB It looks like some stationary memory cost is introduced in pytorch 1.10. I will do some memory analysis in 1.10.
For the error in auto-parsing, the way we implemented auto-parsing is based on parsing the computation graphs (in strings) from the return of torch.jit.trace. This is a workaround and could be volatile when pytorch gets updated. In the short term, we will debug and maintain this auto-parsing function for latest version. In the longer term, we will implement to read the computation graph in C++ which is more stable and robust.
Thanks again for pointing out the issue.
I run this code with PyTorch Version 1.10. I replaced https://github.com/lordfjw/OptimalGradCheckpointing/blob/main/graph.py#L162 on
shape = [int(s) for s in shape_str.split(', ') if s.isdigit()]
and get normal output for resnet101
Average Iteration Time: Checkpointing 0.3115 s, Regular 0.2474 s, overhead 25.92%
Average Peak Memory: Checkpointing 1519.5342 MB, Regular 4305.2930 MB, Memory Cut off 64.71%
Average Intermediate Tensors: Checkpointing 1145.3750 MB, Regular 3931.1338 MB, Memory Cut off 70.86%
@karinaodm Say I want to modify the definition of segment cost, do you know where is the definition in the code?