lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

`python_print` should support `*args`

Open carmocca opened this issue 2 years ago • 1 comments

🐛 Bug

Traceback (most recent call last):
  File "/home/carmocca/git/lightning-thunder/kk.py", line 8, in <module>
    y = cfn()
  File "/home/carmocca/git/lightning-thunder/thunder/common.py", line 645, in _fn
    trc_or_result = trace(compile_data=cd)(cd.processed_function, *args, **kwargs)
  File "/home/carmocca/git/lightning-thunder/thunder/common.py", line 480, in _trace
    result = fn(*proxyargs, **proxykwargs)
  File "<thunder-generated-139770238745040>", line 2, in fn
  File "/home/carmocca/git/lightning-thunder/thunder/core/symbol.py", line 269, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/carmocca/git/lightning-thunder/thunder/core/langctx.py", line 92, in _fn
    result = fn(*args, **kwargs)
TypeError: _print_meta() takes 1 positional argument but 2 were given

To Reproduce

import thunder
from thunder.core.prims import python_print

def fn():
    python_print("Hello", "world")

cfn = thunder.jit(fn)
y = cfn()

Expected behavior

Works like print

carmocca avatar Nov 03 '23 16:11 carmocca

A design flaw with python_print is that common debugging patterns like python_print(f"Before layernorm: {x}") doesn't work because the string conversion for x is already done so you get Before layernorm: t452 instead of Before layernorm: torch.tensor(...)

So the motivation of this issue would be to do python_print("Before layernorm:", x).

As a workaround, I tried this hacky idea:

def tprint(*args):
    a, b = args
    print(a, end=" ")
    prims.python_print(b)

But all the prints get flushed before all the python_prints

carmocca avatar Nov 03 '23 17:11 carmocca