text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

MI300 compatibility

Open fxmarty opened this issue 1 year ago • 2 comments

Adds support for AMD Instinct MI300 in TGI.

Most changes are:

  • Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable. TunableOp is disabled by default, and can be enabled with PYTORCH_TUNABLEOP_ENABLED=1.
  • Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes from https://github.com/pytorch/pytorch/pull/124362)
  • Support SILU & Linear custom kernels contributed by AMD
  • Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/, branching out of a much more recent commit https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
  • Support FA2 Triton kernel as recommended by AMD. Can be used by specifying ROCM_USE_FLASH_ATTN_V2_TRITON=1.
  • Update dockerfile to ROCm 6.1

By default, TunableOp tuning results are saved in /data (e.g. /data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv) in order to avoid to have to rerun the tuning at each docker run.

Example:

Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489

fxmarty avatar Apr 18 '24 23:04 fxmarty

@seungrokj Feel free to ping anyone at AMD that would be interested in reviewing this one.

fxmarty avatar May 02 '24 15:05 fxmarty

@amathews-amd @shajrawi @andyluo7 @mawong-amd @jeffdaily @liligwu @hongxiayang Please take a look at these when you're available

seungrokj avatar May 02 '24 16:05 seungrokj

@fxmarty Can you please add one more triton.Config in https://github.com/huggingface/text-generation-inference/blob/b7e98ba635367daa23c5b1f4a73f51b1f061936a/server/text_generation_server/utils/flash_attn_triton.py#L261

        triton.Config(
            {
                "BLOCK_M": 128,
                "BLOCK_N": 64,
                "waves_per_eu": 1,
                "PRE_LOAD_V": False,
            },
            num_stages=1,
            num_warps=4,
        ),

This will improve the prefill latency of llama3 70b TP8 about 3.4 to 10%, when batch 1~32, seqlen=2048

seungrokj avatar May 16 '24 03:05 seungrokj

@Narsil This PR is ready. Could you give a look?

We are just waiting for a patched / updated rocm/dev-ubuntu-22.04 base image that would fix an issue with libamdhip64.so on certain VMs, avoiding

https://github.com/huggingface/text-generation-inference/blob/afc747337a5beb35c492fbb5fdaa5de1da9d20f1/Dockerfile_amd#L117-L122

We are expecting to get the updated docker image by Monday next week. Do you think a TGI release on next week Tuesday/Wednesday with this PR in is feasible?

fxmarty avatar May 16 '24 11:05 fxmarty

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

We are expecting to get the updated docker image by Monday next week. Do you think a TGI release on next week Tuesday/Wednesday with this PR in is feasible?

Sure releases are kind of trivial now.

Narsil avatar May 16 '24 13:05 Narsil

As we got an updated rocm/dev-ubuntu-22.04:6.1.1_hip_update, this PR may be merged once build is done & tests are passing

fxmarty avatar May 17 '24 09:05 fxmarty