AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

blas_shape: GPU_GEMM: Batch dimension is not collapsible

Open shivadbhavsar opened this issue 1 year ago • 2 comments

Error seen in huggingface torch benchmark: OPTForCasualLM Only occurs after #3104

Model uncompiled mxr can be found in nas at: migraphx/models/torch_benchmarks/OPTForCasualLM.mxr

Repro: migraphx-driver compile OPTForCasualLM.mxr

shivadbhavsar avatar Sep 12 '24 16:09 shivadbhavsar

Small repro:

p = migraphx.program()
mm = p.get_main_module()

s1 = migraphx.shape(lens=[4096, 768], type="float_type")
in1 = mm.add_parameter("x", s1)
in1 = mm.add_instruction(migraphx.op("reshape", dims=[2, 2048, 768]), [in1])
in1 = mm.add_instruction(migraphx.op("reshape", dims=[2, -1, 12, 64]), [in1])
in1 = mm.add_instruction(migraphx.op("transpose", permutation=[0, 2, 1, 3]), [in1])
in1 = mm.add_instruction(migraphx.op("contiguous"), [in1])
in1 = mm.add_instruction(migraphx.op("reshape", dims=[24, -1, 64]), [in1])

s2 = migraphx.shape(lens=[2, 12, 2048, 2048], type="float_type")
in2 = mm.add_parameter("x2", s2)

min_lit = mm.add_literal(np.array(-65504, dtype=np.float32))
min_lit = mm.add_instruction(migraphx.op("multibroadcast", out_lens=[2, 12, 2048, 2048]), [min_lit])

max = mm.add_instruction(migraphx.op("max"), [in2, min_lit])
rsp_max = mm.add_instruction(migraphx.op("reshape", dims=[24, 2048, 2048]), [max])
smax = mm.add_instruction(migraphx.op("softmax", axis=-1), [rsp_max])
dot = mm.add_instruction(migraphx.op("dot"), [smax, in1])
dot_rsp = mm.add_instruction(migraphx.op("reshape", dims=[2, 12, 2048, 64]), [dot])

Trace compile: gmm_err_trace.txt

shivadbhavsar avatar Sep 25 '24 18:09 shivadbhavsar

Heres when the issue starts:

Pass: fuse_reduce
Pass: dead_code_elimination
x2 = @param:x2 -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
x = @param:x -> float_type, {4096, 768}, {768, 1}
@2 = reshape[dims={2, 2048, 12, 64}](x) -> float_type, {2, 2048, 12, 64}, {1572864, 768, 64, 1}
@3 = transpose[permutation={0, 2, 1, 3}](@2) -> float_type, {2, 12, 2048, 64}, {1572864, 64, 768, 1}
@4 = reshape[dims={24, 2048, 64}](@3) -> float_type, {24, 2048, 64}, {131072, 64, 1}
@5 = pointwise(x2), [main:pointwise0] -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
@6 = reshape[dims={24, 2048, 2048}](@5) -> float_type, {24, 2048, 2048}, {4194304, 2048, 1}
@7 = fused_reduce[axes={2}](@6), [main:reduce_sum1:main:pointwise3:main:reduce_max0:main:pointwise1] -> float_type, {24, 2048, 2048}, {4194304, 2048, 1}
@8 = dot(@7,@4) -> float_type, {24, 2048, 64}, {131072, 64, 1}
@9 = reshape[dims={2, 12, 2048, 64}](@8) -> float_type, {2, 12, 2048, 64}, {1572864, 131072, 64, 1}

Pass: rewrite_reshapes
Pass: simplify_reshapes
x2 = @param:x2 -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
x = @param:x -> float_type, {4096, 768}, {768, 1}
@2 = reshape[dims={2, 2048, 12, 64}](x) -> float_type, {2, 2048, 12, 64}, {1572864, 768, 64, 1}
@3 = transpose[permutation={0, 2, 1, 3}](@2) -> float_type, {2, 12, 2048, 64}, {1572864, 64, 768, 1}
@4 = pointwise(x2), [main:pointwise0] -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
@5 = fused_reduce[axes={3}](@4), [main:reduce_sum1:main:pointwise3:main:reduce_max0:main:pointwise1_reshape] -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
@6 = dot(@5,@3) -> float_type, {2, 12, 2048, 64}, {1572864, 131072, 64, 1}
@7 = identity(@6) -> float_type, {2, 12, 2048, 64}, {1572864, 131072, 64, 1}

After rewrite_reshapes does the simplification, there needs to be a contiguous added. Or there should already have been a contiguous op between the transpose and reshape? (@3 and @4 in fuse_reduce above)

shivadbhavsar avatar Sep 25 '24 18:09 shivadbhavsar

should be fixed by #3428

shivadbhavsar avatar Oct 01 '24 19:10 shivadbhavsar