Debug mode that associates backward ops to forward ops in the trace
🚀 Feature
Motivation
People that are not familiar with the autograd definitions for their models can have a hard time inspecting the backward trace generated by Thunder because it's not easy to say which section of backward operations relate to a specific forward operator.
Using a dummy example:
import thunder
import torch
device = torch.device("cuda")
model = torch.nn.Linear(100, 10, bias=False, device=device)
x = torch.randn(10, 100, device=device)
fn = thunder.jit(model)
out = fn(x)
print(out.shape)
fwd_trace, bwd_trace = thunder.last_traces(fn)
print(fwd_trace[-1].python())
print("=" * 20)
print(bwd_trace[-1].python())
Generates
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(t_0, t_weight):
# t_0: "cuda:0 f32[10, 100]"
# t_weight: "cuda:0 f32[10, 100]"
t0 = torch.nn.functional.linear(t_0, t_weight, None) # t0: "cuda:0 f32[10, 10]"
return {'output': t0, 'flat_args': [t_0, t_weight], 'flat_output': (t0,)}, ((t_0,), ())
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
_, \
= saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t1, \
= cotangents
clear_collection(cotangents)
del cotangents
t_0, \
= C0
clear_collection(C0)
del C0
t6 = torch.reshape(t1, (-1, 10)) # t6: "cuda:0 f32[10, 10]"
del t1
t7 = torch.permute(t6, (1, 0)) # t7: "cuda:0 f32[10, 10]"
del t6
t8 = torch.reshape(t_0, (-1, 100)) # t8: "cuda:0 f32[10, 100]"
del t_0
t9 = torch.matmul(t7, t8) # t9: "cuda:0 f32[10, 100]"
del t7, t8
return (None, t9)
Since there's only one operation in forward, the link between traces is obvious.
But for a trace like that of LitGPT, with a lot more diversity and sequences of Transformer blocks, inspecting the backward trace is a challenge.
Pitch
Since Thunder is all about debuggability, we could introduce a mode (assuming we don't want it to be default) where Thunder adds info about the linked forward operator. As an example, and extending the backward trace above:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
_, \
= saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t1, \
= cotangents
clear_collection(cotangents)
del cotangents
t_0, \
= C0
clear_collection(C0)
del C0
t6 = torch.reshape(t1, (-1, 10)) # t6: "cuda:0 f32[10, 10]", forward: t0: "torch.nn.functional.linear"
del t1
t7 = torch.permute(t6, (1, 0)) # t7: "cuda:0 f32[10, 10]", forward: t0: "torch.nn.functional.linear"
del t6
t8 = torch.reshape(t_0, (-1, 100)) # t8: "cuda:0 f32[10, 100]", forward: t0: "torch.nn.functional.linear"
del t_0
t9 = torch.matmul(t7, t8) # t9: "cuda:0 f32[10, 100]", forward: t0: "torch.nn.functional.linear"
del t7, t8
return (None, t9)
This is useful for debugging a specific section of your trace and also for learning autograd
cc @carmocca