lightning-thunder
lightning-thunder copied to clipboard
`python_print` should support `*args`
🐛 Bug
Traceback (most recent call last):
File "/home/carmocca/git/lightning-thunder/kk.py", line 8, in <module>
y = cfn()
File "/home/carmocca/git/lightning-thunder/thunder/common.py", line 645, in _fn
trc_or_result = trace(compile_data=cd)(cd.processed_function, *args, **kwargs)
File "/home/carmocca/git/lightning-thunder/thunder/common.py", line 480, in _trace
result = fn(*proxyargs, **proxykwargs)
File "<thunder-generated-139770238745040>", line 2, in fn
File "/home/carmocca/git/lightning-thunder/thunder/core/symbol.py", line 269, in __call__
result = self.meta(*args, **kwargs)
File "/home/carmocca/git/lightning-thunder/thunder/core/langctx.py", line 92, in _fn
result = fn(*args, **kwargs)
TypeError: _print_meta() takes 1 positional argument but 2 were given
To Reproduce
import thunder
from thunder.core.prims import python_print
def fn():
python_print("Hello", "world")
cfn = thunder.jit(fn)
y = cfn()
Expected behavior
Works like print
A design flaw with python_print is that common debugging patterns like python_print(f"Before layernorm: {x}") doesn't work because the string conversion for x is already done so you get Before layernorm: t452 instead of Before layernorm: torch.tensor(...)
So the motivation of this issue would be to do python_print("Before layernorm:", x).
As a workaround, I tried this hacky idea:
def tprint(*args):
a, b = args
print(a, end=" ")
prims.python_print(b)
But all the prints get flushed before all the python_prints