lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Add initial support for torch.utils.checkpoint

Open IvanYashchuk opened this issue 1 year ago • 4 comments

A checkpointed function doesn't save any intermediates from forward to backward. Instead, all required values are recomputed during the backward pass. Because less intermediates are saved, peak memory usage is usually decreased.

This PR introduces the support of recognizing torch.utils.checkpoint.checkpoint calls and inserting a new bound symbol in the initial trace. Then in the forward-backward generation pass this bound symbol is converted into augmented forward and backward parts of the computation. This step requires the function argument to thunder.torch.checkpoint be a Thunder function. Currently, there's no conversion PyTorch->Thunder implemented and this works only for simple functions that are both recognized by Thunder and PyTorch, for example when only methods are used.

The PyTorch function needs to be converted to a Thunder function in Thunder's JIT. Previously we could simply use thunder.preprocess which is not available today. When I attempted implementing a redispatching/reinterpretation of PyTorch functions using general_thunder_jit I hit the following bug: https://github.com/Lightning-AI/lightning-thunder/issues/1126.

Example:

import thunder
import torch

def f(x):
    return torch.utils.checkpoint.checkpoint(lambda x: x.sin().cos().exp(), x)

jf = thunder.jit(f)
x = torch.randn(3, 4, device="cuda", requires_grad=True)
jf(x).backward(x)
print(thunder.last_traces(jf)[-1])
print(thunder.last_backward_traces(jf)[-1])

Forward execution trace:

def augmented_forward_fn(x):
  # x: "cuda:0 f32[3, 4]"
  [t2] = nvFusion0(x)
    # t0 = prims.sin(x)  # t0: "cuda:0 f32[3, 4]"
    # t1 = prims.cos(t0)  # t1: "cuda:0 f32[3, 4]"
    # t2 = prims.exp(t1)  # t2: "cuda:0 f32[3, 4]"
  return {'output': t2, 'flat_args': [x], 'flat_output': (t2,)}, ((x,), ())

Backward execution trace:

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t3, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  x, = C0
  clear_mutable_collection(C0)
  del C0
  [t12] = nvFusion0(x, t3)
    # t4 = prims.sin(x)  # t4: "cuda:0 f32[3, 4]"
    # t11 = prims.cos(x)  # t11: "cuda:0 f32[3, 4]"
    # t5 = prims.cos(t4)  # t5: "cuda:0 f32[3, 4]"
    # t8 = prims.sin(t4)  # t8: "cuda:0 f32[3, 4]"
    # t6 = prims.exp(t5)  # t6: "cuda:0 f32[3, 4]"
    # t7 = prims.mul(t3, t6)  # t7: "cuda:0 f32[3, 4]"
    # t9 = prims.neg(t8)  # t9: "cuda:0 f32[3, 4]"
    # t10 = prims.mul(t7, t9)  # t10: "cuda:0 f32[3, 4]"
    # t12 = prims.mul(t10, t11)  # t12: "cuda:0 f32[3, 4]"
  del x, t3
  return (t12,)

IvanYashchuk avatar Sep 09 '24 14:09 IvanYashchuk

Why would we not just let the function be any function and have a state "currently checkpointing" that informs Thunder to add a tag to the proxies that are generated during the checkpointing instead? We would need to clear that tag on the outputs, but that would be easier than reentrant jit.

Do you have ideas about how the "currently checkpointing" approach would generalize to supporting, for example, torch.cond? Please continue in the issue https://github.com/Lightning-AI/lightning-thunder/issues/1134.

IvanYashchuk avatar Sep 10 '24 07:09 IvanYashchuk

I don't have immediate ideas, but I don't see that we should be having higher order functions right now. If anything it's the wrong sequencing.

t-vi avatar Sep 10 '24 08:09 t-vi

@IvanYashchuk You might wanna checkout selective activation checkpointing available in PyTorch nightlies: https://pytorch.org/docs/main/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts to specify which activations to save for backward.

syed-ahmed avatar Oct 02 '24 19:10 syed-ahmed

@IvanYashchuk You might wanna checkout selective activation checkpointing available in PyTorch nightlies: https://pytorch.org/docs/main/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts to specify which activations to save for backward.

Awesome, thanks for the link, Syed! Not a fan of ATen ops leaking into the PyTorch Python interface with torch.matmul becoming torch.ops.aten.mm.default, but I will check out how it could be recognized by Thunder.

IvanYashchuk avatar Oct 09 '24 14:10 IvanYashchuk