torchtune quantization has different model output comparing with document
I'm using torchtune for model quantization with QAT. Currently, I am learning based on https://pytorch.org/torchtune/main/tutorials/qat_finetune.html, but the results of the prepared_model I printed are different from those in the link. Is this normal?
from torchtune.training.quantization import Int8DynActInt4WeightQATQuantizer
from torchtune.models.llama3 import llama3_8b
model = llama3_8b()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# fine-tuning without performing any dtype casting
prepared_model = quantizer.prepare(model)
link show me like this.
>>> print(prepared_model.layers[0].attn)
MultiHeadAttention(
(q_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
(k_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
(v_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
(output_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
(pos_embeddings): RotaryPositionalEmbeddings()
)
But i get this:
MultiHeadAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(output_proj): Linear(in_features=4096, out_features=4096, bias=False)
(pos_embeddings): RotaryPositionalEmbeddings()
)
torch 2.4.1 torchtune 0.3.0 torchao 0.5.0 device: nvidia 3090 python: 3.10
Hi @elfisworking thanks for creating the issue! Actually I think our QAT tutorial may be slightly out-of-date. This was written when QAT was done with module swapping (hence why we'd expect Linear -> Int8DynActInt4WeightQATLinear) but now it uses tensor subclasses. If I understand correctly, the fact that you still see Linear instead actually means that you're just on the latest version. cc @andrewor14 to confirm though. If so, we can update our QAT tutorial to reflect this.