TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

AWQ for Yi-6b-Chat fail when world_size > 1 but success when world_size is 1

Open xikaluo opened this issue 1 year ago • 0 comments

System Info

-GPU: 4 * 3090(24G) -TensorRT-LLM version: 0.8.0.dev20240123 -TensorRT version: 9.2.0.post12.dev5 -Nvidia Driver: Driver Version: 535.54.03 CUDA Version: 12.2 -OS: Ubuntu 20.04

Who can help?

@Tracin

Information

  • [ ] The official example scripts
  • [x] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

  1. Use the following script to do AWQ: CUDA_VISIBLE_DEVICES=4,5,6,7 python ammo_quant.py
    --model_dir ./models/yi-6b-chat
    --dtype float16
    --qformat int4_awq
    --export_path ./quant_ckpts/awq/yi-6b-chat-512
    --calib_size 512
    --ds_name cnndm
    --max_input_length 2048
    --inf_tp_size 1

ammo_quant.py is similar to https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py. It only adds three arguments (ds_name, max_input_length, inf_tp_size) which are sent to function get_calib_dataloader and quantize_and_export

  1. Use this script to run python code copied from examples/llama to build engine with world_size == 1: CUDA_VISIBLE_DEVICES=4,5,6,7 python build.py
    --world_size 1
    --pp_size 1
    --model_dir ./models/yi-6b-chat
    --quant_ckpt_path ./quant_ckpts/awq/yi-6b-chat-512/llama_tp1_rank0.npz
    --dtype float16
    --max_batch_size 1
    --max_input_len 2048
    --max_output_len 512
    --max_num_tokens 2560
    --remove_input_padding
    --use_gpt_attention_plugin float16
    --enable_context_fmha
    --use_gemm_plugin float16
    --parallel_build
    --use_weight_only
    --weight_only_precision int4_awq
    --per_group
    --output_dir ./engines/awq/yi-6b-chat-pp1

  2. Use this script to run same python code to build engine with world_size == 4: CUDA_VISIBLE_DEVICES=4,5,6,7 python build.py
    --world_size 4
    --pp_size 4
    --model_dir ./models/yi-6b-chat
    --quant_ckpt_path ./quant_ckpts/awq/yi-6b-chat-512/llama_tp1_rank0.npz
    --dtype float16
    --max_batch_size 1
    --max_input_len 2048
    --max_output_len 512
    --max_num_tokens 2560
    --remove_input_padding
    --use_gpt_attention_plugin float16
    --enable_context_fmha
    --use_gemm_plugin float16
    --parallel_build
    --use_weight_only
    --weight_only_precision int4_awq
    --per_group
    --output_dir ./engines/awq/yi-6b-chat-pp4

  3. Use two engines to generate output for this input: { "system":"", "history": [], "query": "Hello, how are you today?" }

Expected behavior

Both two engines should be able to generate good answer

actual behavior

Engine built under world_size == 1 get good output: "Thank you for asking! I'm doing well today. How about you? Is there anything specific you need assistance with or would like to discuss?<|im_end|>"

Engine built under world_size == 4 get some messy output: ">=__ologists端部臂(光荣地 ;\ Noticeablybooking{| († bes '\ago"

additional notes

