ipex-llm icon indicating copy to clipboard operation
ipex-llm copied to clipboard

Nano: Model optimization pipeline APIs

Open TheaperDeng opened this issue 1 year ago • 1 comments

Background

Nano currently has a great collection of pytorch/tf inference acceleration methods, while our users might need an automatical (intellectural guided) pipeline to find which one is the best.

Methodology

Currently we can have 3 workflows for our users who cares about the inference performance.

Method Accuracy Drop Expected Acceleration Ratio Retrain Success Ratio
Trainer.quantize True (except bf16) 1~4X False low
Trainer.trace False 1~2X False high
Trainer.search True 1~20X True medium

For this pipeline design, we will classify users to 2 catagories

  1. Users bring a trained model (maybe loaded from checkpoint files), and would like to optimize this specific model in a short time.

    We will have a new API to find the best accelearted model for our users automatically, detailed API design is illustrated below. This new API will conver the original Trainer.quantize and Trainer.trace.

  2. Users have a model definition and would like to find the best hyperparameter configuration to balance the accuracy and latency.

    Trainer.search will handle this case very easily, we will not cover this part carefully in this issue.

API Design

Please find a prototype implementation in: #5336

This API is designed to be

  1. really easy to use without extra parameters

  2. detailed acceleration strategy is completely hidden to our users.

# bigdl.nano.pytorch.common_inference

def accelerate_inference(model,
                         training_data,
                         validation_data=None,
                         metric=None,
                         allow_acc_drop=None,
                         cpu_num=None):
'''
:param model: A nn.module to be optimized
:param training_data: A pytorch dataloader for training dataset.
       Users should be careful with this parameter since this dataloader
       might be exposed to the model, which causing data leak. The
       batch_size of this dataloader is important as well, users may
       want to set it to the same batch size you may want to use the model
       in real deploy environment. E.g. batch size should be set to 1
       if you would like to use the accelerated model in an online service.
:param validation_data: (optional) A pytorch dataloader for accuracy evaluation
       This is only needed when users care about the possible accuracy drop.
:param metric: (optional) A callable object takes prediction and target
       and returns a accuracy value in this calling method `metric(pred, target)`
:param allow_acc_drop: (optional) a float represents the accuracy
       drop ratio that can be tollerated, such as 0.05 means we can accept a
       5% accuracy drop compare to the original model's accuracy.

:return: an accelerated model which can be used to predict
         as if it is the original pytorch nn module.
'''
    # psedo-code:
    # available_methods = _check_acceleration_methods_dependencies()
    # for method in available_methods:
    #     accelerated_model = method(model)
    #     performance = evaluate_performance(accelerated_model, training_data)
    #     accuracy = evaluate_performance(accelerated_model, validation_data, metric)
    #     if accracy meets requirement and performance is smaller:
    #         model_to_be_return = accelerated_model
    # return model_to_be_return

Some demo calling

A user who does not care about the accuracy drop and cares about the single sample inferece speed may call this function like this.

train_loader = Dataloader(trainset, batch_size=1)
accelerated_model = accelerate_inference(model, train_loader)

A user who has strict accuracy requirement and cares about a large batch's inference speed may call function like this:

train_loader = Dataloader(trainset, batch_size=512)
val_loader = Dataloader(valset)
accelerated_model = accelerate_inference(model, train_loader,
                                         validation_data=val_loader,
                                         metric=torchmetric.F1(10),
                                         tollerated_accuracy_drop=0.005)

TheaperDeng avatar Aug 08 '22 02:08 TheaperDeng

Maybe use an object and expose two APIs: optimize() and export()

shane-huang avatar Aug 11 '22 10:08 shane-huang