Funny fusion effects - constant as output of compilation group
Hi,
for fusing a not that much longer elementary pointwise computation, I get
with tvm::CompilationGroup_0 = graph(%0 : Float(*),
%1 : Float(*)):
%4 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=1]()
%3 : Float(*) = aten::add(%0, %1, %2)
return (%3, %4)
and the second output is then used outside the group, which is not good, as it'll hit errors when fusing the next thing.
Best regards
Thomas
Trying to produce a minimal example, I got:
RuntimeError: IValue is not a Scalar
The above operation failed in interpreter, with the following stack trace:
at <ipython-input-15-e6d508d37f4e>:3:12
@torch.jit.script
def func(x)->torch.Tensor:
return x+torch.tanh(x+x)
~~~~~~~~~~~~~~~~ <--- HERE
So the first error seems related to prim::Constant being declared fusible.
As far as I understand, prim::Constant nodes are not usually declared fusible but instead are pulled in (but duplicated) when they are arguments to a fused node. This happens in GraphFuser::mergeNodeIntoGroup:
https://github.com/pytorch/pytorch/blob/2132ea1d8d6ca9cb2ab6c2206c7cdc9203e1908f/torch/csrc/jit/passes/graph_fuser.cpp#L359-L365
Hi @t-vi, I'm experiencing the same issue, while trying to compile a model from maskrcnn-benchmark. Do you know if there is any workaround until the issue is fixed? Thx.
Yes, sorry, I have been meaning to send a PR, but you simply need to remove prim::Constant fusion:
https://github.com/pytorch/tvm/blob/7d3fa3fbd88855a1655c174cc43caac7fe6b954e/torch_tvm/operators.cpp#L417
I see, however, that #72 is moving away from the CustomGraphFuser (which ironically seems to be doing funny const handling, too, when the current approach makes fixing this simple).