torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

torchtune quantization has different model output comparing with document

Open elfisworking opened this issue 1 year ago • 1 comments

I'm using torchtune for model quantization with QAT. Currently, I am learning based on https://pytorch.org/torchtune/main/tutorials/qat_finetune.html, but the results of the prepared_model I printed are different from those in the link. Is this normal?

from torchtune.training.quantization import Int8DynActInt4WeightQATQuantizer
from torchtune.models.llama3 import llama3_8b

model = llama3_8b()

# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# fine-tuning without performing any dtype casting
prepared_model = quantizer.prepare(model)

link show me like this.

>>> print(prepared_model.layers[0].attn)
MultiHeadAttention(
  (q_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
  (v_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
  (output_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)

But i get this:

MultiHeadAttention(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (output_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)

torch 2.4.1 torchtune 0.3.0 torchao 0.5.0 device: nvidia 3090 python: 3.10

elfisworking avatar Sep 27 '24 08:09 elfisworking

Hi @elfisworking thanks for creating the issue! Actually I think our QAT tutorial may be slightly out-of-date. This was written when QAT was done with module swapping (hence why we'd expect Linear -> Int8DynActInt4WeightQATLinear) but now it uses tensor subclasses. If I understand correctly, the fact that you still see Linear instead actually means that you're just on the latest version. cc @andrewor14 to confirm though. If so, we can update our QAT tutorial to reflect this.

ebsmothers avatar Sep 27 '24 17:09 ebsmothers