AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Fuse quantizelinear for skip layers using multioutput fusions

Open shivadbhavsar opened this issue 11 months ago • 1 comments

Follow up from PR #3782 Ex resnet quantized graph after above PR:

NEW:

q -> conv -> dq -> add -> relu -> q .......... -> q -> conv -> dq -> add -> relu -> q
			   |	                                                    |
			   -> step -> q -----------------------------------------> concat -> conv -> ...

Doing some experimental work, it turns out that we get a slight perf boost from moving the skip-connection quantize op before the step op, and fusing it into the previous conv-pointwise kernel. This should probably be done as 2 steps:

  • [ ] 1. Write pass for multioutput fusion that can fuse to mlir_quant_convolution_dequantizelinear_dequantizelinear_add_add_relu_quantizelinear_quantizelinear
  • [ ] 2. Write pass to move q before step when this fusion is possible

These should be done in this order since swapping the order of quantize op and step op is generally not preferable unless doing it for this fusion.

shivadbhavsar avatar Jan 31 '25 21:01 shivadbhavsar

Here is some example code used for the experiment.

  1. MLIR multi output fusion (in fuse_mlir pass) (This needs to be refactored to account for incoming changes: #3569 and #3752 (or similar))
struct find_mlir_multi_pointwise
{
    mlir_mode conv_mode = mlir_mode::none;
    mlir_mode dot_mode  = mlir_mode::none;
    auto matcher() const
    {
        return match::name("gpu::mlir_op")(match::all_of[match::outputs()](mlir_pointwise));
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto ins       = r.result;
        auto* mlir_mod = ins->module_inputs().front();
        auto pw_inss   = ins->outputs();

        std::string module_name = mlir_mod->name();
        std::transform(
            pw_inss.begin(),
            pw_inss.end(),
            join_back_inserter(module_name),
            [](instruction_ref pw) { return ":" + pw->module_inputs().front()->name(); });
        module_ref mm = mpm.create_module(module_name);

        std::unordered_map<instruction_ref, instruction_ref> map_main_to_mm;
        mm->add_params(ins->inputs(), &map_main_to_mm);
        std::unordered_map<instruction_ref, instruction_ref> map_mlir_mod_to_mm(map_main_to_mm);
        auto original_return = mm->fuse(*mlir_mod, ins->inputs(), &map_mlir_mod_to_mm).front();
        map_main_to_mm[ins]  = original_return;

        // single pointwise output should already be fused in
        assert(pw_inss.size() > 1);
        std::vector<instruction_ref> new_returns;
        for(auto pw_ins : pw_inss)
        {
            auto* pm = pw_ins->module_inputs().front();
            std::unordered_map<instruction_ref, instruction_ref> lit_map =
                create_param_map_with_literals(mm, pm, pw_ins->get_shape());

            mm->add_params(pw_ins->inputs(), &map_main_to_mm);
            map_main_to_mm.insert(lit_map.begin(), lit_map.end());
            std::unordered_map<instruction_ref, instruction_ref> map_pm_to_mm(map_main_to_mm);
            auto fused_pw_out = mm->fuse(*pm, pw_ins->inputs(), &map_pm_to_mm).front();

            map_main_to_mm[pw_ins] = fused_pw_out;
            new_returns.push_back(fused_pw_out);
        }

        mm->add_return(new_returns);
        auto map_mm_to_main = invert_map_ins(map_main_to_mm);
        auto new_inputs     = mm->get_inputs(map_mm_to_main);

        mm->set_bypass();
        auto fused_ins = mpm.get_module().insert_instruction(
            ins, ins->get_operator(), mlir_contiguous(mpm, new_inputs), {mm});
        mpm.get_module().debug_print(fused_ins);

        size_t out_idx = 0;
        for(auto ret : new_returns)
        {
            auto original_ins = map_mm_to_main[ret];
            mpm.get_module().replace_instruction(
                original_ins, migraphx::make_op("get_tuple_elem", {{"index", out_idx}}), fused_ins);
            out_idx++;
        }
    }
};
  1. Moving q before step (in simplify_qdq pass) - Currently only written for scalar scales and zero-points, can be generalized
struct match_step_qlinear
{
    auto matcher() const
    {
        auto any_pointwise_input = match::any_of[match::inputs()](match::pointwise().bind("pw"));
        return match::name("quantizelinear")(match::arg(0)(
            match::name("step")(match::used_once(), any_pointwise_input).bind("step")));
    }

    auto get_prebroadcast_qparam(instruction_ref i, size_t channels) const
    {
        instruction_ref top = i;
        while(top->get_shape().elements() != channels)
        {
            assert(top->inputs().size() == 1);
            top = top->inputs()[0];
        }
        return top;
    }

    void apply(module& m, const match::matcher_result& r) const
    {
        auto ins      = r.result;
        auto step_ins = r.instructions["step"];
        auto pw_ins   = r.instructions["pw"];

        assert(ins->inputs().size() == 3);
        auto scale = ins->inputs()[1];
        auto zp    = ins->inputs()[2];

        if(not(scale->get_shape().scalar() and zp->get_shape().scalar()))
            return;

        auto sscale = get_prebroadcast_qparam(scale, 1);
        auto szp    = get_prebroadcast_qparam(zp, 1);

        auto scale_mb = m.insert_instruction(
            step_ins,
            make_op("multibroadcast", {{"out_lens", pw_ins->get_shape().lens()}}),
            {sscale});
        auto zp_mb = m.insert_instruction(
            step_ins, make_op("multibroadcast", {{"out_lens", pw_ins->get_shape().lens()}}), {szp});

        auto new_q = m.insert_instruction(step_ins, ins->get_operator(), {pw_ins, scale_mb, zp_mb});

        m.replace_instruction(ins, step_ins->get_operator(), {new_q});
    }
};

shivadbhavsar avatar Jan 31 '25 21:01 shivadbhavsar