sparseml icon indicating copy to clipboard operation
sparseml copied to clipboard

[Draft] Avoid loading model weights before recipe application if any

Open rahul-tuli opened this issue 1 year ago • 1 comments

Peviously when SparseAutoModelForCausalLM.from_pretrained(...) was called the weights were loaded in twice, once during model = super(AutoModelForCausalLM, cls).from_pretrained(...) and then again after recipe application, which is undesirable.

This PR updates the flow to use from_config(...) over from_pretrained, which initializes a model with init weight data, after recipe application the actual trained weights are loaded back in.

More info on from_config: https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#transformers.AutoModel.from_config

initial effort was to accomplish this with accelerate.init_empty weights but we run into https://discuss.huggingface.co/t/error-the-model-weights-are-not-tied-please-use-the-tie-weights-method-before-using-the-infer-auto-device-function-even-after-adding-model-tie-weights/46325 issue with quantized models.

Tests: Tested loading dense, sparse and quantized checkpoints which load just fine

Test script:


import time
from typing import List
from sparseml.transformers import SparseAutoModelForCausalLM
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--model-type", type=str, choices=["dense", "sparse", "quantized"], default="quantized")
parser.add_argument("--all", action="store_true")

BASE_MODEL = "Xenova/llama2.c-stories15M"

# Define the model paths for each model type
models = {
    "dense": "Xenova/llama2.c-stories15M",
    "sparse": "/home/rahul/projects/sparseml/local/local_output/sparse_model_80",
    "quantized": "mgoin/llama2.c-stories15M-quant-pt",
}

def load_and_time(model_path):
    start_time = time.time()
    SparseAutoModelForCausalLM.from_pretrained(model_path)
    end_time = time.time()
    return end_time - start_time

def load_weights(model_types: List[str]):
    return {
            model_type: load_and_time(models[model_type])
            for model_type in model_types
        }

    

def main(args):
    timings = ( 
               load_weights(model_types=list(models.keys()))
               if args.all 
               else load_weights(model_types=[args.model_type])
    )
    print(timings)

if __name__ == "__main__":
    args = parser.parse_args()
    main(args=args)
    

rahul-tuli avatar Apr 08 '24 16:04 rahul-tuli

Also @rahul-tuli, the correct implementantion of this PR should make this part of from_pretrained method:

def skip(*args, **kwargs):
    pass
# Skip the initializer step. This accelerates the loading
# of the models, especially for the quantized models
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

redundant!

dbogunowicz avatar Apr 15 '24 07:04 dbogunowicz