lightning-thunder
lightning-thunder copied to clipboard
`python_print` doesn't work as expected
🐛 Bug
To Reproduce
Issue 1:
import thunder
import torch
from thunder.core import prims
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linears = torch.nn.ModuleList(torch.nn.Linear(1, 1) for _ in range(3))
def forward(self, x):
for l in self.linears:
prims.python_print("l")
prims.python_print(x.sum())
x = l(x)
return x
fn = MyModel()
x = torch.randn(10, 1)
fn = thunder.jit(fn, disable_torch_autograd=True)
fn(x)
Output:
tensor(2.7485)
tensor(6.1350)
tensor(-3.8157)
l is missing
Issue 2:
import thunder
import torch
from thunder.core import prims
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linears = torch.nn.ModuleList(torch.nn.Linear(1, 1) for _ in range(3))
def forward(self, x):
prims.python_print("l0")
prims.python_print(x.sum())
x = self.linears[0](x)
prims.python_print("l1")
prims.python_print(x.sum())
x = self.linears[1](x)
prims.python_print("l2")
prims.python_print(x.sum())
x = self.linears[2](x)
return x
fn = MyModel()
x = torch.randn(10, 1)
fn = thunder.jit(fn, disable_torch_autograd=True)
fn(x)
Output:
l0
l1
l2
tensor(-1.1433)
tensor(8.6309)
tensor(-3.1484)
Expected behavior
The prints should appear where they should, as many times as they should.
Awesome issues, thanks @carmocca! I'll get on them ASAP