lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Export stateful executor states to`ThunderModule`

Open mattteochen opened this issue 2 months ago • 0 comments

What does this PR do?

Closes #2438.

Summary

  • Added ExportStatefulExecutorsTransform (singleton) with a registry of export callbacks.
  • Executors can register a callback to export runtime state post-execution.
  • Currently integrates with transformer_engine_ex.

Usage

The export is controlled by passing the transform to the transforms list:

thunder.jit(model, executors=[...], transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()])

We might or might not go towards a compiler flag to make this more intuitive for the user (hence handling this transform automatically).

Whole example:

from thunder.dev_utils.export_stateful_ex_transform import ExportStatefulExecutorsTransform
import torch
import torch.nn as nn
import thunder
from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform
from transformer_engine.common import recipe
import transformer_engine.pytorch as te
from pprint import pprint

torch.manual_seed(42)

# Device and data type
dtype = torch.bfloat16
device = "cuda"

# Inputs (3D input)
x = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True)

# TE recipe
fp8_recipe = recipe.DelayedScaling()
# fp8_recipe = recipe.MXFP8BlockScaling()

# Dummy model
class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))
        self.w2 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))
        self.w3 = nn.Parameter(torch.randn(4096, 4096, device=device, dtype=dtype))

    def forward(self, x):
        return torch.nn.functional.linear(torch.nn.functional.linear(torch.nn.functional.linear(x, self.w2), self.w3), self.w1)

model = Module()
jmodel = thunder.jit(model, executors=[transformer_engine_ex], transforms=[TransformerEngineTransform(), ExportStatefulExecutorsTransform()])

# Enable autocasting for the forward pass
for _ in range(2):
    with te.fp8_autocast(fp8_recipe=fp8_recipe):
        y = jmodel(torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True))
        rep = jmodel.te_fp8_states()
        pprint(rep) # Visualize states information after operation
    grad_output = torch.randn_like(y)
    y.backward(grad_output)
    rep = jmodel.te_fp8_states(mode="backward")
    pprint(rep) # Visualize states information after operation

Why it matters

  • Not TE-specific: provides a reusable, uniform path for any stateful executor to export runtime state for debugging, validation, and reporting.

mattteochen avatar Oct 02 '25 15:10 mattteochen