optimum icon indicating copy to clipboard operation
optimum copied to clipboard

Quantisation of BigBirdForTokenClassification suffers significant performance drop

Open lewisbails opened this issue 3 years ago • 12 comments

System Info

Apple M1 Pro
macOS 12.5 Monterey

optimum[onnxruntime]==1.3.0
python==3.9.11

Who can help?

@JingyaHuang

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

Model creation

from optimum.onnxruntime import ORTModelForTokenClassification
model_ort = ORTModelForTokenClassification.from_pretrained(model, use_auth_token=True, from_transformers=True, force_download=False)
model_ort.save_pretrained(onxx_dir, file_name="model.onnx")

from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTQuantizer

# Define the quantization methodology
qconfig = AutoQuantizationConfig.arm64(**kwargs)
quantizer = ORTQuantizer.from_pretrained(model, feature="token-classification")
# Apply dynamic quantization on the model
quantizer.export(
onnx_model_path=onnx_dir / "model.onnx",
onnx_quantized_model_output_path=onnx_dir / "model_quantized.onnx",
quantization_config=qconfig
)

Evaluation

tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True, use_auth_token=True)
model = ORTModelForTokenClassification.from_pretrained(onnx_dir, file_name="model_quantized.onnx")

for batch in batches:
  inputs = tokenizer(...)
  start = time.time()
  logits = model(**inputs).logits
  end = time.time()
  # collate batch logits and processing times

# statistical evaluation ...

Results

  • Benchmarking 10 different models on a token-classification task.
  • All ONNX-formatted models were created from the same pytorch BigBirdForTokenClassification model, just with different parameters to AutoQuantizationConfig.arm64.
  • The dataset used is private.
  • Note that both statistical performance and processing time are far worse for the quantised models created with Optimum.
  • If a parameter to AutoQuantizationConfig.arm64 doesn't appear as a column, then it was left the default value for all quantised models.
precision recall f1 period (sec/doc) quantized device accelerator api is_static use_symmetric_activations use_symmetric_weights per_channel
0.64 0.83 0.72 0.41 FALSE cuda:0   transformers        
0.64 0.83 0.72 2.15 FALSE cpu   transformers        
0.75 0.19 0.30 5.37 TRUE cpu onnxruntime optimum FALSE TRUE FALSE FALSE
1.00 0.18 0.31 5.48 TRUE cpu onnxruntime optimum FALSE FALSE TRUE TRUE
1.00 0.18 0.31 5.52 TRUE cpu onnxruntime optimum FALSE TRUE TRUE TRUE
0.75 0.19 0.30 6.72 TRUE cpu onnxruntime optimum FALSE FALSE FALSE FALSE
1.00 0.18 0.31 7.21 TRUE cpu onnxruntime optimum FALSE FALSE TRUE FALSE
1.00 0.18 0.31 7.28 TRUE cpu onnxruntime optimum FALSE TRUE TRUE FALSE
1.00 0.18 0.31 7.78 TRUE cpu onnxruntime optimum FALSE FALSE FALSE TRUE
1.00 0.18 0.31 7.85 TRUE cpu onnxruntime optimum FALSE TRUE FALSE TRUE

Expected behavior

  • Increased performance
  • Similar statistical performance

lewisbails avatar Aug 10 '22 09:08 lewisbails

@deanjones if you're interested in following this.

lewisbails avatar Aug 10 '22 09:08 lewisbails

Hello,

Do you think the issue is related to running on Apple M1? Did you try to run on a x86_64 CPU? Could it be that PyTorch leverages the Apple Neural Engine / Metal Performance Shaders but ONNX Runtime does not?

On a toy example on a x86_64 laptop with ripjar/bigbird-roberta-base-nrer, with batch size 1 and sequence length 1024 and using CPUExecutionProvider I get the runtimes

PyTorch: 23.91 s
ONNX Runtime: 15.36 s
ONNX Runtime quantized: 7.68 s

Script (to run with optimum main):

from optimum.onnxruntime import ORTModelForTokenClassification

from transformers import AutoTokenizer, AutoModelForTokenClassification

import time

from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTQuantizer
import torch

import onnxruntime

model_name = "ripjar/bigbird-roberta-base-nrer"

ort_model = ORTModelForTokenClassification.from_pretrained(model_name, from_transformers=True)
pt_model = AutoModelForTokenClassification.from_pretrained(model_name)
pt_model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

##
batch_size = 1
seq_length = 1024
inp = {
    "input_ids": torch.randint(low=0, high=100, size=(batch_size, seq_length)),
    "attention_mask": torch.ones(batch_size, seq_length, dtype=torch.int64),
}

##

with torch.no_grad():
    res_pt = pt_model(**inp)

res_ort = ort_model(**inp)

assert torch.allclose(res_ort.logits, res_pt.logits, atol=1e-1)

##

def benchmark(model, inp, iters=25):
    # warmup
    for _ in range(5):
        model(**inp)

    start = time.time()
    for _ in range(iters):
        model(**inp)
    end = time.time()
    return end - start

with torch.no_grad():
    pt_time = benchmark(pt_model, inp)
    
ort_time = benchmark(ort_model, inp)

##

qconfig = AutoQuantizationConfig.avx512(is_static=False)
quantizer = ORTQuantizer.from_pretrained(ort_model)

quantized_model_path = quantizer.quantize(qconfig, save_dir="outdir_bigbird")

session = onnxruntime.InferenceSession(str(quantized_model_path))
model_quantized = ORTModelForTokenClassification(session)

ort_quantized_time = benchmark(model_quantized, inp)

print(f"PyTorch: {pt_time:.2f} s")
print(f"ONNX Runtime: {ort_time:.2f} s")
print(f"ONNX Runtime quantized: {ort_quantized_time:.2f} s")

