lightning-thunder
lightning-thunder copied to clipboard
Implement VJP for `python_print`
🚀 Feature
Motivation
Enable debugging
Pitch
import torch
import thunder
from thunder.core.prims import python_print
def fn(x):
y = x + 2
python_print(y)
return y
x = torch.randn(2, requires_grad=True)
cfn = thunder.jit(fn)
y = cfn(x)
NotImplementedError: VJP for PrimIDs.PRINT is not implemented