lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

High Peak Memory with CUDAGraphTransform

Open kshitij12345 opened this issue 1 year ago • 8 comments

Peak Memory is very high when CUDAGraphTransform is used.

# With CUDAGraphTransform - 27517.101568
# Without CUDAGraphTransform - 11917.129728

Example -

import torch
import thunder
import litgpt
from torch.testing import make_tensor
from functools import partial
from thunder.dynamo import ThunderCompiler
from thunder.transforms.cudagraph import CUDAGraphTransform

device = torch.device("cuda")


cfg = litgpt.Config.from_name("open_llama_3b", n_layer=10)
with device:
    make = partial(make_tensor, low=0, high=255, device=device, dtype=torch.long, requires_grad=False)
    shape = (1,) + (cfg.block_size,)

    x = make(shape)
    m = litgpt.GPT(cfg)

# m = thunder.jit(m)
m = thunder.jit(m, transforms=[CUDAGraphTransform()])

o = m(x)
o.sum().backward()

# With CUDAGraphTransform - 27517.101568
# Without CUDAGraphTransform - 11917.129728
print(torch.cuda.max_memory_allocated() / 1e6)

Tested with internal image dated 20241209 on RTX 6000 Ada.

kshitij12345 avatar Dec 09 '24 17:12 kshitij12345

Another Example:

Qwen2

import torch
from thunder.dynamo import ThunderCompiler
from transformers import AutoConfig, AutoModelForCausalLM
from thunder.transforms.cudagraph import CUDAGraphTransform
model_id = "Qwen/Qwen2.5-7B-Instruct"

configuration = AutoConfig.from_pretrained(
    model_id,
    num_hidden_layers=5,
)
configuration.hidden_size = configuration.num_attention_heads
with torch.device("cuda"):
    model = AutoModelForCausalLM.from_config(configuration).to(torch.bfloat16)

# backend = ThunderCompiler()
backend = ThunderCompiler(transforms=[CUDAGraphTransform()])
compiled_model = torch.compile(model, backend=backend)

input_ids = torch.randint(0, configuration.vocab_size, (1, 4096), device="cuda")

compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
compiled_output.loss.backward()

# Without CUDAGraphTransform - 13312.685568
# With CUDAGraphTransform - 26071.434752
print(torch.cuda.max_memory_allocated() / 1e6)

kshitij12345 avatar Dec 10 '24 01:12 kshitij12345

This is likely an issue that where we create static input buffers for things that we would not want to do that for (e.g. maybe saved for backward tensors when the forward is computed by the cuda graph?) The parameters seem to be correctly marked as not needing static input buffers.

t-vi avatar Dec 10 '24 05:12 t-vi

So there are two things actually:

  • the input buffers in the backward. This will improve on that:
class MyCUDAGraphTransform(CUDAGraphTransform):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.outputs_from_forward = None

    def transform_trace_post_optimization(self, trace, **kwargs):
        if trace.siginfo().name == 'backward_fn':
            # todo: have a backward tag
            assert self.outputs_from_forward is not None, "called on backward without forward before"
            # make this more generic or have an utility?
            assert len(trace.bound_symbols[2].args) == 2 and trace.bound_symbols[2].args[0].name == 'saved_for_backward'
            assert trace.bound_symbols[8].sym.name == 'unpack_sequence' and trace.bound_symbols[8].args[0] is trace.bound_symbols[2].output[0]
            saved_for_backwards_unpacked = trace.bound_symbols[8].output
            assert len(saved_for_backwards_unpacked) == len(self.outputs_from_forward)
            for (_, is_static), p_bw in zip(self.outputs_from_forward, saved_for_backwards_unpacked):
                if is_static:
                    p_bw.tags.add(thunder.core.proxies.ProxyTag.STATIC_MEMORY_LOCATION)
            self.outputs_from_forward = None

        new_trace = super().transform_trace_post_optimization(trace, **kwargs)

        if thunder.core.trace.TraceTag.AUGMENTED_FORWARD in new_trace.tags:
            assert self.outputs_from_forward is None, "called on augmented forward twice without backward in between"
            # apparently, it is safer to go by name than assume we have the same proxies here. :(
            cudagraph_output_names = set()
            for bsym in new_trace.bound_symbols:
                if bsym.sym.name.startswith('CUDAGraph'):
                    for o in bsym.flat_proxy_outs:
                        cudagraph_output_names.add(o.name)
            saved_for_backward = thunder.core.vjp_utils.get_saved_for_backward_tensors(new_trace)
            self.outputs_from_forward = [(o.name, o.name in cudagraph_output_names or thunder.core.proxies.ProxyTag.STATIC_MEMORY_LOCATION in o.tags) for o in saved_for_backward]
        return new_trace

