lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Debug mode that associates backward ops to forward ops in the trace

Open carmocca opened this issue 1 year ago • 0 comments

🚀 Feature

Motivation

People that are not familiar with the autograd definitions for their models can have a hard time inspecting the backward trace generated by Thunder because it's not easy to say which section of backward operations relate to a specific forward operator.

Using a dummy example:

import thunder
import torch

device = torch.device("cuda")

model = torch.nn.Linear(100, 10, bias=False, device=device)
x = torch.randn(10, 100, device=device)

fn = thunder.jit(model)
out = fn(x)
print(out.shape)

fwd_trace, bwd_trace = thunder.last_traces(fn)
print(fwd_trace[-1].python())
print("=" * 20)
print(bwd_trace[-1].python())

Generates

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(t_0, t_weight):
  # t_0: "cuda:0 f32[10, 100]" 
  # t_weight: "cuda:0 f32[10, 100]" 
  t0 = torch.nn.functional.linear(t_0, t_weight, None)  # t0: "cuda:0 f32[10, 10]"
  return {'output': t0, 'flat_args': [t_0, t_weight], 'flat_output': (t0,)}, ((t_0,), ())
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection" 
  # cotangents: "Collection" 
  C0, \
  _, \
  = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t1, \
  = cotangents
  clear_collection(cotangents)
  del cotangents
  t_0, \
  = C0
  clear_collection(C0)
  del C0
  t6 = torch.reshape(t1, (-1, 10))  # t6: "cuda:0 f32[10, 10]"
  del t1
  t7 = torch.permute(t6, (1, 0))  # t7: "cuda:0 f32[10, 10]"
  del t6
  t8 = torch.reshape(t_0, (-1, 100))  # t8: "cuda:0 f32[10, 100]"
  del t_0
  t9 = torch.matmul(t7, t8)  # t9: "cuda:0 f32[10, 100]"
  del t7, t8
  return (None, t9)

Since there's only one operation in forward, the link between traces is obvious.

But for a trace like that of LitGPT, with a lot more diversity and sequences of Transformer blocks, inspecting the backward trace is a challenge.

Pitch

Since Thunder is all about debuggability, we could introduce a mode (assuming we don't want it to be default) where Thunder adds info about the linked forward operator. As an example, and extending the backward trace above:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection" 
  # cotangents: "Collection" 
  C0, \
  _, \
  = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t1, \
  = cotangents
  clear_collection(cotangents)
  del cotangents
  t_0, \
  = C0
  clear_collection(C0)
  del C0
  t6 = torch.reshape(t1, (-1, 10))  # t6: "cuda:0 f32[10, 10]", forward: t0: "torch.nn.functional.linear"
  del t1
  t7 = torch.permute(t6, (1, 0))  # t7: "cuda:0 f32[10, 10]", forward: t0: "torch.nn.functional.linear"
  del t6
  t8 = torch.reshape(t_0, (-1, 100))  # t8: "cuda:0 f32[10, 100]", forward: t0: "torch.nn.functional.linear"
  del t_0
  t9 = torch.matmul(t7, t8)  # t9: "cuda:0 f32[10, 100]", forward: t0: "torch.nn.functional.linear"
  del t7, t8
  return (None, t9)

This is useful for debugging a specific section of your trace and also for learning autograd

cc @carmocca

carmocca avatar Mar 08 '24 13:03 carmocca