sparseml icon indicating copy to clipboard operation
sparseml copied to clipboard

Match GPTQ state dict

Open rahul-tuli opened this issue 1 year ago • 0 comments

Conversion script:

from sparseml.transformers.utils.vllm_export_helpers import export_vllm_checkpoint
from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer

path = "/home/rahul/projects/sparseml/local/local_output/sparsegpt-autogptq-emulation-checkpoint/stage_compression"
sparse_gpt_model = SparseAutoModelForCausalLM.from_pretrained(path)
tokenizer = SparseAutoTokenizer.from_pretrained(path)

export_vllm_checkpoint(
    model=sparse_gpt_model,
    tokenizer=tokenizer,
)
024-03-21 01:58:33 sparseml.pytorch.model_load.helpers INFO     Reloaded model state after SparseML recipe structure modifications from /home/rahul/projects/sparseml/local/local_output/sparsegpt-autogptq-emulation-checkpoint/stage_compression
2024-03-21 01:58:33 __main__     INFO     Adding exllama quantization info to config
2024-03-21 01:58:33 __main__     INFO     Translating state dict to exllama format.
2024-03-21 01:58:33 sparseml.transformers.utils.transformations INFO     Applying transformation: TRANSFORM_NAMES
2024-03-21 02:00:46 sparseml.transformers.utils.transformations INFO     Transformation: TRANSFORM_NAMES complete
2024-03-21 02:00:46 sparseml.transformers.utils.transformations INFO     Applying transformation: ADD_TENSORS
2024-03-21 02:00:46 sparseml.transformers.utils.transformations INFO     Transformation: ADD_TENSORS complete
2024-03-21 02:00:46 sparseml.transformers.utils.transformations INFO     Applying transformation: TRANSFORM_TENSORS
2024-03-21 02:00:46 sparseml.transformers.utils.transformations INFO     Transformation: TRANSFORM_TENSORS complete
2024-03-21 02:00:46 sparseml.transformers.utils.transformations INFO     Applying transformation: REMOVE_UNWANTED_TENSORS
2024-03-21 02:00:46 sparseml.transformers.utils.transformations INFO     Transformation: REMOVE_UNWANTED_TENSORS complete
2024-03-21 02:00:50 __main__     INFO     Model and config saved to /nm/drive0/rahul/projects/sparseml/exllama_model
2024-03-21 02:00:50 __main__     INFO     tokenizer saved to /nm/drive0/rahul/projects/sparseml/exllama_model
$ tree ./exllama_model 
./exllama_model
├── config.json
├── generation_config.json
├── model.safetensors
├── special_tokens_map.json
├── tokenizer_config.json
└── tokenizer.json

0 directories, 6 files

config.json

{
  "_name_or_path": "/home/rahul/projects/sparseml/local/local_output/sparsegpt-autogptq-emulation-checkpoint/stage_compression",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "quantization_config": {
    "bits": 4,
    "desc_act": false,
    "group_size": -1,
    "is_marlin_format": false,
    "quant_method": "gptq",
    "sym": true
  },
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "1.7.0.20240321",
  "use_cache": true,
  "vocab_size": 32000
}

Usage Script: (needs vLLM)

import argparse
from vllm import LLM, SamplingParams


parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)

args = parser.parse_args()


prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=1, max_tokens=100)

# Create an LLM.
llm = LLM(args.model)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"\nGenerated text: {prompt}{generated_text}\n")


rahul-tuli avatar Mar 19 '24 14:03 rahul-tuli