[QUESTION] Some value of args went wrong
When I tried to assign different value to some args, tests may report errors.
My Working Environment H20
Reproduce my problem
- run ./launch.sh test/python/moe_gather_rs/test_moe_gather_rs.py -M 1024 -N 8192 -K 2048 -G 256 -E 8 -T 1 --topk 4 it works well
- change N from 8192 to 7168, run ./launch.sh test/python/moe_gather_rs/test_moe_gather_rs.py -M 1024 -N 7168 -K 2048 -G 256 -E 8 -T 1 --topk 4 the test went wrong
Some of this kind of problem(such as try to set topk to 8) can be solved by add some config in src/generator/gen_moe_gather_rs.cc, but others cannot
My problem is how can I run test on these args
thanks for your report.
You can refer to here:https://github.com/bytedance/flux/blob/main/src/generator/gen_moe_gather_rs.cc#L90 and add some arguments.
static constexpr auto AllGemmHParams_FP16 = make_space_gemm_hparams(
cute::make_tuple(make_gemm_v3_hparams(Shape<_1, _1, _1>{})),
cute::make_tuple(
make_gather_rs_hparams(cute::Int<26>{}, cute::Int<8192>{}),
make_gather_rs_hparams(cute::Int<28>{}, cute::Int<8192>{}),
make_gather_rs_hparams(cute::Int<30>{}, cute::Int<8192>{}),
make_gather_rs_hparams(cute::Int<28>{}, cute::Int<6144>{}),
make_gather_rs_hparams(cute::Int<28>{}, cute::Int<5120>{}),
make_gather_rs_hparams(cute::Int<28>{}, cute::Int<4096>{}),
make_gather_rs_hparams(cute::Int<28>{}, cute::Int<3072>{}),
make_gather_rs_hparams(cute::Int<28>{}, cute::Int<2048>{})),
cute::make_tuple(Shape<_128, _256, _64>{}));
The first argument is gather_rs CTAs. Use 32 or 26 may affect the performance a little, but I guess not much. the second should be 7168 for your case.
JIT helps. but it's not an easy job for JIT compiles for all shapes.
It works, thanks!