torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[Question]: activation offload won't work for torch version < 2.5

Open Irvingwangjr opened this issue 8 months ago • 11 comments

Image the registered hook didn't trigger for torch version under 2.5, it works for torch 2.5 and 2.6. Could you kindly point out which PR change the behavior of registere_hook? We want to use this for torch 2.4

Irvingwangjr avatar Mar 04 '25 11:03 Irvingwangjr

hey @Irvingwangjr , glad that you are interested. We only guarantee that it works with torch nightlies and the latest torch release. Specifically for activation offloading, use_streams only work for >2.5: https://github.com/pytorch/torchtune/blob/80da6a5dae23a201595d07041c12ffde830332d7/torchtune/training/_activation_offloading.py#L43

I am not sure which PR changed the behavior of registered_hook :/

felipemello1 avatar Mar 04 '25 18:03 felipemello1

If there are more questions, please feel free to reopen the issue!

felipemello1 avatar Mar 04 '25 18:03 felipemello1

Many thanks for this! I also wanna report a bug here,

reproduction script

def test_offloading_works_with_cpu_tensors() -> None:

    class SomefuncNeedCpuTensors(torch.autograd.Function):
        @staticmethod
        def forward(ctx, cpu_tenosr):
            assert cpu_tenosr.device == torch.device("cpu")
            ctx.save_for_backward(cpu_tenosr)
            return torch.rand_like(cpu_tenosr)
        @staticmethod
        def backward(ctx, dy):
            corrupter = ctx.saved_tensors[0]
            assert corrupter.device == torch.device("cpu")
            return torch.rand_like(corrupter)
    def fwd(c):
        a = SomefuncNeedCpuTensors.apply(c)
        return a.sum()
    tensor_c = torch.ones(10102, 1024, device="cpu", requires_grad=True)
    ctx = OffloadActivations(use_streams=False)
    with ctx:
        loss_c = fwd(tensor_c)
    # delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
    ctx.fwd_stash = {}
    loss_c.backward()

for some ops, it needs some tensors to be on cpu: and current code ignore this and will bring then back to gpu as below

  # Kick off the process to bring tensors back
  with torch.cuda.stream(self.s1):
      gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
      maybe_gpu_tensor = gpu_tensor

I think a better way is store the device and bring them back to the specific device in unpack function

def pack_hook(x):
    print("pack_hook",x)
    return (x.device, x.cpu())

def unpack_hook(packed):
    print("unpack_hook",packed)
    device, tensor = packed
    return tensor.to(device)

Irvingwangjr avatar Mar 05 '25 09:03 Irvingwangjr

https://github.com/tgale96/grouped_gemm This ops is a real world example where the GMM ops need a parameter named batch_sizes and that needs to be cpu tensor

Irvingwangjr avatar Mar 05 '25 09:03 Irvingwangjr

@Irvingwangjr Ah, good point. Would it then make more sense to only offload if the tensor is not on CPU?

janeyx99 avatar Mar 05 '25 17:03 janeyx99

thanks @Irvingwangjr , i asked @janeyx99 if she has availability to take a look.

The way we use it in torchtune is that we only enable it with activation checkpointing + we pass the input tensor to the transformer block. This tensor is always on GPU (unless you are training the model on CPU, which doesnt make much sense).

Can you explain your use case anymore? Not sure if the bug you mentioned will ever occur in torchtune. Thank you!

felipemello1 avatar Mar 05 '25 17:03 felipemello1

thanks @Irvingwangjr , i asked @janeyx99 if she has availability to take a look.

The way we use it in torchtune is that we only enable it with activation checkpointing + we pass the input tensor to the transformer block. This tensor is always on GPU (unless you are training the model on CPU, which doesnt make much sense).

Can you explain your use case anymore? Not sure if the bug you mentioned will ever occur in torchtune. Thank you!

sure, I actually met this problem when I trying to integrate the activation checkpoint and offload to this ops: https://github.com/tgale96/grouped_gemm/blob/main/grouped_gemm/ops_test.py#L154

the parameter 'batch_sizes' is a cpu tensor indicated how the input is splited to groups.

and the current logic will break this since it will put the 'batch_sizes' to gpu and when doing the following re-computation, it will break for forward pass

since the MOE and dpskV3 is becoming a trend, torchtune might be also met this problem (if use this ops)

Irvingwangjr avatar Mar 06 '25 02:03 Irvingwangjr

@Irvingwangjr Ah, good point. Would it then make more sense to only offload if the tensor is not on CPU?

yeah I actually patch the code like this:

        def pack_tensor(activation: torch.Tensor) -> int:
            ...
            num_bytes = get_num_bytes_tensor(activation)
            tensor_id = get_tensor_id()
            device = activation.device
            if num_bytes >= self.min_tensor_size_bytes and (
                not isinstance(activation, torch.nn.Parameter)
                and not isinstance(activation, torch.nn.parameter.Buffer)
            ) and device.type !="cpu": 
                   .....
                   # offload ops
            else:
                self.tracker[tensor_id] = (
                    activation,
                    False,
                )  # False = not modified, tensor is as is

            return (tensor_id,device)

and for the unpack function, I think it doesn't need modification since cpu tensor will go to the 'not modified' branch. But I'm not sure what's the different when you call tensor.to("cuda") and tensor.to("cuda:0"); if that makes no different, I think we can ignore the device

Irvingwangjr avatar Mar 06 '25 02:03 Irvingwangjr

@Irvingwangjr if convenient can you check that this patch https://github.com/pytorch/torchtune/pull/2466 does the trick? I am specializing on CUDA here because our streaming logic only works in CUDA.

The difference between cuda and cuda:0 is that cuda:0 specifies the particular GPU, which can be limiting or unavailable dependent on user env so it is better to not use "cuda:0" specifically.

janeyx99 avatar Mar 06 '25 15:03 janeyx99

@Irvingwangjr if convenient can you check that this patch #2466 does the trick? I am specializing on CUDA here because our streaming logic only works in CUDA.

The difference between cuda and cuda:0 is that cuda:0 specifies the particular GPU, which can be limiting or unavailable dependent on user env so it is better to not use "cuda:0" specifically.

Look good to me !

But I still have question about the device. Lets say if a tensor is on device 'cuda:3'; then we move then to cpu and bring it back by calling tensor.to("cuda"), which device will the system bring? Does it depends on some env variables like local_rank?

Irvingwangjr avatar Mar 10 '25 02:03 Irvingwangjr

@Irvingwangjr "cuda" defaults to "cuda:0"

janeyx99 avatar Apr 28 '25 19:04 janeyx99