AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Improve horizontal fusion with multi-used splits

Open pfultz2 opened this issue 4 months ago • 1 comments

We should fuse this case:

p = migraphx.program()
m = p.get_main_module()
x_0 = m.add_literal(migraphx.create_argument(migraphx.shape(type="float_type", lens=[1]), [1]))
x_1 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1, 1648]), 0))
x_2 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1, 1648]), 1))
x_3 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1, 1776]), 2))
x_4 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1, 128]), 3))
x_5 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1, 128]), 4))
x_6 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1, 128]), 5))
x_7 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1, 1648]), 6))
p_x88 = m.add_parameter("x88", migraphx.shape(type="float_type", lens=[1, 1776]))
x_9 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1648]), [x_0]) # migraphx.shape(type="float_type", lens=[1, 1648], strides=[0, 0])
x_10 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,128]), [x_0]) # migraphx.shape(type="float_type", lens=[1, 128], strides=[0, 0])
x_11 = m.add_instruction(migraphx.op("add"), [x_3, p_x88]) # migraphx.shape(type="float_type", lens=[1, 1776])
x_12 = m.add_instruction(migraphx.op("slice", axes=[1], starts=[0], ends=[1648]), [x_11]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_13 = m.add_instruction(migraphx.op("mul"), [x_12, x_7]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_14 = m.add_instruction(migraphx.op("add"), [x_13, x_1]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_15 = m.add_instruction(migraphx.op("sigmoid"), [x_14]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_16 = m.add_instruction(migraphx.op("sub"), [x_9, x_15]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_17 = m.add_instruction(migraphx.op("mul"), [x_16, x_2]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_18 = m.add_instruction(migraphx.op("add"), [x_15, x_17]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_19 = m.add_instruction(migraphx.op("mul"), [x_18, x_12]) # migraphx.shape(type="float_type", lens=[1, 1648])
x_20 = m.add_instruction(migraphx.op("slice", axes=[1], starts=[1648], ends=[1776]), [x_11]) # migraphx.shape(type="float_type", lens=[1, 128])
x_21 = m.add_instruction(migraphx.op("mul"), [x_20, x_4]) # migraphx.shape(type="float_type", lens=[1, 128])
x_22 = m.add_instruction(migraphx.op("add"), [x_21, x_5]) # migraphx.shape(type="float_type", lens=[1, 128])
x_23 = m.add_instruction(migraphx.op("sigmoid"), [x_22]) # migraphx.shape(type="float_type", lens=[1, 128])
x_24 = m.add_instruction(migraphx.op("sub"), [x_10, x_23]) # migraphx.shape(type="float_type", lens=[1, 128])
x_25 = m.add_instruction(migraphx.op("mul"), [x_24, x_6]) # migraphx.shape(type="float_type", lens=[1, 128])
x_26 = m.add_instruction(migraphx.op("add"), [x_23, x_25]) # migraphx.shape(type="float_type", lens=[1, 128])
x_27 = m.add_instruction(migraphx.op("mul"), [x_26, x_20]) # migraphx.shape(type="float_type", lens=[1, 128])
m.add_return([x_19, x_27])

In #3920, when there are interdependencies in the split groups, we dont fuse it and instead fuse after we have run fuse_pointwise as its easier to analyze how it can be fused. This worked for the previous case because everything was scalar so it would be embeded into kernel instead of a parameter that is passed. In this case these are full tensors, which find_splits only handles unary and binary operators.

We need to extend find_splits to handle multiple arguments. As long as all the extra(or data arguments) are constants we should fuse it.

pfultz2 avatar Aug 22 '25 22:08 pfultz2

Related #3844, #3920

pfultz2 avatar Aug 22 '25 22:08 pfultz2