If we decide the sharing is the right thing, we should do that.

The other part is that the memory for the graphs are not currently shared. I thought it would be as easy as initializing self.pool to None and then changing

https://github.com/Lightning-AI/lightning-thunder/blob/d3b22764def5acab1c27f9155e0c9f85a2f00cdf/thunder/transforms/cudagraph.py#L98-L100

to

        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, pool=self.pool, stream=stream):
            static_outputs = fn(*static_inputs)
        if self.pool is None:
          self.pool = graph.pool() 

but this does not do it (and makes things worse instead?)...

I would highly appreciate any hint how this should work.

t-vi avatar Dec 17 '24 17:12 t-vi

So I think we may have a more basic problem here than what cudagraphs do, that our tooling works with function calls that does not let the args go out of scope (similar to the backward itself wanting to get mutable collections as inputs).

t-vi avatar Mar 05 '25 19:03 t-vi

Most of the args to the jitted functions are the model parameters, so they need to stay around anyways. (I have probably misunderstood the above point😅)

kshitij12345 avatar Mar 05 '25 22:03 kshitij12345

The CUDA Graph static memory model should cause this behaviour. Apart from the static inputs (parameters etc...) we allocate static buffers for graph replay: https://github.com/Lightning-AI/lightning-thunder/blob/ab067fdfbe11fa83b779d7b09ad8c8c1ccf24dc0/thunder/transforms/cudagraph.py#L93

Profiling a sample model, we can recognise that backward inputs that are dynamic cause this memory spike (primarily due to the fw last activation). Comparing:

cfg = litgpt.Config.from_name("open_llama_3b", n_layer=1)
with device:
    make = partial(make_tensor, low=0, high=255, device=device, dtype=torch.long, requires_grad=False)
    shape = (10,) + (cfg.block_size,)

    x = make(shape)
    m = litgpt.GPT(cfg)

with:

cfg = litgpt.Config.from_name("open_llama_3b", n_layer=1)
with device:
    make = partial(make_tensor, low=0, high=255, device=device, dtype=torch.long, requires_grad=False)
    shape = (1,) + (cfg.block_size,)

    x = make(shape)
    m = litgpt.GPT(cfg)

We observe a peak memory difference of ~10x (as the batch size difference).

This can be verified here (the thick bands represent the fw pass output tensor of size [B=10, 2048, 32000]): no graph: Image

