glow
glow copied to clipboard
`tryMerge` of torch_glow's cusom fuse pass might have some issue
Hi,
I am implementing a custom torchscript jit backend like glow did, and might have found some little issue when custom fuse pass doing tryMerge
It seems that the fuse pass is implemented in a reverse order, but tryMerge
tries to fuse producer before consumer.
I am wondering if this was intended to be so, or it was some mistake.
The code below shows that glow will not do jit for a supported op aten::mul
.
When the aten::mul
op is treated as a consumer, the producer will be prim::Param
, which will not pass the first canMerge
checking, and for this, the consumer aten::mul
will not be fused into the FusionGroup, thus will not be JITed.
x = torch.randn(4)
y = torch.randn(4)
@torch.jit.script
def foo(a, b):
c = a.mul(b)
return c
print(foo.graph_for(x, y))
Output:
graph(%a.1 : Float(*, requires_grad=0, device=cpu),
%b.1 : Float(*, requires_grad=0, device=cpu)):
%c.1 : Float(*, requires_grad=0, device=cpu) = aten::mul(%a.1, %b.1) # /workspace/workspace/torch_glow_dev/glow/torch_glow/examples/basic_example.py:14:8
return (%c.1)
I have done some modification in this ~~commit~~ PR, it is able to generate a FusionGroup as I expected.
graph(%a.1 : Float(*, requires_grad=0, device=cpu),
%b.1 : Float(*, requires_grad=0, device=cpu)):
%c.1 : Float(*, requires_grad=0, device=cpu) = glow::FusionGroup_0(%a.1, %b.1)
return (%c.1)
with glow::FusionGroup_0 = graph(%a.1 : Float(*, requires_grad=0, device=cpu),
%b.1 : Float(*, requires_grad=0, device=cpu)):
%c.1 : Float(*, requires_grad=0, device=cpu) = aten::mul(%a.1, %b.1) # /workspace/workspace/torch_glow_dev/glow/torch_glow/examples/basic_example.py:14:8
return (%c.1)