The code of ammo_quant.py is:

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from examples/quantization/hf_ptq.py
"""

import argparse
import random

import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.logger import logger, set_level
set_level("verbose")
from tensorrt_llm.models.quantized.ammo import quantize_and_export


def get_calib_dataloader(ds_name, tokenizer, batch_size: int, calib_size: int, max_input_length: int):
    print("Loading calibration dataset")
    if ds_name == "pileval":
        dataset = load_dataset(
            "json",
            data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
            split="train"
        )
        dataset = dataset["text"][:calib_size]
    elif ds_name == "cnndm":
        dataset = load_from_disk("./all_dataset/ccdv-cnn_dailymail")["train"]
        dataset = dataset["article"][:calib_size]
    else:
        raise NotImplementedError

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.im_end_id

    dataset_input_ids = tokenizer(dataset,
                                  return_tensors="pt",
                                  padding="max_length",
                                  truncation=True,
                                  max_length=max_input_length).input_ids.cuda()

    calib_dataloader = DataLoader(dataset_input_ids,
                                  batch_size=batch_size,
                                  shuffle=False)

    return calib_dataloader


def get_tokenizer(ckpt_path, **kwargs):
    logger.info(f"Loading tokenizer from {ckpt_path}")
    tokenizer = AutoTokenizer.from_pretrained(ckpt_path,
                                              padding_side="left",
                                              **kwargs)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def get_model(ckpt_path, dtype="float16", cache_dir=None):
    logger.info(f"Loading model from {ckpt_path}")
    torch_dtype = str_dtype_to_torch(dtype)
    model = AutoModelForCausalLM.from_pretrained(
        ckpt_path,
        device_map="auto",
        cache_dir=cache_dir,
        trust_remote_code=True,
        torch_dtype=torch_dtype,
    )
    model.eval()
    model = model.to(memory_format=torch.channels_last)
    return model


def get_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--model_dir",
                        type=str,
                        required=True,
                        help="Directory of a HF model checkpoint")
    parser.add_argument("--dtype", help="Model data type.", default="float16")
    parser.add_argument("--qformat",
                        type=str,
                        choices=['fp8', 'int8_sq', 'int4_awq'],
                        default='fp8',
                        help='Quantization format.')
    parser.add_argument('--calibrate_kv_cache',
                        default=False,
                        action="store_true",
                        help='Calibrate kv cache for int8 quantization.')
    parser.add_argument('--group_size',
                        type=int,
                        default=128,
                        help='Group size used in AWQ quantization.')
    parser.add_argument(
        '--quantize_lm_head',
        default=False,
        action="store_true",
        help='Quantize lm_head weight as well when using int4_awq.')
    parser.add_argument("--calib_size",
                        type=int,
                        default=512,
                        help="Number of samples for calibration.")
    parser.add_argument("--export_path", default="exported_model")
    parser.add_argument("--cache_dir",
                        type=str,
                        default=None,
                        help="Directory of dataset cache.")
    parser.add_argument('--seed', type=int, default=None, help='Random seed')
    parser.add_argument("--ds_name", type=str, default=None, help="dataset name")
    parser.add_argument("--max_input_length", type=int, default=None, help="max input length")
    parser.add_argument("--inf_tp_size", type=int, default=None, help="tp-size when inference")
    
    args = parser.parse_args()
    return args


def main():
    if not torch.cuda.is_available():
        raise EnvironmentError("GPU is required for inference.")

    args = get_args()

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)

    tokenizer = get_tokenizer(args.model_dir,
                              cache_dir=args.cache_dir,
                              use_fast=True,
                              trust_remote_code=True)
    model = get_model(args.model_dir, args.dtype, cache_dir=args.cache_dir)

    calib_dataloader = get_calib_dataloader(
        ds_name=args.ds_name,
        tokenizer=tokenizer, 
        batch_size=1,
        calib_size=args.calib_size,
        max_input_length=args.max_input_length
    )

    quant_cfg_dict = {}
    if args.quantize_lm_head:
        quant_cfg_dict.update({
            "*lm_head*": {
                "enable": True
            },
        })
    if args.group_size != 128:
        quant_cfg_dict.update({
            "*weight_quantizer": {
                "num_bits": 4,
                "block_sizes": {
                    -1: args.group_size
                },
                "enable": True
            },
        })
    if args.calibrate_kv_cache:
        quant_cfg_dict.update({
            "*.query_key_value.output_quantizer": {
                "num_bits": 8,
                "axis": None,
                "enable": True
            },
            "*.k_proj.output_quantizer": {
                "num_bits": 8,
                "axis": None,
                "enable": True
            },
            "*.v_proj.output_quantizer": {
                "num_bits": 8,
                "axis": None,
                "enable": True
            },
        })

    model = quantize_and_export(model,
                                qformat=args.qformat,
                                calib_dataloader=calib_dataloader,
                                export_path=args.export_path,
                                tensor_parallel_size=args.inf_tp_size,
                                quant_cfg_dict=quant_cfg_dict)


if __name__ == "__main__":
    main()

xikaluo avatar Jan 31 '24 11:01 xikaluo