lightning-thunder
lightning-thunder copied to clipboard
Export stateful executor states to`ThunderModule`
What does this PR do?
Closes #2438.
Summary
- Added ExportStatefulExecutorsTransform (singleton) with a registry of export callbacks.
- Executors can register a callback to export runtime state post-execution.
- Currently integrates with
transformer_engine_ex.
Usage
The export is controlled by passing the transform to the transforms list:
thunder.jit(model, executors=[...], transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()])
We might or might not go towards a compiler flag to make this more intuitive for the user (hence handling this transform automatically).
Whole example:
from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform
import torch
import torch.nn as nn
import thunder
from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform
from transformer_engine.common import recipe
import transformer_engine.pytorch as te
from pprint import pprint
torch.manual_seed(42)
# Device and data type
dtype = torch.bfloat16
device = "cuda"
# Inputs (3D input)
x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True)
# TE recipe
fp8_recipe = recipe.DelayedScaling()
# fp8_recipe = recipe.MXFP8BlockScaling()
# Dummy model
class Module(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))
self.w2 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))
self.w3 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))
def forward(self, x):
return torch.nn.functional.linear(torch.nn.functional.linear(torch.nn.functional.linear(x, self.w2), self.w3), self.w1)
model = Module()
jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()])
# Enable autocasting for the forward pass
for _ in range(2):
with te.fp8_autocast(fp8_recipe=fp8_recipe):
y = jmodel(torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True))
rep = jmodel.te_fp8_states()
pprint(rep) # Visualize states information after operation
grad_output = torch.randn_like(y)
y.backward(grad_output)
rep = jmodel.te_fp8_states(mode="backward")
pprint(rep) # Visualize states information after operation
Why it matters
- Not TE-specific: provides a reusable, uniform path for any stateful executor to export runtime state for debugging, validation, and reporting.