fxmarty avatar Aug 12 '22 07:08 fxmarty

Thanks for that @fxmarty. That could certainly be the case for why PyTorch inference is faster on my machine! But regarding the ORT model performance, your models seem much quicker than mine, which is interesting because I wouldn't expect mine to be slower on an M1. On this toy example, was there a large difference in logits between the original model and quantised version?

lewisbails avatar Aug 12 '22 08:08 lewisbails

I have never tried using onnxruntime with Apple devices, so that's why I am curious about it. I am not sure about the logits, I only wanted to check runtime here. Maybe you could try to run the script on your Apple M1 (maybe changing the autoconfig to arm64) to see what numbers you get?

Something I noticed is that for long sequences the logits between PyTorch and ONNX Runtime (not quantized) tend to be more and more different (had to pass atol=1e-1 for 1024 sequence length)

fxmarty avatar Aug 12 '22 08:08 fxmarty

I had to make a few tweaks to your script to get around some errors that were popping up. Are you using optimum==1.3.0? These were my results on M1:

PyTorch: 15.43 s
ONNX Runtime: 13.87 s
ONNX Runtime quantized: 13.72 s

It seems as though I was doing something strange during my benchmarking to get such a large difference between the ONNX models and the PyTorch models. It's interesting to note that you get almost a 2x speed up over my quantized model (7.68s vs 13.72s).

Script to run:

from optimum.onnxruntime import ORTModelForTokenClassification

from transformers import AutoTokenizer, AutoModelForTokenClassification

import time

from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTQuantizer
import torch

model_name = "ripjar/bigbird-roberta-base-nrer"

ort_model = ORTModelForTokenClassification.from_pretrained(model_name, from_transformers=True)
pt_model = AutoModelForTokenClassification.from_pretrained(model_name)
pt_model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

##
batch_size = 1
seq_length = 1024
inp = {
    "input_ids": torch.randint(low=0, high=100, size=(batch_size, seq_length)),
    "attention_mask": torch.ones(batch_size, seq_length, dtype=torch.int64),
}

##

with torch.no_grad():
    res_pt = pt_model(**inp)

res_ort = ort_model(**inp)

assert torch.allclose(res_ort.logits, res_pt.logits, atol=1e-1)

##


def benchmark(model, inp, iters=25):
    # warmup
    for _ in range(5):
        model(**inp)

    start = time.time()
    for _ in range(iters):
        model(**inp)
    end = time.time()
    return end - start


with torch.no_grad():
    pt_time = benchmark(pt_model, inp)

ort_time = benchmark(ort_model, inp)

##
ort_model.save_pretrained(".", file_name="model.onnx")
qconfig = AutoQuantizationConfig.arm64(is_static=False)
quantizer = ORTQuantizer.from_pretrained(model_name, feature="token-classification")
quantizer.export(
    onnx_model_path="./model.onnx",
    onnx_quantized_model_output_path="./model_quantized.onnx",
    quantization_config=qconfig
)

model_quantized = ORTModelForTokenClassification.from_pretrained(".", file_name="./model_quantized.onnx")

ort_quantized_time = benchmark(model_quantized, inp)

print(f"PyTorch: {pt_time:.2f} s")
print(f"ONNX Runtime: {ort_time:.2f} s")
print(f"ONNX Runtime quantized: {ort_quantized_time:.2f} s")

lewisbails avatar Aug 12 '22 09:08 lewisbails

Ok that's great! So apparently quantization on Apple M1 with onnxruntime is not that great. cc @mfuntowicz @hollance if you have any idea

I run Optimum 1.2.3.dev0 (dev version from the main branch) and the ORTQuantizer was refactored hence the errors you are getting I think, sorry about that!

fxmarty avatar Aug 12 '22 09:08 fxmarty

Also, I had to go up to atol=3 to get the logits comparison between the vanilla ONNX model and ONNX-quantized model to pass. Seems large, but I'm not familiar enough with quantization to know otherwise. If the difference is consistent across classes, I assume it wouldn't matter if one is softmaxing the logits anyway.

res_ort_quantized = model_quantized(**inp)
assert torch.allclose(res_ort.logits, res_ort_quantized.logits, atol=3)

lewisbails avatar Aug 12 '22 09:08 lewisbails

What are the min and max values for the predicted logits? atol=3 seems large but if the logits themselves are huge, then a difference of 3 may not be very significant.

hollance avatar Aug 12 '22 09:08 hollance

Running it again with the random input ids:

(Min, Max) PyTorch: ( -3.349, 3.752)
(Min, Max) ONNX Runtime: (-3.32, 3.737)
(Min, Max) ONNX Runtime quantized: (-5.626, 3.52)

lewisbails avatar Aug 12 '22 09:08 lewisbails

I didn't really look at the problem too closely but if this model runs on the Neural Engine / M1 GPU, then precision is limited to 16-bit floats. It's possible that some intermediate layer has activations that are too large or too small, and that this messes up the results of downstream layers.

One thing you could do is add a hook into the PyTorch model and print the min/max of the activation of each layer, to see if they are extremely large or small.

hollance avatar Aug 12 '22 10:08 hollance

I didn't explicitly send it to the Neural Engine / M1 GPU, do you know if this is something that happens under the hood?

lewisbails avatar Aug 12 '22 10:08 lewisbails

@lewisbails @hollance I think there are two very distinct issues here:

1/ runtime latency/throughput 2/ accuracy/other scores perfs

In my example above I was only focusing on runtime, as I was surprised of the worse times with onnxruntime; I don't know what PyTorch does under the hood.

fxmarty avatar Aug 12 '22 12:08 fxmarty