cuda graph (reading the pickle logs, it's clear that the 2.5GB orange allocation ([10, 2048, 32000]) comes from get_static_buffer; the yellow allocation should be the static_output of the fw pass): Image

mem copy before replaying the graph:

Image

If we want to decrease the peak memory consumption when CUDA graphs are being used, we must go to a lower level and make the FW output tensor data location static during graph capture to allow the backward pass to read from such a data location directly (maybe it's trivial, I am not familiar with the torch allocator right now). Without thunder, we capture fw and bw together like:

with torch.cuda.graph(g):
    static_y_pred = model(static_input)
    static_loss = loss_fn(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()

The output of fw is directly used by bw without requiring an additional tensor copy.

In inference mode, we observe pretty much the same peak values, further confirming this thesis.

cc @kshitij12345 @t-vi @IvanYashchuk for vis

Appendix:

Repro:

import thunder
from thunder.transforms.cudagraph import CUDAGraphTransform
import torch
import time
import litgpt
from torch.testing import make_tensor
from functools import partial

grad = False

tdtype = torch.bfloat16
device = torch.device("cuda")

cfg = litgpt.Config.from_name("open_llama_3b", n_layer=10)
with device:
    make = partial(make_tensor, low=0, high=255, device=device, dtype=torch.long, requires_grad=False)
    batch_size = 1  # This is the batch size
    shape = (batch_size, cfg.block_size)

    x = make(shape)
    print('x size in MB', x.numel() * x.element_size() / (1024 * 1024))
    m = litgpt.GPT(cfg).requires_grad_(grad)


def warmup(model, name):
    # Warm-up runs (build caches, allocate persistent buffers)
    torch.cuda.cudart().cudaProfilerStart()
    for i in range(4):
        print(f"Warmup {name} {i}")
        # torch.cuda.empty_cache()
        torch.cuda.nvtx.range_push(f"warmup_{name}_{i}")
        res = model(x)
        if grad:
            res.sum().backward()
        torch.cuda.nvtx.range_pop()

_mb = 1024 * 1024

def run(model, name):
    if name == 'plain':
        torch.cuda.cudart().cudaProfilerStart()
    # torch.cuda.empty_cache()

    max_peak_alloc_bytes = 0
    max_peak_reserved_bytes = 0
    max_cur_alloc_bytes = 0
    max_cur_reserved_bytes = 0
    max_time_ns = 0

    for i in range(10):
        # Measure peak GPU memory (allocated and reserved)
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats(device=None)

        torch.cuda.nvtx.range_push(f"forward_{name}_{i}")
        t0 = time.perf_counter_ns()
        res = model(x)
        torch.cuda.nvtx.range_pop()
        # Report forward memory usage separately
        torch.cuda.synchronize()
        fwd_peak_alloc_bytes = torch.cuda.max_memory_allocated(device=None)
        fwd_peak_reserved_bytes = torch.cuda.max_memory_reserved(device=None)
        fwd_cur_alloc_bytes = torch.cuda.memory_allocated(device=None)
        fwd_cur_reserved_bytes = torch.cuda.memory_reserved(device=None)
        print(
            f"Forward mem - peak alloc/res (MB): {fwd_peak_alloc_bytes / _mb:.2f} / {fwd_peak_reserved_bytes / _mb:.2f}; "
            f"current alloc/res (MB): {fwd_cur_alloc_bytes / _mb:.2f} / {fwd_cur_reserved_bytes / _mb:.2f}"
        )
        if grad:
            torch.cuda.synchronize()
            torch.cuda.reset_peak_memory_stats(device=None)
            torch.cuda.nvtx.range_push(f"backward_{name}_{i}")
            # print('res size in MB', res.numel() * res.element_size() / (1024 * 1024))
            # print('res data ptr', res.data_ptr())
            res.sum().backward()
            torch.cuda.nvtx.range_pop()
            # Report backward memory usage separately
            torch.cuda.synchronize()
            bwd_peak_alloc_bytes = torch.cuda.max_memory_allocated(device=None)
            bwd_peak_reserved_bytes = torch.cuda.max_memory_reserved(device=None)
            bwd_cur_alloc_bytes = torch.cuda.memory_allocated(device=None)
            bwd_cur_reserved_bytes = torch.cuda.memory_reserved(device=None)
            print(
                f"Backward mem - peak alloc/res (MB): {bwd_peak_alloc_bytes / _mb:.2f} / {bwd_peak_reserved_bytes / _mb:.2f}; "
                f"current alloc/res (MB): {bwd_cur_alloc_bytes / _mb:.2f} / {bwd_cur_reserved_bytes / _mb:.2f}"
            )

        torch.cuda.synchronize()
        t1 = time.perf_counter_ns()
        peak_alloc_bytes = torch.cuda.max_memory_allocated(device=None)
        peak_reserved_bytes = torch.cuda.max_memory_reserved(device=None)
        cur_alloc_bytes = torch.cuda.memory_allocated(device=None)
        cur_reserved_bytes = torch.cuda.memory_reserved(device=None)

        if peak_alloc_bytes > max_peak_alloc_bytes:
            max_peak_alloc_bytes = peak_alloc_bytes
        if peak_reserved_bytes > max_peak_reserved_bytes:
            max_peak_reserved_bytes = peak_reserved_bytes
        if cur_alloc_bytes > max_cur_alloc_bytes:
            max_cur_alloc_bytes = cur_alloc_bytes
        if cur_reserved_bytes > max_cur_reserved_bytes:
            max_cur_reserved_bytes = cur_reserved_bytes
        elapsed_time_ns = t1 - t0
        if elapsed_time_ns > max_time_ns:
            max_time_ns = elapsed_time_ns

    # torch.cuda.cudart().cudaProfilerStop()
    return max_peak_alloc_bytes, max_peak_reserved_bytes, max_cur_alloc_bytes, max_cur_reserved_bytes, max_time_ns

to_mb = lambda b: b / (1024 * 1024)

jm = thunder.jit(m)
torch.cuda.memory._record_memory_history()
# warmup(jm, "plain")
print("Without CUDAGraphTransform:")
peak_alloc_bytes, peak_reserved_bytes, cur_alloc_bytes, cur_reserved_bytes, time_ns = run(jm, "plain")
try:
    print(f"Capturing memory snapshot")
    torch.cuda.memory._dump_snapshot(f"plain.pickle")
except Exception as e:
    print(f"Failed to capture memory snapshot {e}")
torch.cuda.memory._record_memory_history(enabled=None)
print(f"Time: {time_ns / 1000000:.2f}ms")
print(f"Peak allocated (MB): {to_mb(peak_alloc_bytes):.2f}")
print(f"Peak reserved  (MB): {to_mb(peak_reserved_bytes):.2f}")
print(f"Current allocated (MB): {to_mb(cur_alloc_bytes):.2f}")
print(f"Current reserved  (MB): {to_mb(cur_reserved_bytes):.2f}")

# print(thunder.last_traces(jm)[-1])

torch.cuda.synchronize()
transforms = [CUDAGraphTransform(share_mem_pool=True)]
jm = thunder.jit(m, transforms=transforms)
torch.cuda.memory._record_memory_history()
# warmup(jm, "graph")
print("With CUDAGraphTransform:")
peak_alloc_bytes, peak_reserved_bytes, cur_alloc_bytes, cur_reserved_bytes, time_ns = run(jm, "graph")
try:
    print(f"Capturing memory snapshot")
    torch.cuda.memory._dump_snapshot(f"graph.pickle")
except Exception as e:
    print(f"Failed to capture memory snapshot {e}")
torch.cuda.memory._record_memory_history(enabled=None)
print(f"Time: {time_ns / 1000000:.2f}ms")
print(f"Peak allocated (MB): {to_mb(peak_alloc_bytes):.2f}")
print(f"Peak reserved  (MB): {to_mb(peak_reserved_bytes):.2f}")
print(f"Current allocated (MB): {to_mb(cur_alloc_bytes):.2f}")
print(f"Current reserved  (MB): {to_mb(cur_reserved_bytes):.2f}")

# print(thunder.last_traces(jm)[-1])
# print(thunder.last_backward_traces(jm)[-1])

mattteochen avatar Oct 10 '25 15:10 mattteochen

Excellent profiling work and attention to details, great deep dive!

Just one small clarification: the conclusion that the copied input to backward is the output of forward isn’t quite right. What’s actually passed into backward is the "output_grad", not the forward output tensor itself. When capturing backward with CUDA Graphs, it's important to start from the loss so the input is a scalar, not a large tensor (that would need to be copied).

This is similar to how fused cross entropy loss functions combine forward and backward in one step, preventing materialization of large intermediate tensors and improving memory efficiency.

A further thing worth double-checking is the treatment of all "saved for backward" tensors. In Thunder and PyTorch, these are often static, referenced in memory and do not require a copy by default if the graph and intermediate shapes remain stable. However, if Thunder pass doesn't mark them static, extra copying may happen. It's key to verify that all tensors marked as "saved for backward" remain static and reusable in the CUDA graph to ensure maximum efficiency.​ Do we have tests for that?

Aligning backward input, fused losses, and confirming static memory for all saved tensors will keep the CUDA graph replay lean and fast.

IvanYashchuk avatar Oct 16 '25 10:10 IvanYashchuk

Excellent profiling work and attention to details, great deep dive!

Just one small clarification: the conclusion that the copied input to backward is the output of forward isn’t quite right. What’s actually passed into backward is the "output_grad", not the forward output tensor itself. When capturing backward with CUDA Graphs, it's important to start from the loss so the input is a scalar, not a large tensor (that would need to be copied).

This is similar to how fused cross entropy loss functions combine forward and backward in one step, preventing materialization of large intermediate tensors and improving memory efficiency.

A further thing worth double-checking is the treatment of all "saved for backward" tensors. In Thunder and PyTorch, these are often static, referenced in memory and do not require a copy by default if the graph and intermediate shapes remain stable. However, if Thunder pass doesn't mark them static, extra copying may happen. It's key to verify that all tensors marked as "saved for backward" remain static and reusable in the CUDA graph to ensure maximum efficiency.​ Do we have tests for that?

Aligning backward input, fused losses, and confirming static memory for all saved tensors will keep the CUDA graph replay lean and fast.

Thank you for the clarification @IvanYashchuk . I wrongly stated that backward input is the output of the forward. The issue comes from cotangents being a tensor and not a scalar value. In my example, cotangets, which are the gradients of the forward output, have a size of 2.5GB. As this input is dynamic and not marked as static CUDA Graph allocates a buffer for them, leading to this tensor copy before replaying the graph.

I have also checked save_for_bw and these tensors are marked static where they should be.

What I see from this is:

  • CUDA Graph needs to be captured for forward and backward together (classical path)
  • We need to somehow start the backward capture from a scalar cotangent, which is not happening because the loss is computed externally from the jitted region in normal circumstances. A fused approach, as you've said, should solve this.

I am not familiar with thunder jitting the whole training loop, but giving a quick try, it seems not to work.

Currently, I am investigating if torch compile has some strategy in this regard, but facing some issue in running reduce-overhead mode.

mattteochen avatar Oct 21 '25 16:10 mattteochen