onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Improve KE for commandline and programmatically tuning dispatch

Open cloudhan opened this issue 2 years ago • 1 comments

cloudhan avatar Dec 11 '23 04:12 cloudhan

For example:

import os
import sys

sys.path.insert(0, "/home/guangyunhan/onnxruntime/onnxruntime/python/tools/kernel_explorer/kernels")
sys.path.insert(0, "/home/guangyunhan/onnxruntime/build_rocm/Release")
os.environ["KERNEL_EXPLORER_BUILD_DIR"] = "/home/guangyunhan/onnxruntime/build_rocm/Release"


import multiprocessing as mp
from multiprocessing import Pool, current_process


def profile(name, *args, **kwargs):
    import kernel_explorer as ke

    ke.set_return_tuning_results()
    ke.set_dispatchable_pattern("*Tunable*")
    print(os.environ["HIP_VISIBLE_DEVICES"])
    if name == "gemm":
        from gemm_test import profile_with_args as profile

        return profile(*args, **kwargs)
    elif name == "softmax":
        from softmax_test import profile_with_args as profile

        return profile(*args, **kwargs)
    else:
        return []


def init():
    pidx = int(current_process()._identity[0]) - 1
    start_gpu = 2
    num_gpu = 14
    os.environ["HIP_VISIBLE_DEVICES"] = str(pidx % num_gpu + start_gpu)


if __name__ == "__main__":
    configs = [
        ("gemm", "float16", False, False, 1, 8912, 8912),
        ("gemm", "float16", False, False, 8, 8912, 8912),
        ("gemm", "float16", False, False, 16, 8912, 8912),
        ("gemm", "float16", False, False, 24, 8912, 8912),
        ("gemm", "float16", False, False, 32, 8912, 8912),
        ("gemm", "float16", False, False, 40, 8912, 8912),
        ("gemm", "float16", False, False, 48, 8912, 8912),
        ("softmax", 1, 1024, False, "float16"),
        ("softmax", 2, 1024, False, "float16"),
    ]

    mp.set_start_method("spawn")

    with Pool(processes=4, initializer=init) as pool:
        ret = pool.starmap(profile, configs, chunksize=1)

    from pprint import pprint
    from onnxruntime.tools.offline_tuning import Merger

    m = Merger()
    for tr in ret:
        m.merge(tr)

    pprint(m.get_merged())

cloudhan avatar Dec 11 '23 04:12 cloudhan