dce pass does not treat correctly `DONT_DCE` tag in sub-symbols
🐛 Bug
This is a very specific and hard to reach bug, but a bug nonetheless. If one of the subsymbols has the DONT_DCE tag, but the output of the parent bound symbol is not used, running dce will remove the output proxy but not the bound symbol.
This is due to is_needed being True but the output proxy is not in needed_proxies
https://github.com/Lightning-AI/lightning-thunder/blob/d0f647479f73787389c4fa6ced3b0b9ada4083c0/thunder/core/transform_common.py#L182-L187
To Reproduce
import torch
import thunder
from thunder.core.transform_common import dce
def foo(a):
b = a.add_(5)
return a.add_(4)
a = torch.randn(2, 2)
jf = thunder.jit(foo)
jf(a)
trace_before_functionalization = thunder.last_traces(jf)[-7]
broken_trace = dce(trace_before_functionalization)
print(broken_trace)
Will print something like this:
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cpu f32[2, 2]"
# <ipython-input-17-c5d60ae12d96>:6: b = a.add_(5)
ltorch.add_(a, 5, alpha=1)
# t0 = ltorch.add(a, 5, alpha=1) # t0: "cpu f32[2, 2]"
# t0 = prims.add(a, 5.0) # t0: "cpu f32[2, 2]"
# b = prims.copy_(t0, a, grad_enabled=True) # b: "cpu f32[2, 2]"
# <ipython-input-17-c5d60ae12d96>:7: return a.add_(4)
t3 = ltorch.add_(a, 4, alpha=1) # t3: "cpu f32[2, 2]"
# t2 = ltorch.add(a, 4, alpha=1) # t2: "cpu f32[2, 2]"
# t2 = prims.add(a, 4.0) # t2: "cpu f32[2, 2]"
# t3 = prims.copy_(t2, a, grad_enabled=True) # t3: "cpu f32[2, 2]"
return {'output': (t3,), 'flat_args': [a]}
Discovered while working on #1961
Just to clarify, do you want to remove the bound symbol if the output proxy is not used, even though a sub symbol has the DONT_DCE tag?
That seems rather unclear to me. I think we should not, possibly we should remove some subsymbols.