coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Discrepancy in model size and inference speed in INT8 between cto.coreml and cto.torch

Open james-p-xu opened this issue 8 months ago • 7 comments

❓Question

Hello, I have a few questions about the differences between the cto.coreml and cto.torch codepaths. I have followed the documentation how how to perform full INT8 (W8A8) quantization and have generated models with both approaches. I am running these examples on an M4 MBP.

Below are the tables with some profiling numbers. The first table is for an Apple example; I am able to replicate the performance numbers which Apple reports. The second table is for a pre-trained torch example which I am trying to export to CoreML. It is a similar ViT-based architecture to Apple's MobileViTv2.

ViT (FP16/INT8) .mlpackage files (too large to directly attach)

https://drive.google.com/file/d/1ABWEBFnKo-s-WmFEM62sAOj25VqlKrz9/view?usp=sharing


  1. Why would the default symmetric ModuleLinearQuantizerConfig (W8A8) cto.torch codepath result in FP16 weights?
  • The cto.torch path results in a 33.5% decrease in model size, while the cto.coreml path results in a 49.8% decrease in model size
  • Comparing the vit_int8_cto_torch.mlpackage (attached above) with the vit_int8_cto_coreml.mlpackage file in Netron viewer, I can see that the first linear layer has a constexpr_affine_dequantize in the cto.coreml model but no such weight dequantize in the cto.torch model
  • I suspect that this is where the discrepancy in model size is coming from: weights are stored in INT8 in the cto.coreml model and FP16 in the cto.torch model?
  1. Why does inference speedups differ so much (~20%) between the two INT8 models?
  • The smaller model (full INT8) is slower than the larger model (W16A8)? This is non-intuitive to me: since the newer Apple chips have native INT8 hardware support, I would expect the larger model to be slower?
  • This thread (https://github.com/apple/coremltools/issues/2432) helped answer some questions but I would like to understand how to get my ViT model to run with close to the same speedup as MobileViTv2
  • I used the open-source CoreMLProfiler tool to profile these models

Table 1: Profiling for MobileViTv2 (Apple Example)

Source: Core ML Tools Documentation

Model Version Model Size (MB) Size Change (%) Prediction Latency (median, ms) Prediction Latency Change (%) Load Latency (median, ms) Compilation Latency (median, ms)
MobileViTv2-FP16 10 N/A 7.793 N/A 16.329 64.330
MobileViTv2-W8A8 5.3 -47% 5.353 -31.3% 21.992 93.230

Table 2: Profiling for torch pre-trained ViT_B_16 (custom)

Note on naming:

  • cto.torch refers to quantization performed using coremltools.optimize.torch (quantization done on the PyTorch model before conversion to Core ML).
  • cto.coreml refers to quantization performed using coremltools.optimize.coreml (post-training quantization applied directly to an already converted .mlpackage Core ML model).
Model Version / Configuration Model Size (MB) Size Change (%) Prediction Latency (median, ms) Prediction Latency Change (%) Load Latency (median, ms) Compilation Latency (median, ms)
ViT_B_16-FP16 173.5 N/A 8.335 N/A 14.948 61.414
ViT_B_16_W8A8 (cto.torch) 115.4 -33.5% 7.457 -10.5% 21.847 88.207
ViT_B_16_W8A8 (cto.coreml) 87.1 -49.8% 8.209 -1.7% 18.496 77.849

vit_torch_coreml_export.py

import torch
import torchvision
import coremltools as ct
from coremltools.optimize.torch.quantization import (
    LinearQuantizer,
    LinearQuantizerConfig,
    ModuleLinearQuantizerConfig
)
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.datasets import FakeData
from multiprocessing import freeze_support

torch.manual_seed(0)

if __name__ == '__main__':
    freeze_support()

    img_size = 224 

    print("Loading pretrained ViT model (FP32 in PyTorch)...")
    weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1
    model_for_export = torchvision.models.vit_b_16(weights=weights)
    model_for_export.eval()

    vit_output_classes = 1000 
    print(f"ViT model loaded. Output classes: {vit_output_classes}")

    example_input_tensor = torch.randn(1, 3, img_size, img_size)
    example_inputs_tuple = (example_input_tensor, )

    print("Attempting FP16 CoreML export for ViT (using torch.jit.trace)...")
    model_for_export.eval()
    print("Applying JIT trace (with check_trace=False workaround)...")
    traced_model = torch.jit.trace(model_for_export, example_input_tensor, check_trace=False)
    
    # Convert to CoreML with FP16 compute precision
    converted_model_fp16 = ct.convert(
        traced_model,
        inputs=[ct.TensorType(name="image_input", shape=example_input_tensor.shape)],
        minimum_deployment_target=ct.target.macOS14,
        convert_to="mlprogram",
        compute_precision=ct.precision.FLOAT16
    )
    converted_model_fp16.save("vit_fp16.mlpackage")
    print("Saved FP16 CoreML ViT model: vit_fp16.mlpackage")

    transform_fake_data = transforms.Compose([
        transforms.Lambda(lambda x: torch.randn(3, img_size, img_size))
    ])
    
    dataset = FakeData(size=32,
                       image_size=(3, img_size, img_size),
                       num_classes=vit_output_classes,
                       transform=transform_fake_data)
    
    dataloader = data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

    num_calibration_steps = 8
    quant_config = LinearQuantizerConfig(
        global_config=ModuleLinearQuantizerConfig(
            quantization_scheme="symmetric",
            weight_dtype="qint8",
            activation_dtype="quint8",
        )
    )
    
    print("Re-loading ViT model for cto.torch quantization (starts as FP32 PyTorch)...")
    model_to_quantize = torchvision.models.vit_b_16(weights=weights)
    model_to_quantize.train() 

    quantizer = LinearQuantizer(model_to_quantize, quant_config)
    prepared_model = quantizer.prepare(example_inputs_tuple) 

    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(prepared_model.parameters(), lr=0.0001) 

    print(f"Starting quantization calibration loop for ViT ({num_calibration_steps} steps)...")
    for i, (inputs_batch, labels_batch) in enumerate(dataloader):
        if i >= num_calibration_steps:
            break
        optimizer.zero_grad()
        outputs = prepared_model(inputs_batch)
        loss = loss_fn(outputs, labels_batch)
        loss.backward()
        optimizer.step()
        quantizer.step()

    print("Calibration loop finished.")
    prepared_model.eval() 

    finalized_quantized_model = quantizer.finalize(inplace=True) 
    print("ViT model finalized for quantization (cto.torch).")

    print("Attempting W8A8 (quint8 based) CoreML export for ViT (cto.torch)...")
    finalized_quantized_model.eval()
    
    print("Applying JIT trace (with check_trace=False workaround)...")
    traced_quantized_model = torch.jit.trace(
        finalized_quantized_model, 
        example_input_tensor,
        check_trace=False 
    )
    
    coreml_inputs = [ct.TensorType(name="image_input", shape=example_input_tensor.shape)]
    
    quantized_model_coreml = ct.convert(
        traced_quantized_model,
        inputs=coreml_inputs,
        minimum_deployment_target=ct.target.macOS14,
        convert_to="mlprogram",
    )
    quantized_model_coreml.save("vit_int8_cto_torch.mlpackage")
    print("Saved W8A8 (cto.torch) quantized CoreML ViT model: vit_int8_cto_torch.mlpackage")

quantize_vit_coreml.py

import coremltools as ct
import coremltools.optimize as cto
import torch 
import torchvision.transforms as transforms
from torchvision.datasets import FakeData
import numpy as np

img_size = 224 
vit_output_classes = 1000 
coreml_model_input_name = "image_input"

transform_fake_data = transforms.Compose([
    transforms.Lambda(lambda x: torch.randn(3, img_size, img_size))
])

dataset = FakeData(size=32, 
                   image_size=(3, img_size, img_size),
                   num_classes=vit_output_classes,
                   transform=transform_fake_data)

num_samples_for_calibration = 32        
sample_data_for_coreml = []

print(f"Preparing {num_samples_for_calibration} samples for CoreML calibration data...")
for i in range(min(num_samples_for_calibration, len(dataset))):
    image_tensor, _ = dataset[i] 
    image_numpy = image_tensor.cpu().numpy() 
    image_numpy_with_batch = np.expand_dims(image_numpy, axis=0)
    sample_data_for_coreml.append({coreml_model_input_name: image_numpy_with_batch})
print(f"Finished preparing sample_data. Samples: {len(sample_data_for_coreml)}.")

# Load the FP16 CoreML model (generated by vit_torch_coreml_export.py)
print("Loading FP16 CoreML model: vit_fp16.mlpackage")
baseline_model = ct.models.MLModel("vit_fp16.mlpackage")

# Apply post-training quantization using coremltools.optimize.coreml
print("Attempting W8A8 PTQ using cto.coreml (user-provided two-step structure)...")
print("Step 1: Configuring and applying activation quantization...")
activation_config = cto.coreml.OptimizationConfig(
    global_config=cto.coreml.OpLinearQuantizerConfig(mode="linear_symmetric")
)

compressed_model_a8 = cto.coreml.linear_quantize_activations(
        baseline_model, config=activation_config, sample_data=sample_data_for_coreml
)
print("Activation quantization step applied.")

print("Step 2: Configuring and applying weight quantization...")
weight_config = cto.coreml.OptimizationConfig(
    global_config=cto.coreml.OpLinearQuantizerConfig(mode="linear_symmetric")
)

compressed_model_w8a8 = cto.coreml.linear_quantize_weights(
    compressed_model_a8, config=weight_config
)
print("Weight quantization step applied.")

compressed_model_w8a8.save("vit_int8_cto_coreml.mlpackage")
print("Saved W8A8 quantized Core ML ViT model (cto.coreml): vit_int8_cto_coreml.mlpackage")

james-p-xu avatar May 13 '25 22:05 james-p-xu

@junpeiz could you please take a look?

cymbalrush avatar May 15 '25 00:05 cymbalrush

For the model size and speed discrepancy, it's due to that in cto.coreml it does the compression more aggresively: for each const op it tries to compress it. While in cto.torch it only compresses specific torch module (torch.nn.Linear, etc).

I think @aseemw might be a better person to route the questions in this thread.

junpeiz avatar May 15 '25 20:05 junpeiz

@junpeiz, thanks for this insight! If I am targeting maximum model compression (I did not look too closely at numerical accuracy), then I guess I should proceed with cto.coreml?

Any reason why the resultant cto.coreml model would be slower than its cto.torch counterpart? Is the on-device graph compiler potentially missing something? It's basically the same speed as the FP16 model...

james-p-xu avatar May 15 '25 20:05 james-p-xu

maximum model compression

Compressing more ops only give you smaller disk size of the model. If an op' compression is not well supported by the hardware, it could hurt the latency (based on your experiment results). So it's more like a case-by-case preference. If latency is important to you, please go ahead with the better latency one (cto.torch).

junpeiz avatar May 16 '25 16:05 junpeiz

@junpeiz, I guess I am trying to understand how to achieve this stated INT8-INT8 performance increase?

In newer hardware with A17 Pro or M4 chips, such as iPhone 15 Pro, there is increased throughput possible for int8-int8 compute on Neural Engine, compared to previous versions

Looking at the ONNX of the vanilla ViT, the graph seems to have many MatMul nodes. Is there no native NPU support for INT8-INT8 MatMul (like INT8-INT8 GEMM on GPU)? Or is there only support for convs, which would explain why the other example architectures (MobileNet, ResNet, MobileViT) see speedups?

james-p-xu avatar May 16 '25 18:05 james-p-xu

I guess I am trying to understand how to achieve this stated INT8-INT8 performance increase?

I think @aseemw has more context about the benchmark you were referring to.

junpeiz avatar May 16 '25 18:05 junpeiz

In my above benchmark, I exported these models with macOS14 minimum target (which does not use fused SDPA op). However, when I set macOS15 minimum target (which uses fused SDPA op), I also do not see any performance boost on either the cto.torch and cto.coreml exported models.

using SDPA fused op cto.torch codepath cto.coreml codepath

james-p-xu avatar May 16 '25 22:05 james-p-xu